"""
Functions for gridding and processing CRISPy results on reference images.
This module provides tools to map CRISPy ridge detection results onto a reference image grid, clean and refine
the skeleton structures, and handle input and output operations. Key functionalities include skeleton gridding,
labeling, pruning, and advanced 2D/3D structure manipulation for astrophysical data.
"""
__author__ = 'mcychen'
import numpy as np
from astropy.io import fits
from skimage import morphology
from astropy.utils.console import ProgressBar
from .pruning.pruning import endPoints
from .pruning.structures import get_footprints
# ======================================================================================================================#
# higher-level wrapper
[docs]
def grid_skel(readFile, imgFile, writeFile, **kwargs):
"""
Map raw CRISPy results onto a reference image grid and save the gridded results.
Takes ungridded CRISPy results and aligns them with the grid of a
reference image (typically the image from which CRISPy was run). The gridded results
are saved to a specified output file in FITS format.
Parameters
----------
readFile : str
Path to the .txt file containing the ungridded (raw) CRISPy results.
imgFile : str
Path to the .fits file of the reference grid image.
writeFile : str
Path to the .fits file where the gridded CRISPy results will be saved.
**kwargs : dict, optional
Additional keyword arguments to customize the behavior of the `clean_grid` function.
Defaults are:
coord_in_xfirst : bool, default=True
Whether the input coordinates have the x-dimension first.
start_index : int, default=0
Starting index for gridding the skeleton.
min_length : int, default=6
Minimum length of structures to retain.
method : str, default="robust"
Gridding method, either "robust" or "fast".
Returns
-------
None
The gridded skeleton is saved directly to the specified `writeFile`.
Notes
-----
- The `clean_grid` function is responsible for handling the gridding and filtering of skeleton structures.
- The output file is saved in FITS format with the reference image's header metadata.
Examples
--------
Grid CRISPy results and save them to a FITS file:
>>> from crispy import grid_ridge
>>> grid_ridge.grid_skel("results.txt", "reference_image.fits", "gridded_results.fits", min_length=10,
method="fast")
"""
kwargs_default = dict(coord_in_xfirst=True, start_index=0, min_length=6, method="robust")
kwargs = {**kwargs_default, **kwargs}
crds = read_table(readFile)
img, hdr = fits.getdata(imgFile, header=True)
skel_cube = clean_grid(crds, img, **kwargs)
write_skel(writeFile, skel_cube, header=hdr)
# ======================================================================================================================#
# input and output
[docs]
def read_table(fname, useDict=False):
"""
Read filament skeleton data from a file.
Reads the skeleton coordinates identified by the SCMS algorithm
from a text file. It can return the data as either a NumPy array or a dictionary,
depending on the user's preference.
Parameters
----------
fname : str
Path to the text file containing the skeleton data.
useDict : bool, optional, default=False
If True, returns the skeleton data as a dictionary with coordinate labels
(`'xind, yind, zind'` for 3D or `'xind, yind'` for 2D). Otherwise, returns a NumPy array.
Returns
-------
ndarray or dict
- If `useDict` is False: A NumPy array containing the skeleton coordinates.
- If `useDict` is True: A dictionary with labeled coordinates.
Notes
-----
- The skeleton file is expected to have columns representing coordinate indices.
- If the data is 3D, the dictionary keys will be `'xind, yind, zind'`.
- For 2D data, the dictionary keys will be `'xind, yind'`.
Examples
--------
Read skeleton data as a NumPy array:
>>> from crispy import grid_ridge
>>> data = grid_ridge.read_table("skeleton_data.txt")
>>> print(data.shape)
Read skeleton data as a dictionary:
>>> data_dict = grid_ridge.read_table("skeleton_data.txt", useDict=True)
>>> print(data_dict.keys())
"""
values = np.loadtxt(fname, unpack=True)
if useDict:
if np.shape(values)[0] == 3:
keys = ('xind, yind', 'zind')
else:
keys = ('xind, yind')
keys = keys.split(', ')
return dict(zip(keys, values))
else:
return values
[docs]
def write_skel(filename, data, header):
"""
Write a gridded image to a FITS file.
Parameters
----------
filename : str
The name of the FITS file to write.
data : ndarray
The image data to be written to the FITS file.
header : fits.Header
The FITS header to include in the file.
"""
# write gridded skeleton
data = data.astype('uint8')
fits.writeto(filename=filename, data=data, header=header, overwrite=True)
# ======================================================================================================================#
# label structures (this is the only place where sklearn is needed in the package
[docs]
def label_ridge(coord, eps=1.0, min_samples=5):
"""
Label unconnected ridges using DBSCAN clustering.
Applies the DBSCAN algorithm to identify and label distinct,
unconnected ridge structures in the input coordinates.
Parameters
----------
coord : ndarray
Coordinates of the ridge points, shape (n, D), where `n` is the number
of points and `D` is the dimensionality.
eps : float, optional, default=1.0
Maximum distance between two points to be considered part of the same ridge.
min_samples : int, optional, default=5
Minimum number of points required to form a cluster.
Returns
-------
labels : ndarray
Cluster labels for each point, shape (n,). Points labeled `-1` are considered noise.
Notes
-----
- DBSCAN is a density-based clustering algorithm that groups points based on spatial
proximity. Points that do not belong to any cluster are assigned the label `-1`.
- This function requires `scikit-learn` for the DBSCAN implementation.
Examples
--------
Label ridge points in a 2D space:
>>> import numpy as np
>>> from crispy import grid_ridge
>>> coords = np.array([[0, 0], [1, 1], [2, 2], [10, 10], [11, 11]])
>>> labels = grid_ridge.label_ridge(coords, eps=2.0, min_samples=2)
>>> print(labels)
[ 0 0 0 1 1 ]
"""
from sklearn.cluster import DBSCAN
db = DBSCAN(eps=eps, min_samples=min_samples).fit(coord)
labels = db.labels_
return labels
[docs]
def clean_grid(coord, refdata, coord_in_xfirst=False, start_index=1, min_length=6, method="robust"):
"""
Process and grid CRISPy coordinates onto a reference image, labeling and cleaning skeleton structures.
Takes CRISPy coordinates, labels distinct ridge structures using DBSCAN,
grids them onto a reference image, and removes endpoints that might connect separate structures.
Structures shorter than a specified length are filtered out.
Parameters
----------
coord : ndarray
Coordinates of the ridge points, shape (n, D), where `n` is the number of points
and `D` is the dimensionality (2D or 3D).
refdata : ndarray
Reference image array defining the grid dimensions.
coord_in_xfirst : bool, optional, default=False
If True, assumes the input coordinates are ordered with x as the first axis.
If False, assumes z is the first axis for 3D data or y for 2D data.
start_index : int, optional, default=1
The starting index for gridding the skeleton onto the reference image.
min_length : int, optional, default=6
Minimum length (in pixels) for structures to be retained.
method : {"robust", "fast"}, optional, default="robust"
Method for cleaning skeleton endpoints:
- "robust": Ensures diagonal connections are handled but is computationally intensive.
- "fast": Faster but may miss diagonally connected structures.
Returns
-------
skel_cube : ndarray
Binary array with the same shape as `refdata`, where gridded skeleton structures
are set to `True`.
Notes
-----
- The DBSCAN algorithm is used to label distinct ridge structures, grouping nearby points
into clusters and treating outliers as noise.
- Endpoints are removed to avoid overlap between distinct structures when gridded.
- The current implementation supports only 2D and 3D data.
Examples
--------
Grid and clean ridge coordinates for a 3D reference image:
>>> import numpy as np
>>> from crispy import grid_ridge
>>> coords = np.array([[0, 0, 0], [1, 1, 1], [10, 10, 10]])
>>> ref_image = np.zeros((20, 20, 20))
>>> skel_cube = grid_ridge.clean_grid(coords, ref_image, min_length=5, method="fast")
>>> print(skel_cube.shape)
(20, 20, 20)
"""
# label filaments
coord = coord.T
labels = label_ridge(coord, eps=1.0, min_samples=3)
skel_cube = np.zeros(refdata.shape, dtype=bool)
if method == "robust":
# define the space where end points may be considered connected by 8-neighborhood in 2D and 26-neighbourhood in
footprint = get_footprints(ndim=refdata.ndim, width=5)
print("---gridding {} distinct skeletons---".format(np.max(labels)))
for lb in ProgressBar(range(np.max(labels) + 1)):
# create a full skeleton
# loop through all the lables (except for -1, which is label for noise)
skl = grid_skeleton(coord[labels == lb].T, refdata, coord_in_xfirst=coord_in_xfirst, start_index=start_index)
skl = morphology.skeletonize(skl)
skl = skl.astype(bool)
if skl.sum() > min_length:
# only keep the structure if it has more pixels than the min_length
omask = np.logical_and(labels != lb, labels >= 0)
others = grid_skeleton(coord[omask].T, refdata, coord_in_xfirst=coord_in_xfirst, start_index=start_index)
others = morphology.skeletonize(others)
# remove the endpoint pixels that may connect one structure from another
if method == "robust":
# robust is much less efficient, but ensure the endpoints that diagonally connects distinct structures
# are removed
endpts = endPoints(skl)
try:
endpts_lg = morphology.binary_dilation(endpts, footprint=footprint)
except TypeError:
endpts_lg = morphology.binary_dilation(endpts, selem=footprint)
# find where the structures are connected
overlap_pt = np.logical_and(endpts_lg, others)
if np.sum(overlap_pt) > 0:
skl[np.logical_and(endpts, endpts_lg)] = False
elif method == "fast":
# fast method to remove ends that are "connected to other stuctures when gridded
# note, may miss diagonally connected structures
others = morphology.binary_dilation(others)
skl[np.logical_and(skl, others)] = False
skel_cube[skl] = True
return skel_cube
# note: this method designed to work on ppv structures
# more general 3d cleaning has yet to be implemented
[docs]
def clean_grid_ppv(coord, refdata, coord_in_xfirst=False, start_index=1, min_length=6, method="robust"):
"""
Process and grid CRISPy coordinates in PPV space onto a reference image, labeling and cleaning skeleton structures.
Grids CRISPy coordinates onto a position-position-velocity (PPV) reference image, labels distinct
ridge structures using DBSCAN, and removes endpoints to prevent overlap between structures. Structures shorter
than a specified projected length are filtered out, and vertical segments are truncated based on a velocity threshold.
Parameters
----------
coord : ndarray
Coordinates of the ridge points, shape (n, D), where `n` is the number of points
and `D` is the dimensionality (typically 3D in PPV space).
refdata : ndarray
Reference PPV image array defining the grid dimensions.
coord_in_xfirst : bool, optional, default=False
If True, assumes the input coordinates are ordered with x as the first axis.
If False, assumes z is the first axis.
start_index : int, optional, default=1
The starting index for gridding the skeleton onto the reference image.
min_length : int, optional, default=6
Minimum projected length (in pixels) for structures to be retained.
method : {"robust", "fast"}, optional, default="robust"
Method for cleaning skeleton endpoints:
- "robust": Handles diagonal connections but is computationally intensive.
- "fast": Faster but may miss diagonally connected structures.
Returns
-------
skel_cube : ndarray
Binary array with the same shape as `refdata`, where gridded skeleton structures
are set to `True`.
Notes
-----
- DBSCAN is used to label ridge structures with higher resultion than the image grid
- Endpoints are removed to avoid overlap between distinct structures.
- Vertical segments in velocity are removed based on a threshold (`delVelMax` set to 2 pixels).
- The current implementation supports only 3D PPV data.
Examples
--------
Grid and clean ridge coordinates in PPV space:
>>> import numpy as np
>>> from crispy import grid_ridge
>>> coords = np.array([[0, 0, 0], [1, 1, 1], [10, 10, 10]])
>>> ref_image = np.zeros((20, 20, 20)) # Reference PPV image
>>> skel_cube = grid_ridge.clean_grid_ppv(coords, ref_image, min_length=5, method="robust")
>>> print(skel_cube.shape)
(20, 20, 20)
"""
delVelMax = 2
# label the filaments
coord = coord.T
labels = label_ridge(coord, eps=1.0, min_samples=3)
skel_cube = np.zeros(refdata.shape, dtype=bool)
print("---gridding {} distinct skeletons---".format(np.max(labels)))
for lb in ProgressBar(range(np.max(labels) + 1)):
# create a full skeleton
# loop through all the lables (except for -1, which is label for noise)
skl = grid_skeleton(coord[labels == lb].T, refdata, coord_in_xfirst=coord_in_xfirst, start_index=start_index)
skl = morphology.skeletonize(skl)
skl = skl.astype(bool)
len2d = get_2d_length(skl)
if len2d > min_length:
# only keep the structure if it has projected length longer than min_length
omask = np.logical_and(labels != lb, labels >= 0)
others = grid_skeleton(coord[omask].T, refdata, coord_in_xfirst=coord_in_xfirst, start_index=start_index)
others = morphology.skeletonize(others)
# remove overlaping pixels
if method == "robust":
# remove a pixel from the end points
# robust is much less efficient, but
endpts = endPoints(skl)
try:
endpts_lg = morphology.binary_dilation(endpts, footprint=morphology.cube(5))
except TypeError:
endpts_lg = morphology.binary_dilation(endpts, selem=morphology.cube(5))
# find where the structures are connected
overlap_pt = np.logical_and(endpts_lg, others)
if np.sum(overlap_pt) > 0:
skl[np.logical_and(endpts, endpts_lg)] = False
elif method == "fast":
# fast method to remove ends that are too close to other stuctures
# note, may miss diagonally connected structures
others = morphology.binary_dilation(others)
skl[np.logical_and(skl, others)] = False
# remove vertical segments with delVelMax number of pixels
# note: if
skl2d = skl.sum(axis=0)
skl[:, skl2d > delVelMax] = False
skel_cube[skl] = True
# final cleaning to remove small objects
skel_cube = morphology.remove_small_objects(skel_cube, min_size=min_length, connectivity=2)
return skel_cube
# ======================================================================================================================#
# grid function
[docs]
def grid_skeleton(coord, refdata, coord_in_xfirst=False, start_index=1):
"""
Map CRISPy skeleton coordinates onto a reference image grid.
Takes CRISPy ridge coordinates and grids them onto a binary mask with the
same shape as a reference image. The resulting mask highlights the skeletonized structure
aligned to the grid.
Parameters
----------
coord : ndarray
Coordinates of the ridge points, shape (n, D), where `n` is the number of points
and `D` is the dimensionality (2D or 3D).
refdata : ndarray
Reference image array defining the grid dimensions.
coord_in_xfirst : bool, optional, default=False
If True, assumes the input coordinates are ordered with x as the first axis.
If False, assumes z is the first axis for 3D data or y for 2D data.
start_index : int, optional, default=1
Starting index for mapping the skeleton coordinates to the reference grid.
Returns
-------
mask : ndarray
Binary mask with the same shape as `refdata`, where skeletonized points are set to 1
and all other points are 0.
Notes
-----
- The coordinates are rounded to the nearest integer and adjusted for the starting index
before mapping onto the reference grid.
- This function supports both 2D and 3D data.
Examples
--------
Map 3D ridge coordinates onto a reference image grid:
>>> import numpy as np
>>> from crispy import grid_ridge
>>> coords = np.array([[0, 0, 0], [1, 1, 1], [10, 10, 10]])
>>> ref_image = np.zeros((20, 20, 20)) # Reference image
>>> mask = grid_ridge.grid_skeleton(coords, ref_image)
>>> print(mask.shape)
(20, 20, 20)
"""
# if the passed in coordinates are in the order of x, y, and z, instead z, y, and x.
if coord_in_xfirst:
coord[[0, -1]] = coord[[-1, 0]]
# round the pixel coordinates into the nearest integer
if coord.dtype != 'int64':
coord = np.rint(coord).astype(int)
coord = coord - start_index
coord = np.swapaxes(coord, 0, 1)
coords = tuple(zip(*coord))
mask = np.zeros(shape=refdata.shape)
mask[coords] = 1
return mask
[docs]
def make_skeleton(coord, refdata, rm_sml_obj=True, coord_in_xfirst=False, start_index=1, min_length=6):
"""
Map CRISPy skeleton coordinates onto a reference grid and clean the skeleton.
Grids CRISPy ridge coordinates onto a binary mask based on a reference
image. It optionally removes small objects and structures shorter than a specified
length to produce a cleaned skeleton map.
Parameters
----------
coord : ndarray
Coordinates of the ridge points, shape (n, D), where `n` is the number of points
and `D` is the dimensionality (2D or 3D).
refdata : ndarray
Reference image array defining the grid dimensions.
rm_sml_obj : bool, optional, default=True
If True, removes small objects shorter than `min_length` from the skeletonized map.
coord_in_xfirst : bool, optional, default=False
If True, assumes the input coordinates are ordered with x as the first axis.
If False, assumes z is the first axis for 3D data or y for 2D data.
start_index : int, optional, default=1
Starting index for mapping the skeleton coordinates to the reference grid.
min_length : int, optional, default=6
Minimum length (in pixels) for structures to be retained.
Returns
-------
mask : ndarray
Binary array with the same shape as `refdata`, representing the cleaned skeleton.
Structures shorter than `min_length` are removed if `rm_sml_obj` is True.
Notes
-----
- The skeleton is gridded using the `grid_skeleton` function and further processed
to remove small objects or short structures.
- Cleaning operations assume the skeleton is 1-pixel wide and connected by an 8-neighbor
connectivity in 2D or 26-neighbor connectivity in 3D.
Examples
--------
Create and clean a 3D skeleton:
>>> import numpy as np
>>> from crispy import grid_ridge
>>> coords = np.array([[0, 0, 0], [1, 1, 1], [10, 10, 10]])
>>> ref_image = np.zeros((20, 20, 20)) # Reference image
>>> mask = grid_ridge.make_skeleton(coords, ref_image, rm_sml_obj=True, min_length=5)
>>> print(mask.shape)
(20, 20, 20)
"""
mask = grid_skeleton(coord, refdata, coord_in_xfirst=coord_in_xfirst, start_index=start_index)
# remove small object shorter than a certain length (this assumes the skeleton is truely 1-pixel in width)
if rm_sml_obj:
mask = mask.astype('bool')
# connectivity = 2 to ensure "vortex/diagonal" connection
mask = morphology.remove_small_objects(mask, min_size=min_length, connectivity=2)
# label each connected structure
# Whether to use 4- or 8- "connectivity". In 3D, 4-"connectivity" means connected pixels have to share face,
# whereas with 8-"connectivity", they have to share only edge or vertex.
mask, num = morphology.label(mask, connectivity=2, return_num=True)
# remove filaments that does not meet the aspect ratio criterium in the pp space
for i in range(1, num + 1):
mask_i = mask == i
fil = np.sum(mask_i, axis=0)
fil = fil.astype('bool')
fil = morphology.skeletonize(fil)
if np.sum(fil) < min_length:
# note: this method may not be able to pick up short spine with lots of branches
mask[mask_i] = False
# re-label individual branches
if False:
mask, num = morphology.label(mask, connectivity=2, return_num=True)
else:
mask = mask / mask
mask = mask.astype('int')
return mask
[docs]
def get_2d_length(skl3d):
"""
Calculate the sky-projected length of a 3D skeleton.
Computes the length of a skeleton structure when projected onto a 2D plane.
The projection is performed by collapsing the third dimension of the input 3D skeleton array.
Parameters
----------
skl3d : ndarray
3D binary array representing the skeletonized structure, where `True` or `1`
indicates skeleton points and `False` or `0` represents the background.
Returns
-------
length : int
The total number of pixels in the projected 2D skeleton.
Notes
-----
- The function collapses the 3D skeleton along the third axis using a logical OR operation
and then applies 2D skeletonization to the resulting binary image.
- This method is useful for evaluating the extent of structures in position-position space
regardless of the velocity axis.
Examples
--------
Compute the 2D length of a 3D skeleton:
>>> import numpy as np
>>> from skimage.morphology import skeletonize
>>> from crispy import grid_ridge
>>> skl3d = np.zeros((10, 10, 10), dtype=bool)
>>> skl3d[0, 0, :] = True # A straight skeleton in 3D
>>> length = grid_ridge.get_2d_length(skl3d)
>>> print(length)
1
"""
skl = np.any(skl3d, axis=0)
skl = skl.astype('bool')
skl = morphology.skeletonize(skl)
return np.sum(skl)
[docs]
def uniq_per_pix(coord, mask, coord_in_xfirst=False, start_index=1):
"""
Reduce a list of ridge coordinates to one unique point per pixel.
Processes ridge coordinates to retain a single representative point
per pixel based on a provided binary mask. The representative point is selected
as the one with the median value along the last coordinate axis.
Parameters
----------
coord : ndarray
Ridge coordinates, shape (D, n), where `D` is the number of dimensions
(e.g., 2 or 3) and `n` is the number of points.
mask : ndarray
Binary mask array, shape matching the reference grid, where `True` indicates
pixels of interest.
coord_in_xfirst : bool, optional, default=False
If True, assumes the input coordinates are ordered with x as the first axis.
If False, assumes z is the first axis for 3D data or y for 2D data.
start_index : int, optional, default=1
Starting index for the coordinate system. Adjusts the input coordinates before
processing.
Returns
-------
coord_uniq : ndarray
Reduced set of coordinates, shape (D, m), where `m` is the number of unique
pixels with a representative coordinate.
Notes
-----
- This function is optimized for cases where the input mask represents gridded,
one-voxel-wide skeletons or spines.
- The median value along the last axis (e.g., z in 3D) is used to select the
representative point for each pixel.
Examples
--------
Reduce ridge coordinates to one per pixel:
>>> import numpy as np
>>> from crispy import grid_ridge
>>> coords = np.array([[0, 1, 1, 2], [0, 0, 0, 0], [0, 0, 1, 1]]) # 3D coordinates
>>> mask = np.zeros((3, 3, 3), dtype=bool)
>>> mask[1, 1, 0] = True
>>> mask[1, 1, 1] = True
>>> reduced_coords = grid_ridge.uniq_per_pix(coords, mask)
>>> print(reduced_coords)
[[1]
[1]
[0]]
"""
# if the passed in coordinates are in the order of x, y, and z, instead z, y, and x.
if coord_in_xfirst:
coord[[0, -1]] = coord[[-1, 0]]
# round the pixel coordinates into the nearest integer
if coord.dtype != 'int64':
crds_int = np.rint(coord).astype(int)
else:
msg = (f"The provided coord are of type {coord.dtype} instead of the supported int")
raise ValueError(msg)
crds_int = crds_int - start_index
crds_int = np.swapaxes(crds_int, 0, 1)
crds_int = tuple(zip(*crds_int))
# get indicies of where the mask is true
idx_mask = np.argwhere(mask)
# get the coordinate index in the smae format as the mask indicies
crds_int = np.array(crds_int).T
coord = coord.T
coord_uniq = []
for i, idx in enumerate(idx_mask):
mask_same = np.all(crds_int - idx_mask[i] == 0, axis=1)
crd_at_pix = coord[mask_same]
if crd_at_pix.size == 0:
print("[ERROR]: crd at pix size: {}; there may be a mismatch in the start_index".format(crd_at_pix.size))
# get index of the point with the median last-coordinate value within a pixel
# (e.g., in 3D, index of the point with the median z value)
z_vals = crd_at_pix[:, -1]
med_idx = np.argsort(z_vals)[len(z_vals) // 2]
coord_uniq.append(crd_at_pix[med_idx])
return np.array(coord_uniq).T