"""
Utilities for skeleton processing, branch property initialization, and pruning in 2D and 3D structures.
"""
__author__ = 'mcychen'
from scipy import ndimage
from skimage import morphology
import numpy as np
from astropy.utils.console import ProgressBar
import copy
import string
from ._filfinder_length import product_gen
from ._filfinder_length import init_lengths as init_lengths_2D
from .structures import get_base_block, two_con_3D
# ==============================================================================================
[docs]
def branchedPoints(skel, endpt=None):
"""
Identify branch points in a skeletonized structure.
Detects branch points in a 2D or 3D skeleton. Branch points are defined as
skeleton pixels that are not endpoints or body points. If no endpoints
are provided, they are computed automatically.
Parameters
----------
skel : ndarray
Binary array representing the skeletonized structure. Non-zero values represent
skeleton points, and zero values represent the background.
endpt : ndarray, optional
Precomputed binary array of endpoints in the skeleton. If `None`, the function
calculates the endpoints internally.
Returns
-------
pt : ndarray
Binary array with the same shape as `skel`, where branch points are set to `True`.
Notes
-----
- Branch points are identified by excluding body points and endpoints from the skeleton.
- The function automatically adjusts for 2D or 3D skeletons using appropriate connectivity rules.
- This function relies on `bodyPoints` to determine body points and `endPoints` to calculate endpoints
if `endpt` is not provided.
- This code is based on the 2D version seen in FilFinder (v1.7.2) by Eric Koch
Examples
--------
Detect branch points in a 2D skeleton:
>>> import numpy as np
>>> from crispy import pruning
>>> skel = np.zeros((5, 5), dtype=bool)
>>> skel[2, 1:4] = True
>>> skel[1, 2] = True
>>> branches = pruning.branchedPoints(skel)
>>> print(branches)
[[False False False False False]
[False False True False False]
[False False False False False]
[False False False False False]
[False False False False False]]
"""
pt = bodyPoints(skel)
pt = np.logical_and(skel, np.logical_not(pt))
# if no end-points are defined, find the end-points first and remove them
if endpt is None:
print("calculating end points...")
endpt = endPoints(skel)
pt = np.logical_and(pt, np.logical_not(endpt))
return pt
# identify body points (points with only two neighbour by 3-connectivity)
[docs]
def bodyPoints(skel):
"""
Identify body points in a skeletonized structure.
Detects body points in a 2D or 3D skeleton. Body points are defined as
pixels with exactly two neighbors in the skeleton, based on 2-connectivity in ND.
Parameters
----------
skel : ndarray
Binary array representing the skeletonized structure. Non-zero values represent
skeleton points, and zero values represent the background.
Returns
-------
pt : ndarray
Binary array with the same shape as `skel`, where body points are set to `True`.
Notes
-----
- Body points are computed by identifying skeleton points with exactly two neighbors
under the specified connectivity rules.
- This function supports both 2D and 3D skeletons, adjusting connectivity checks
based on the dimensionality.
- This code is based on the 2D version seen in FilFinder (v1.7.2) by Eric Koch.
Examples
--------
Detect body points in a 2D skeleton:
>>> import numpy as np
>>> from crispy import grid_ridge
>>> skel = np.zeros((5, 5), dtype=bool)
>>> skel[2, 1:4] = True
>>> body_pts = grid_ridge.bodyPoints(skel)
>>> print(body_pts)
[[False False False False False]
[False False False False False]
[False True True True False]
[False False False False False]
[False False False False False]]
"""
base_block = get_base_block(skel.ndim, return_cent_idx=False)
ptList = []
# iterate over the "top" layer
i = 0
for idx_top, v_top in np.ndenumerate(base_block[0]):
# for each cell in the "top" layer, iterate over teh "bottom" layer
for idx_bottom, v_bottom in np.ndenumerate(base_block[2]):
str_block = base_block.copy()
# populate the two neighbours
str_block[(0,) + idx_top] = 1
str_block[(2,) + idx_bottom] = 1
ptList.append(str_block)
# now add the permutations that are rotationally symmetric to the above list
ptListOri = copy.deepcopy(ptList)
for i in ptListOri:
ptList.append(np.swapaxes(i, 0, 1))
# again, for a 3D skeleton
if np.size(skel.shape) == 3:
for i in ptListOri:
ptList.append(np.swapaxes(i, 0, 2))
# remove the redundant elements
ptList = np.unique(np.array(ptList), axis=0)
pt = np.full(np.shape(skel), False, dtype=bool)
for pt_i in ptList:
pt = pt + ndimage.binary_hit_or_miss(skel, structure1=pt_i)
return pt
# identify end points (points with only two neighbour by 3-connectivity)
# (only works if the skeleton is on 1-pixel in width by 3-connectivity and not 1-connectivity)
[docs]
def endPoints(skel):
"""
Identify endpoints in a skeletonized structure.
Detects endpoints in a 2D or 3D skeleton. Endpoints are defined as
pixels in the skeleton with exactly one neighbor, based on 2-connectivity in ND.
Parameters
----------
skel : ndarray
Binary array representing the skeletonized structure. Non-zero values represent
skeleton points, and zero values represent the background.
Returns
-------
ep : ndarray
Binary array with the same shape as `skel`, where endpoints are set to `True`.
Notes
-----
- Endpoints are determined using hit-or-miss morphology with connectivity rules that
detect pixels with only one neighbor.
- This function supports both 2D and 3D skeletons, adjusting connectivity checks
based on the dimensionality.
- This code is based on the 2D version seen in FilFinder (v1.7.2) by Eric Koch.
Examples
--------
Detect endpoints in a 2D skeleton:
>>> import numpy as np
>>> from crispy import grid_ridge
>>> skel = np.zeros((5, 5), dtype=bool)
>>> skel[2, 1:4] = True
>>> skel[1, 2] = True
>>> endpoints = grid_ridge.endPoints(skel)
>>> print(endpoints)
[[False False False False False]
[False False True False False]
[False True False True False]
[False False False False False]
[False False False False False]]
"""
base_block, cent_idx = get_base_block(skel.ndim, return_cent_idx=True)
epList = []
# iterate over all permutation of endpoints
# Note: this does not account for "end points" that are only a pixel long
for index, value in np.ndenumerate(base_block):
if index != cent_idx:
str_block = base_block.copy()
str_block[index] = 1
epList.append(str_block)
ep = np.full(np.shape(skel), False, dtype=bool)
for ep_i in epList:
ep = ep + ndimage.binary_hit_or_miss(skel, structure1=ep_i)
return ep
[docs]
def walk_through_segment_3D(segment):
"""
Traverse a 3D skeleton segment to obtain an ordered list of pixel coordinates.
This function processes a skeleton segment that does not contain branches or intersections
and returns an ordered list of pixel coordinates. The traversal starts from the endpoint
closest to the origin.
Parameters
----------
segment : ndarray
A binary 3D array representing the skeleton segment. Non-zero values indicate
skeleton pixels, and zero values represent the background.
Returns
-------
idx_list : list of tuple
A list of 3D coordinate tuples ordered by their position along the segment.
The traversal begins from the endpoint nearest to the origin.
Notes
-----
- This function assumes that the segment has exactly two endpoints and does not touch the
edges of the array.
- The traversal may fail if the segment width exceeds one pixel at any point due to
imperfect skeletonization.
- Endpoints are detected using the `endPoints` function.
- This code is based on the 2D version seen in FilFinder (v1.7.2) by Eric Koch.
Raises
------
ValueError
If the segment has less than two pixels or more than one pixel-wide connectivity.
"""
# note: this only works if the endpoints does not touch the edge
segment = copy.copy(segment)
# in case the segment contains less than 2 pixels
num_pix = len(segment[segment >= 1])
if num_pix < 1:
print("[ERROR]: the total number of pixels in the segment is less than 1!")
return None
if num_pix == 1:
z, y, x = np.argwhere(segment >= 1)[0]
return [(z, y, x)]
# find indicies of the endpoints
ept = endPoints(segment)
ept_idx = np.argwhere(ept)
# find the endpoint that is closest to the origin
if np.sum(ept_idx[0] ** 2) < np.sum(ept_idx[1] ** 2):
idx = ept_idx[0]
else:
idx = ept_idx[1]
z, y, x = idx
idx_list = [(z, y, x)]
block = segment[z - 1:z + 2, y - 1:y + 2, x - 1:x + 2]
block[1, 1, 1] = 0
# "walk through" the pixels in the segment
while len(block[block > 0]) == 1:
k, j, i = np.argwhere(block > 0)[0]
z, y, x = z + k - 1, y + j - 1, x + i - 1
idx_list.append((z, y, x))
block = segment[z - 1:z + 2, y - 1:y + 2, x - 1:x + 2]
block[1, 1, 1] = 0
# in case the walk was terminated due to imperfect skeletonization
if len(block[block > 0]) > 1:
print("[ERROR]: the skeleton segment is more than a pixel wide by 3-connectivity")
return None
return idx_list
[docs]
def init_lengths_3D(labelisofil, array_offsets=None, img=None, use_skylength=True):
"""
Compute lengths and intensities for branches in 3D skeletons.
This function calculates the lengths and average intensities of branches
in labeled skeletons, accounting for both the full 3D length and the
sky-projected length if specified.
Parameters
----------
labelisofil : list of ndarray
A list of 3D labeled skeleton arrays. Each array contains skeleton branches
where intersections have been removed, and branches are labeled with unique integers.
array_offsets : list of ndarray, optional, default=None
Indices specifying where each skeleton array fits in the original image.
If None, offsets default to ones.
img : ndarray, optional, default=None
The original 3D intensity image. If provided, the average intensity along
each branch is computed. If not provided, the intensity is assumed to be uniform.
use_skylength : bool, optional, default=True
If True, calculates the sky-projected length (ignoring the velocity axis).
If False, calculates the full 3D length in PPV space.
Returns
-------
branch_properties : dict
A dictionary containing the following keys:
- `length`: A list of branch lengths for each skeleton.
- `intensity`: A list of average intensities for each branch.
- `pixels`: A list of pixel coordinates for each branch.
Notes
-----
- Branch lengths are calculated using the `walk_through_segment_3D` function.
- Sky-projected lengths are computed by ignoring the velocity axis during length calculation.
- The function pads branch arrays to prevent edge-related errors during traversal.
- This code is based on the 2D version seen in FilFinder (v1.7.2) by Eric Koch.
Raises
------
ValueError
If the shape of `img` does not match the shape of the skeleton arrays in `labelisofil`.
"""
print("getting branch_properties")
num = len(labelisofil)
if img is None:
img = np.ones(labelisofil[0].shape)
else:
if img.shape != labelisofil[0].shape:
print("[ERROR]: the shape of the intensity image does that match that of the skeleton")
return None
if array_offsets is None:
array_offsets = np.ones((num, 1, 3), dtype=int)
# Initialize Lists
lengths = []
av_branch_intensity = []
all_branch_pts = []
for n in range(num):
leng = []
av_intensity = []
branch_pix = []
label_copy = copy.copy(labelisofil[n])
objects = ndimage.find_objects(label_copy)
for i, obj in enumerate(objects):
# Scale the branch array to the branch size
branch_array = np.zeros_like(label_copy[obj])
# Find the skeleton points and set those to 1
branch_pts = np.where(label_copy[obj] == i + 1)
branch_array[branch_pts] = 1
# pad the edges to so it's compatiable with walk_through_segment_3D (i.e., skeleton does not touch the edge)
branch_array = np.pad(branch_array, 1, mode='constant', constant_values=0)
# Now find the length on the branch
if branch_array.sum() == 1:
# if single pixel. No need to find length
# For use in the longest path algorithm, will be set to zero for final analysis
branch_length = 0.5
else:
wlk_idx = walk_through_segment_3D(branch_array)
if use_skylength:
# calculate the sky-projected length
branch_length = segment_len(wlk_idx, remove_axis=0)
else:
# calculate the ppv length
branch_length = segment_len(wlk_idx)
leng.append(branch_length)
# Find the average intensity along each branch
# Get the offsets from the original array and add on the offset the branch array introduces.
x_offset = obj[0].start + array_offsets[n][0][0]
y_offset = obj[1].start + array_offsets[n][0][1]
z_offset = obj[2].start + array_offsets[n][0][2]
av_intensity.append(np.nanmean([img[x + x_offset, y + y_offset, z + z_offset]
for x, y, z in zip(*branch_pts)
if np.isfinite(img[x + x_offset, y + y_offset, z + z_offset]) and
not img[x + x_offset, y + y_offset, z + z_offset] < 0.0]))
branch_pix.append(np.array([(x + x_offset, y + y_offset, z + z_offset)
for x, y, z in zip(*branch_pts)]))
lengths.append(leng)
av_branch_intensity.append(av_intensity)
all_branch_pts.append(branch_pix)
branch_properties = {"length": lengths,
"intensity": av_branch_intensity,
"pixels": all_branch_pts}
return branch_properties
[docs]
def init_branch_properties(labelisofil, ndim, img=None, use_skylength=True):
"""
Initialize branch properties for 2D or 3D skeletons.
Computes lengths and intensities of branches in skeletons, supporting both
2D and 3D skeleton structures. For 2D skeletons, the lengths are initialized
using `init_lengths_2D`, while for 3D skeletons, `init_lengths_3D` is used.
Parameters
----------
labelisofil : list of ndarray
A list of labeled skeleton arrays, where branches are labeled with unique integers,
and intersections have been removed.
ndim : int
The number of dimensions of the skeletons. Must be either 2 or 3.
img : ndarray, optional
Intensity image associated with the skeletons. If provided, the average intensity
along each branch is calculated. Defaults to None.
use_skylength : bool, optional
If True, calculates the sky-projected length for each branch (ignoring the velocity axis).
If False, calculates the full 3D length in PPV space for 3D skeletons. Defaults to True.
Returns
-------
branch_properties : dict
A dictionary containing the following keys:
- `length`: List of branch lengths.
- `intensity`: List of average intensities for each branch.
- `pixels`: List of pixel coordinates for each branch.
Notes
-----
- The function dispatches to different implementations depending on the dimensionality (`ndim`).
- This code is based on the 2D version seen in FilFinder (v1.7.2) by Eric Koch, which as been
depreciated
Examples
--------
Initialize branch properties for a 2D skeleton:
>>> import numpy as np
>>> from crispy import grid_ridge
>>> skel = np.zeros((5, 5), dtype=bool)
>>> skel[2, 1:4] = True
>>> labelisofil = [skel.astype(int)]
>>> props = grid_ridge.init_branch_properties(labelisofil, ndim=2)
>>> print(props["length"])
[[2.0]]
"""
if ndim == 2:
num = len(labelisofil)
array_offsets = np.ones((num, 1, 2), dtype=int)
if img is None:
img = np.ones(labelisofil[0].shape)
return init_lengths_2D(labelisofil, array_offsets, img=img)
else:
return init_lengths_3D(labelisofil, img=img, use_skylength=use_skylength)
[docs]
def segment_len(wlk_idx, remove_axis=None):
"""
Calculate the length of a skeleton segment.
This function computes the length of a skeleton segment from an ordered list
of its pixel coordinates. The length is calculated as the sum of Euclidean distances
between consecutive pixels. Optionally, a specified axis can be excluded from the
calculation, which is useful for computing sky-projected lengths.
Parameters
----------
wlk_idx : list of tuple
Ordered list of pixel coordinates representing the skeleton segment.
remove_axis : int, optional, default=None
Axis to exclude from the length calculation. If None, the full length is computed
in all dimensions.
Returns
-------
length : float
The computed length of the skeleton segment.
Notes
-----
- The calculated length may underestimate the actual length by approximately one pixel
due to measuring from the center of each pixel.
- Excluding an axis (e.g., velocity in PPV space) computes the sky-projected length.
"""
crd_diff = np.diff(np.swapaxes(wlk_idx, 0, 1) * 1.0)
if remove_axis is not None:
crd_diff[remove_axis, :] = 0.0
dst = np.sum(np.sqrt(np.sum(crd_diff ** 2, axis=0)))
return dst
[docs]
def remove_bad_ppv_branches(labBodyPtAry, num_lab, refStructure=None, max_pp_length=9.0, v2pp_ratio=1.5, method="full"):
"""
Remove unphysical branches from a labeled 3D skeleton in PPV space.
This function identifies and removes branches that are likely unphysical, such as
those with small projected lengths in the position-position (PP) plane or with
high velocity-to-length ratios. Optionally, a faster approximation method can be
used for branch filtering.
Parameters
----------
labBodyPtAry : ndarray
A 3D array of the skeleton with body points removed, where branches are labeled
with unique integers.
num_lab : int
Number of labeled branches in the array.
refStructure : ndarray, optional, default=None
The reference structure array (full skeleton). If None, it is derived from
`labBodyPtAry`.
max_pp_length : float, optional, default=9.0
Maximum allowed length of a branch in the PP plane. Branches shorter than this
threshold are evaluated for removal.
v2pp_ratio : float, optional, default=1.5
Minimum allowed velocity-to-length ratio. Branches exceeding this ratio are
removed.
method : {"full", "quick"}, optional, default="full"
Method for branch filtering:
- "full": Performs a detailed analysis using branch traversal and length calculations.
- "quick": Uses a faster, approximate method for filtering based on pixel counts.
Returns
-------
filtered_structure : ndarray
A binary array of the reference structure with unphysical branches removed.
Notes
-----
- The "full" method uses `walk_through_segment_3D` to accurately calculate branch
lengths and velocity ratios.
- The "quick" method approximates branch lengths by counting pixels, which may
be less accurate for longer branches.
- Branch removal may fail if `labBodyPtAry` and `refStructure` have mismatched shapes.
Raises
------
ValueError
If the shapes of `labBodyPtAry` and `refStructure` do not match.
"""
if refStructure is None:
refStructure = labBodyPtAry.copy()
refStructure[refStructure != 0] = 1
else:
refStructure = refStructure.copy()
if labBodyPtAry.shape != refStructure.shape:
print("[ERROR]: the shape fo labBodyPtAry and refStructure are not the same!")
return None
if method == "full":
objects = ndimage.find_objects(labBodyPtAry)
for i, obj in enumerate(ProgressBar(objects)):
# Scale the branch array to the branch size
branch = np.zeros_like(labBodyPtAry[obj])
# Find the skeleton points and set those to 1
branch_pts = np.where(labBodyPtAry[obj] == i + 1)
branch[branch_pts] = 1
# pad the edges to so it's compatiable with walk_through_segment_3D (i.e., skeleton does not touch the edge)
branch = np.pad(branch, 1, mode='constant', constant_values=0)
if len(branch == 1) > 1:
wlk_idx = walk_through_segment_3D(branch)
skylength = segment_len(wlk_idx, remove_axis=0)
fulllength = segment_len(wlk_idx)
vlength = np.sqrt(fulllength ** 2 - skylength ** 2)
if skylength <= max_pp_length:
ratio = vlength / skylength
if ratio >= v2pp_ratio:
refStructure[labBodyPtAry == i + 1] = 0
elif method == "quick":
# a quick way to approximate the ratio between the total length and the on-sky length (the accuracy decreases
# as the total length increases)
for n in ProgressBar(list(range(num_lab))):
# count the number of pixels in an on-sky projection of a branch
# (i.e., a proxy for the on-sky length of a filament)
branch = labBodyPtAry == n + 1
size_pp = (np.logical_or.reduce(branch, axis=0)).sum()
if size_pp <= max_pp_length:
v_length = (np.logical_or.reduce(branch, axis=(1, 2))).sum()
# remove branches that have high velocity-to-length ratio
ratio = 1.0 * v_length / size_pp
if ratio >= v2pp_ratio:
refStructure[branch] = 0
else:
print(("[ERROR]: the entered method {0} is not recongnized.".format(method)))
return None
return morphology.remove_small_objects(refStructure, min_size=2, connectivity=3)
[docs]
def pre_graph_3D(labelisofil, branch_properties, interpts, ends, w=0.0):
"""
Convert 3D skeletons into graph representations with weighted edges.
This function generates graph representations of 3D skeletons where nodes represent
end points and intersection points, and edges represent branches. Edge weights are
calculated using branch lengths and intensities.
Parameters
----------
labelisofil : list of ndarray
A list of 3D labeled skeleton arrays, where branches are labeled with unique
integers, and intersection points are removed.
branch_properties : dict
A dictionary containing properties of the branches, with the following keys:
- `length`: List of branch lengths.
- `intensity`: List of average intensities for each branch.
interpts : list of list of ndarray
Intersection points for each skeleton, with each entry containing the coordinates
of pixels belonging to an intersection.
ends : list of ndarray
Endpoints for each skeleton.
w : float, optional, default=0.0
Weighting factor for branch lengths and intensities in edge weight calculation.
Must be between 0.0 (length-only weighting) and 1.0 (intensity-only weighting).
Returns
-------
edge_list : list
List of edges in the graph. Each edge is represented as a tuple:
`(node_1, node_2, edge_properties)`, where `edge_properties` includes
branch length and intensity.
nodes : list
List of all nodes in the graph, including endpoints and intersection points.
loop_edges : list
List of loop edges (edges connecting two intersection nodes through multiple
branches).
Notes
-----
- Nodes corresponding to intersection points are labeled alphabetically. For graphs
with more than 26 intersections, labels extend to AA, AB, etc.
- The `path_weighting` function calculates edge weights using both length and
intensity, with the relative contribution controlled by `w`.
- This code is based on the 2D version seen in FilFinder (v1.7.2) by Eric Koch.
Raises
------
ValueError
If `w` is not between 0.0 and 1.0.
"""
num = len(labelisofil)
end_nodes = []
inter_nodes = []
nodes = []
edge_list = []
loop_edges = []
def path_weighting(idx, length, intensity, w=0.5):
"""
Relative weighting for the shortest path algorithm using the branch lengths and the average intensity
along the branch.
MC note: this weighting scheme may be potentially flawed
"""
if w > 1.0 or w < 0.0:
raise ValueError(
"Relative weighting w must be between 0.0 and 1.0.")
return (1 - w) * (length[idx] / np.sum(length)) + \
w * (intensity[idx] / np.sum(intensity))
lengths = branch_properties["length"]
branch_intensity = branch_properties["intensity"]
for n in range(num):
inter_nodes_temp = []
# Create end_nodes, which contains lengths, and nodes, which we will later add in the intersections
end_nodes.append([(labelisofil[n][i],
path_weighting(int(labelisofil[n][i] - 1),
lengths[n],
branch_intensity[n],
w),
lengths[n][int(labelisofil[n][i] - 1)],
branch_intensity[n][int(labelisofil[n][i] - 1)])
for i in ends[n]])
nodes.append([labelisofil[n][i] for i in ends[n]])
# Intersection nodes are given by the intersections points of the filament.
for intersec in interpts[n]:
uniqs = []
for i in intersec: # Intersections can contain multiple pixels
z, y, x = i
int_arr = labelisofil[n][z - 1: z + 2, y - 1: y + 2, x - 1: x + 2]
int_arr = int_arr.astype(int)
int_arr[1, 1, 1] = 0
for x in np.unique(int_arr[np.nonzero(int_arr)]):
uniqs.append((x,
path_weighting(x - 1,
lengths[n],
branch_intensity[n],
w),
lengths[n][x - 1],
branch_intensity[n][x - 1]))
# Intersections with multiple pixels can give the same branches.
# Get rid of duplicates
uniqs = list(set(uniqs))
inter_nodes_temp.append(uniqs)
# Add the intersection labels. Also append those to nodes
inter_nodes.append(list(zip(product_gen(string.ascii_uppercase),
inter_nodes_temp)))
for alpha, node in zip(product_gen(string.ascii_uppercase),
inter_nodes_temp):
nodes[n].append(alpha)
# Edges are created from the information contained in the nodes.
edge_list_temp = []
loops_temp = []
for i, inters in enumerate(inter_nodes[n]):
end_match = list(set(inters[1]) & set(end_nodes[n]))
for k in end_match:
edge_list_temp.append((inters[0], k[0], k))
for j, inters_2 in enumerate(inter_nodes[n]):
if i != j:
match = list(set(inters[1]) & set(inters_2[1]))
new_edge = None
if len(match) == 1:
new_edge = (inters[0], inters_2[0], match[0])
elif len(match) > 1:
# Multiple connections (a loop)
multi = [match[l][1] for l in range(len(match))]
keep = multi.index(min(multi))
new_edge = (inters[0], inters_2[0], match[keep])
# Keep the other edges information in another list
for jj in range(len(multi)):
if jj == keep:
continue
loop_edge = (inters[0], inters_2[0], match[jj])
dup_check = loop_edge not in loops_temp and \
(loop_edge[1], loop_edge[0], loop_edge[2]) \
not in loops_temp
if dup_check:
loops_temp.append(loop_edge)
if new_edge is not None:
dup_check = (new_edge[1], new_edge[0], new_edge[2]) \
not in edge_list_temp \
and new_edge not in edge_list_temp
if dup_check:
edge_list_temp.append(new_edge)
# Remove duplicated edges between intersections
edge_list.append(edge_list_temp)
loop_edges.append(loops_temp)
return edge_list, nodes, loop_edges
[docs]
def main_length_3D(max_path, edge_list, labelisofil, interpts, branch_lengths, img_scale, verbose=False, save_png=False,
save_name=None):
"""
Compute the main lengths of 3D skeletons and generate longest path arrays.
This function calculates the overall lengths of skeletons in a 3D image by identifying
and preserving the longest paths. Intersections are added back to the skeletons, and
extraneous pixels introduced by intersections are removed.
Parameters
----------
max_path : list
List of paths corresponding to the longest lengths for each skeleton.
edge_list : list
List of edges representing connectivity information for the skeleton graphs.
labelisofil : list of ndarray
List of 3D labeled skeleton arrays. Each array contains skeleton branches
with unique integer labels and no intersection points.
interpts : list of list of ndarray
Intersection points for each skeleton, with each entry containing the
coordinates of pixels belonging to an intersection.
branch_lengths : list
Lengths of individual branches in each skeleton.
img_scale : float
Conversion factor from pixel units to physical units.
verbose : bool, optional, default=False
If True, prints detailed information about the process (currently disabled for 3D).
save_png : bool, optional, default=False
If True, saves 2D visualizations of the skeletons (disabled for 3D).
save_name : str, optional, default=None
Name for saving output PNGs (currently unused for 3D).
Returns
-------
main_lengths : list
List of overall lengths for each skeleton, in physical units.
longpath_arrays : ndarray
Binary 3D array representing the longest paths for all skeletons. Non-zero
values indicate pixels belonging to the longest paths.
Notes
-----
- Intersections are added back to the skeleton, and extraneous pixels are removed
using a recursive pruning process.
- The `max_path` input determines the longest path in each skeleton.
- This function is adapted from the 2D `main_length` function in FilFinder and
includes modifications for 3D structures.
- This code is based on the 2D version seen in FilFinder (v1.7.2) by Eric Koch.
"""
main_lengths = []
longpath_cube = np.zeros(labelisofil[0].shape, dtype=bool)
for num, (path, edges, inters, skel_arr, lengths) in \
enumerate(zip(max_path, edge_list, interpts, labelisofil,
branch_lengths)):
if not lengths:
# if the lengths list is empty
print("empty!")
print(("lengths: {}".format(lengths)))
# may want to check not doing anything does not cause problem
elif len(path) == 1:
main_lengths.append(lengths[0] * img_scale)
skeleton = skel_arr # for viewing purposes when verbose
else:
skeleton = np.zeros(skel_arr.shape)
# Add edges along longest path
good_edge_list = [(path[i], path[i + 1])
for i in range(len(path) - 1)]
# Find the branches along the longest path.
for i in good_edge_list:
for j in edges:
if (i[0] == j[0] and i[1] == j[1]) or \
(i[0] == j[1] and i[1] == j[0]):
label = j[2][0]
skeleton[np.where(skel_arr == label)] = 1
# Add intersections along longest path
intersec_pts = []
for label in path:
try:
label = int(label)
except ValueError:
pass
if not isinstance(label, int):
k = 1
while list(zip(product_gen(string.ascii_uppercase),
[1] * k))[-1][0] != label:
k += 1
intersec_pts.extend(inters[k - 1])
skeleton[tuple(zip(*inters[k - 1]))] = 2
# Remove unnecessary pixels
count = 0
while True:
for pt in intersec_pts:
# If we have already eliminated the point, continue
if skeleton[pt] == 0:
continue
skeleton[pt] = 0
lab_try, n = ndimage.label(skeleton, two_con_3D)
if n > 1:
skeleton[pt] = 1
else:
count += 1
if count == 0:
break
count = 0
# main_lengths.append(skeleton_length(skeleton) * img_scale)
# This is a place holding hack at the moment and main_lengths does not actually hold the lengths of the
# longest paths
main_lengths.append(1.0 * img_scale)
longpath_cube[skeleton.astype(bool)] = True
return main_lengths, longpath_cube.astype(int)
[docs]
def classify_structure(skeleton):
"""
Classify the components of a skeleton into labeled branches, intersections, and endpoints.
This function processes a binary skeleton array, identifies its endpoints and
intersection points, and removes intersections to separate individual branches.
It returns labeled arrays for the branches and lists of intersection points and endpoints.
Parameters
----------
skeleton : ndarray
Binary array representing the skeletonized structure. Non-zero values indicate
skeleton points, and zero values represent the background.
Returns
-------
labelisofil : list of ndarray
List of labeled arrays, where each array corresponds to a skeleton with intersections
removed, and branches are labeled with unique integers.
interpts : list of list of tuple
List of intersection points for each skeleton, with each intersection containing
the coordinates of pixels belonging to it.
ends : list of list of tuple
List of endpoint coordinates for each skeleton.
Notes
-----
- The function uses 8-connectivity for 2D skeletons and maximum connectivity for
higher dimensions to label individual structures.
- Endpoints are identified using the `endPoints` function.
- Intersection points are labeled separately, and their coordinates are stored
in `interpts`.
- Branches are labeled after removing intersections from the skeleton.
- This code is based on the 2D version seen in FilFinder (v1.7.2) by Eric Koch.
"""
def labCrdList(labelled, num, refStructure):
'''
Place the coordinates of individual, labelled structure a list
'''
crd_list = []
for n in range(num):
crd = np.argwhere(np.logical_and(labelled == n + 1, refStructure != 0))
crd = list(map(tuple, crd))
if crd:
crd_list.append(crd)
return crd_list
# label the skeletons
connectivity = skeleton.ndim # use maximum connectivity for the dimensions
SkLb, SkNum = morphology.label(skeleton, connectivity=connectivity, return_num=True)
# acquire end-points, label them, and place them into a coordinate list
print("getting end-points")
EpFk = endPoints(skeleton)
ends = labCrdList(labelled=SkLb, num=SkNum, refStructure=EpFk)
# acquire the branched-points (i.e., intersection points), label them, and place them into a coordinate list
print("getting branched-points")
BpFk = branchedPoints(skeleton, endpt=EpFk)
# Not exactly an elegant implementation below, but it'll have to do for now
interpts = []
for n in range(SkNum):
BpFk_temp = BpFk.copy()
BpFk_temp[SkLb != n + 1] = 0
Lb, Num = morphology.label(BpFk_temp, connectivity=connectivity, return_num=True)
crdList = labCrdList(labelled=Lb, num=Num, refStructure=BpFk)
interpts.append(crdList)
# remove the intersection points from the original skeleton
SkFk_bpRemoved = skeleton.copy()
SkFk_bpRemoved[BpFk != 0] = 0
# for each skeleton with the intersections removed, label each branch
labelisofil = []
for n in range(SkNum):
skl = SkFk_bpRemoved.copy()
skl[SkLb != n + 1] = 0
labelisofil.append(morphology.label(skl, connectivity=connectivity, return_num=False))
return labelisofil, interpts, ends