shift_wakers_multiproc#
- crispy.scms.shift_wakers_multiproc(G, X, h, d, c, mask, ncpu)[source]#
Shift walkers towards density ridges using the SCMS algorithm with multiprocessing.
This function parallelizes the SCMS walker-shifting process for improved efficiency on large datasets. It divides the walkers into chunks and processes them concurrently across multiple CPUs.
- Parameters:
G (ndarray) – Initial coordinates of the walkers, shape (m, D, 1), where m is the number of walkers and D is the dimensionality.
X (ndarray) – Coordinates of the data points, shape (n, D, 1), where n is the number of data points.
h (float) – Smoothing bandwidth for the Gaussian kernel.
d (int) – Target dimensionality of the ridge subspace.
c (ndarray) – Weighted Gaussian values computed for the data points and walkers, shape (m, n).
mask (ndarray of bool) – Boolean mask indicating valid (True) data points for each walker. Shape is (m, n).
ncpu (int) – Number of CPUs to use for parallel processing. If set to None, defaults to the number of available CPUs.
- Returns:
G_updated (ndarray) – Updated coordinates of the walkers after the SCMS shift, shape (m, D, 1).
error (ndarray) – Convergence error for each walker, shape (m,). The error represents the displacement of each walker and is used to determine convergence.
Notes
The walkers (G) are divided into chunks, and each chunk is processed independently on a separate CPU.
Internally, this function calls shift_walkers for each chunk, ensuring consistency with the SCMS algorithm.
Multiprocessing is particularly beneficial when the number of walkers or data points is large.
Examples
Perform a parallel SCMS shift for walkers:
>>> import numpy as np >>> from crispy import scms >>> data = np.random.random((100, 3, 1)) # 3D data points >>> walkers = np.random.random((10, 3, 1)) # Initial walker positions >>> c = np.random.random((10, 100)) # Weighted Gaussian values >>> mask = np.random.choice([True, False], size=(10, 100)) # Boolean mask >>> h = 1.0 >>> d = 1 >>> ncpu = 4 # Use 4 CPUs >>> G_updated, error = scms.shift_wakers_multiproc(walkers, data, h, d, c, mask, ncpu)