Source code for photutils.segmentation.deblend

# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""
This module provides tools for deblending overlapping sources labeled in
a segmentation image.
"""

import warnings
from concurrent.futures import ProcessPoolExecutor, as_completed
from dataclasses import dataclass
from functools import partial
from multiprocessing import cpu_count, get_context

import numpy as np
from astropy.units import Quantity
from astropy.utils import lazyproperty
from astropy.utils.exceptions import AstropyUserWarning
from scipy.ndimage import label as ndi_label
from scipy.ndimage import sum_labels

from photutils.segmentation.core import SegmentationImage
from photutils.segmentation.detect import _detect_sources
from photutils.segmentation.utils import _make_binary_structure
from photutils.utils._optional_deps import tqdm
from photutils.utils._progress_bars import add_progress_bar
from photutils.utils._stats import nanmax, nanmin, nansum

__all__ = ['deblend_sources']


@dataclass
class _DeblendParams:
    npixels: int
    footprint: np.ndarray
    nlevels: int
    contrast: float
    mode: str


[docs] def deblend_sources(data, segment_img, npixels, *, labels=None, nlevels=32, contrast=0.001, mode='exponential', connectivity=8, relabel=True, nproc=1, progress_bar=True): """ Deblend overlapping sources labeled in a segmentation image. Sources are deblended using a combination of multi-thresholding and `watershed segmentation <https://en.wikipedia.org/wiki/Watershed_(image_processing)>`_. In order to deblend sources, there must be a saddle between them. Parameters ---------- data : 2D `~numpy.ndarray` The 2D array of the image. If filtering is desired, please input a convolved image here. This array should be the same array used in `~photutils.segmentation.detect_sources`. segment_img : `~photutils.segmentation.SegmentationImage` The segmentation image to deblend. npixels : int The minimum number of connected pixels, each greater than ``threshold``, that an object must have to be deblended. ``npixels`` must be a positive integer. labels : int or array_like of int, optional The label numbers to deblend. If `None` (default), then all labels in the segmentation image will be deblended. nlevels : int, optional The number of multi-thresholding levels to use for deblending. Each source will be re-thresholded at ``nlevels`` levels spaced between its minimum and maximum values (non-inclusive). The ``mode`` keyword determines how the levels are spaced. contrast : float, optional The fraction of the total source flux that a local peak must have (at any one of the multi-thresholds) to be deblended as a separate object. ``contrast`` must be between 0 and 1, inclusive. If ``contrast=0`` then every local peak will be made a separate object (maximum deblending). If ``contrast=1`` then no deblending will occur. The default is 0.001, which will deblend sources with a 7.5 magnitude difference. mode : {'exponential', 'linear', 'sinh'}, optional The mode used in defining the spacing between the multi-thresholding levels (see the ``nlevels`` keyword) during deblending. The ``'exponential'`` and ``'sinh'`` modes have more threshold levels near the source minimum and less near the source maximum. The ``'linear'`` mode evenly spaces the threshold levels between the source minimum and maximum. The ``'exponential'`` and ``'sinh'`` modes differ in that the ``'exponential'`` levels are dependent on the source maximum/minimum ratio (smaller ratios are more linear; larger ratios are more exponential), while the ``'sinh'`` levels are not. Also, the ``'exponential'`` mode will be changed to ``'linear'`` for sources with non-positive minimum data values. connectivity : {8, 4}, optional The type of pixel connectivity used in determining how pixels are grouped into a detected source. The options are 8 (default) or 4. 8-connected pixels touch along their edges or corners. 4-connected pixels touch along their edges. The ``connectivity`` must be the same as that used to create the input segmentation image. relabel : bool, optional If `True` (default), then the segmentation image will be relabeled such that the labels are in consecutive order starting from 1. nproc : int, optional The number of processes to use for multiprocessing (if larger than 1). If set to 1, then a serial implementation is used instead of a parallel one. If `None`, then the number of processes will be set to the number of CPUs detected on the machine. Please note that due to overheads, multiprocessing may be slower than serial processing if only a small number of sources are to be deblended. The benefits of multiprocessing require ~1000 or more sources to deblend, with larger gains as the number of sources increase. progress_bar : bool, optional Whether to display a progress bar. If ``nproc = 1``, then the ID shown after the progress bar is the source label being deblended. If multiprocessing is used (``nproc > 1``), the ID shown is the last source label that was deblended. The progress bar requires that the `tqdm <https://tqdm.github.io/>`_ optional dependency be installed. Note that the progress bar does not currently work in the Jupyter console due to limitations in ``tqdm``. Returns ------- segment_image : `~photutils.segmentation.SegmentationImage` A segmentation image, with the same shape as ``data``, where sources are marked by different positive integer values. A value of zero is reserved for the background. See Also -------- :func:`photutils.segmentation.detect_sources` :class:`photutils.segmentation.SourceFinder` """ if isinstance(data, Quantity): data = data.value if not isinstance(segment_img, SegmentationImage): raise TypeError('segment_img must be a SegmentationImage') if segment_img.shape != data.shape: raise ValueError('The data and segmentation image must have ' 'the same shape') if nlevels < 1: raise ValueError('nlevels must be >= 1') if contrast < 0 or contrast > 1: raise ValueError('contrast must be >= 0 and <= 1') if contrast == 1: # no deblending return segment_img.copy() if mode not in ('exponential', 'linear', 'sinh'): raise ValueError('mode must be "exponential", "linear", or "sinh"') if labels is None: labels = segment_img.labels else: labels = np.atleast_1d(labels) segment_img.check_labels(labels) # include only sources that have at least (2 * npixels); # this is required for a source to be deblended into multiple # sources, each with a minimum of npixels mask = (segment_img.areas[segment_img.get_indices(labels)] >= (npixels * 2)) labels = labels[mask] footprint = _make_binary_structure(data.ndim, connectivity) deblend_params = _DeblendParams(npixels, footprint, nlevels, contrast, mode) segm_deblended = segment_img.data.copy() label_indices = segment_img.get_indices(labels) if nproc is None: nproc = cpu_count() # pragma: no cover deblend_label_map = {} max_label = segment_img.max_label if nproc == 1: if progress_bar: # pragma: no cover desc = 'Deblending' label_indices = add_progress_bar(label_indices, desc=desc) nonposmin_labels = [] nmarkers_labels = [] for label, label_idx in zip(labels, label_indices, strict=True): if not isinstance(label_indices, np.ndarray): label_indices.set_postfix_str(f'ID: {label}') source_slice = segment_img.slices[label_idx] source_data = data[source_slice] source_segment = segment_img.data[source_slice] source_deblended, warns = _deblend_source(source_data, source_segment, label, deblend_params) if warns: if 'nonposmin' in warns: nonposmin_labels.append(label) if 'nmarkers' in warns: nmarkers_labels.append(label) if source_deblended is not None: source_mask = source_deblended > 0 new_segm = source_deblended[source_mask] # min label = 1 segm_deblended[source_slice][source_mask] = ( new_segm + max_label) new_labels = _get_labels(new_segm) + max_label deblend_label_map[label] = new_labels max_label += len(new_labels) else: # Use multiprocessing to deblend sources # Prepare the arguments for the worker function all_source_data = [] all_source_segments = [] all_source_slices = [] for label_idx in label_indices: source_slice = segment_img.slices[label_idx] source_data = data[source_slice] source_segment = segment_img.data[source_slice] all_source_data.append(source_data) all_source_segments.append(source_segment) all_source_slices.append(source_slice) args_all = zip(all_source_data, all_source_segments, labels, strict=True) # Create a partial function to pass the deblend_params to the # worker function worker = partial(_deblend_source, deblend_params=deblend_params) # Prepare to store futures and results to preserve the input # order of the labels when using as_completed() futures_dict = {} results = [None] * len(labels) disable_pbar = not progress_bar mp_context = get_context('spawn') with ProcessPoolExecutor(mp_context=mp_context, max_workers=nproc) as executor: # Submit all jobs at once for index, args in enumerate(args_all): futures_dict[executor.submit(worker, *args)] = index with tqdm(total=len(labels), desc='Deblending', disable=disable_pbar) as pbar: # Process the results as they are completed for future in as_completed(futures_dict): pbar.update(1) idx = futures_dict[future] pbar.set_postfix_str(f'ID: {labels[idx]}') results[idx] = future.result() # Process the results nonposmin_labels = [] nmarkers_labels = [] for label, source_slice, source_deblended in zip(labels, all_source_slices, results, strict=True): source_deblended, warns = source_deblended if warns: if 'nonposmin' in warns: nonposmin_labels.append(label) if 'nmarkers' in warns: nmarkers_labels.append(label) if source_deblended is not None: source_mask = source_deblended > 0 new_segm = source_deblended[source_mask] # min label = 1 segm_deblended[source_slice][source_mask] = ( new_segm + max_label) new_labels = _get_labels(new_segm) + max_label deblend_label_map[label] = new_labels max_label += len(new_labels) # process any warnings during deblending warning_info = {} if nonposmin_labels or nmarkers_labels: msg = ('The deblending mode of one or more source labels from the ' f'input segmentation image was changed from "{mode}" to ' '"linear". See the "info" attribute for the list of affected ' 'input labels.') warnings.warn(msg, AstropyUserWarning) if nonposmin_labels: nonposmin_labels = np.array(nonposmin_labels) msg = (f'Deblending mode changed from {mode} to linear due to ' 'non-positive minimum data values.') warn = {'message': msg, 'input_labels': nonposmin_labels} warning_info['nonposmin'] = warn if nmarkers_labels: nmarkers_labels = np.array(nmarkers_labels) msg = (f'Deblending mode changed from {mode} to linear due to ' 'too many potential deblended sources.') warn = {'message': msg, 'input_labels': nmarkers_labels} warning_info['nmarkers'] = warn if relabel: relabel_map = _create_relabel_map(segm_deblended, start_label=1) if relabel_map is not None: segm_deblended = relabel_map[segm_deblended] deblend_label_map = _update_deblend_label_map(deblend_label_map, relabel_map) segm_img = object.__new__(SegmentationImage) segm_img._data = segm_deblended segm_img._deblend_label_map = deblend_label_map # store the warnings in the output SegmentationImage info attribute if warning_info: segm_img.info = {'warnings': warning_info} return segm_img
def _deblend_source(data, segment_data, label, deblend_params): """ Convenience function to deblend a single labeled source. """ deblender = _SingleSourceDeblender(data, segment_data, label, deblend_params) return deblender.deblend_source(), deblender.warnings class _SingleSourceDeblender: """ Class to deblend a single labeled source. Parameters ---------- data : 2D `~numpy.ndarray` The cutout data array for a single source. ``data`` should also already be smoothed by the same filter used in :func:`~photutils.segmentation.detect_sources`, if applicable. segment_data : 2D int `~numpy.ndarray` The cutout segmentation image for a single source. Must have the same shape as ``data``. label : int The label of the source to deblend. This is needed because there may be more than one source label within the cutout. npixels : int The number of connected pixels, each greater than ``threshold``, that an object must have to be detected. ``npixels`` must be a positive integer. nlevels : int The number of multi-thresholding levels to use. Each source will be re-thresholded at ``nlevels`` levels spaced between its minimum and maximum values within the source segment. See the ``mode`` keyword for how the levels are spaced. contrast : float The fraction of the total (blended) source flux that a local peak must have (at any one of the multi-thresholds) to be considered as a separate object. ``contrast`` must be between 0 and 1, inclusive. If ``contrast = 0`` then every local peak will be made a separate object (maximum deblending). If ``contrast = 1`` then no deblending will occur. The default is 0.001, which will deblend sources with a 7.5 magnitude difference. mode : {'exponential', 'linear', 'sinh'} The mode used in defining the spacing between the multi-thresholding levels (see the ``nlevels`` keyword). Returns ------- segment_image : `~photutils.segmentation.SegmentationImage` A segmentation image, with the same shape as ``data``, where sources are marked by different positive integer values. A value of zero is reserved for the background. Note that the returned `SegmentationImage` will have consecutive labels starting with 1. """ def __init__(self, data, segment_data, label, deblend_params): self.data = data self.segment_data = segment_data self.label = label self.npixels = deblend_params.npixels self.footprint = deblend_params.footprint self.nlevels = deblend_params.nlevels self.contrast = deblend_params.contrast self.mode = deblend_params.mode self.segment_mask = segment_data == label data_values = data[self.segment_mask] self.source_min = nanmin(data_values) self.source_max = nanmax(data_values) self.source_sum = nansum(data_values) self.warnings = {} @lazyproperty def linear_thresholds(self): """ Linearly spaced thresholds between the source minimum and maximum (inclusive). The source min/max are excluded later, giving nlevels thresholds between min and max (noninclusive). """ return np.linspace(self.source_min, self.source_max, self.nlevels + 2) @lazyproperty def normalized_thresholds(self): """ Normalized thresholds (from 0 to 1) between the source minimum and maximum (inclusive). """ return ((self.linear_thresholds - self.source_min) / (self.source_max - self.source_min)) def compute_thresholds(self): """ Compute the multi-level detection thresholds for the source. Returns ------- thresholds : 1D `~numpy.ndarray` The multi-level detection thresholds for the source. """ if self.mode == 'exponential' and self.source_min <= 0: self.warnings['nonposmin'] = 'non-positive minimum' self.mode = 'linear' if self.mode == 'linear': thresholds = self.linear_thresholds elif self.mode == 'sinh': a = 0.25 minval = self.source_min maxval = self.source_max thresholds = self.normalized_thresholds thresholds = np.sinh(thresholds / a) / np.sinh(1.0 / a) thresholds *= (maxval - minval) thresholds += minval elif self.mode == 'exponential': minval = self.source_min maxval = self.source_max thresholds = self.normalized_thresholds thresholds = minval * (maxval / minval) ** thresholds return thresholds[1:-1] # do not include source min and max def multithreshold(self): """ Perform multithreshold detection for each source. This method is useful for debugging and testing. Parameters ---------- deblend_mode : bool, optional If `True` then only segmentation images with more than one label will be returned. If `False` then all segmentation images will be returned. Returns ------- segments : list of 2D `~numpy.ndarray` A list of segmentation images, one for each threshold. Only segmentation images with more than one label will be returned. """ thresholds = self.compute_thresholds() segms = [] for threshold in thresholds: segm = _detect_sources(self.data, threshold, self.npixels, self.footprint, self.segment_mask, relabel=False, return_segmimg=False) segms.append(segm) return segms def make_markers(self, return_all=False): """ Make markers (possible sources) for the watershed algorithm. Parameters ---------- return_all : bool, optional If `False` then return only the final segmentation marker image. If `True` then return all segmentation marker images. This keyword is useful for debugging and testing. Returns ------- markers : 2D `~numpy.ndarray` or list of 2D `~numpy.ndarray` A segmentation image that contain markers for possible sources. If ``return_all=True`` then a list of all segmentation marker images is returned. `None` is returned if there is only one source at every threshold. """ thresholds = self.compute_thresholds() segm_lower = _detect_sources(self.data, thresholds[0], self.npixels, self.footprint, self.segment_mask, relabel=False, return_segmimg=False) if return_all: all_segms = [segm_lower] for threshold in thresholds[1:]: segm_upper = _detect_sources(self.data, threshold, self.npixels, self.footprint, self.segment_mask, relabel=False, return_segmimg=False) if segm_upper is None: # 0 or 1 labels continue segm_lower = self.make_marker_segment(segm_lower, segm_upper) if return_all: all_segms.append(segm_lower) if return_all: return all_segms return segm_lower def make_marker_segment(self, segment_lower, segment_upper): """ Make markers (possible sources) for the watershed algorithm. Parameters ---------- segment_lower : 2D `~numpy.ndarray` The "lower" threshold level segmentation image. segment_upper : 2D `~numpy.ndarray` The next-highest threshold level segmentation image. Returns ------- markers : 2D `~numpy.ndarray` A segmentation image that contain markers for possible sources. Notes ----- For a given label in the lower level, find the labels in the upper level (higher threshold value) that are its children (i.e., the labels within the same mask as the lower level). If there are multiple children, then the lower-level parent label is replaced by its children. Parent labels that do not have multiple children in the upper level are kept as is (maximizing the marker size). """ if segment_lower is None: return segment_upper labels = _get_labels(segment_lower) new_markers = False markers = segment_lower.astype(bool) for label in labels: mask = (segment_lower == label) # find label mapping from the lower to upper level upper_labels = _get_labels(segment_upper[mask]) if upper_labels.size >= 2: # new child markers found new_markers = True markers[mask] = segment_upper[mask].astype(bool) if new_markers: # convert bool markers to integer labels return ndi_label(markers, structure=self.footprint)[0] return segment_lower def apply_watershed(self, markers): """ Apply the watershed algorithm to the source markers. Parameters ---------- markers : list of `~photutils.segmentation.SegmentationImage` A list of segmentation images that contain possible sources as markers. The last list element contains all of the potential source markers. Returns ------- segment_data : 2D int `~numpy.ndarray` A 2D int array containing the deblended source labels. Note that the source labels may not be consecutive if a label was removed. """ from skimage.segmentation import watershed # Deblend using watershed. If any source does not meet the contrast # criterion, then remove the faintest such source and repeat until # all sources meet the contrast criterion. remove_marker = True while remove_marker: markers = watershed(-self.data, markers, mask=self.segment_mask, connectivity=self.footprint) labels = _get_labels(markers) if labels.size == 1: # only 1 source left remove_marker = False else: flux_frac = (sum_labels(self.data, markers, index=labels) / self.source_sum) remove_marker = any(flux_frac < self.contrast) if remove_marker: # remove only the faintest source (one at a time) # because several faint sources could combine to meet # the contrast criterion markers[markers == labels[np.argmin(flux_frac)]] = 0.0 return markers def deblend_source(self): """ Deblend a single labeled source. Returns ------- segment_data : 2D int `~numpy.ndarray` A 2D int array containing the deblended source labels. The source labels are consecutive starting at 1. """ if self.source_min == self.source_max: # no deblending return None # define the markers (possible sources) for the watershed algorithm markers = self.make_markers() if markers is None: return None # If there are too many markers (e.g., due to low threshold # and/or small npixels), the watershed step can be very slow # (the threshold of 200 is arbitrary, but seems to work well). # This mostly affects the "exponential" mode, where there are # many levels at low thresholds, so here we try again with # "linear" mode. nlabels = len(_get_labels(markers)) if self.mode != 'linear' and nlabels > 200: del markers # free memory self.warnings['nmarkers'] = 'too many markers' self.mode = 'linear' markers = self.make_markers() if markers is None: return None # deblend using the watershed algorithm using the markers as seeds markers = self.apply_watershed(markers) if not np.array_equal(self.segment_mask, markers.astype(bool)): raise ValueError(f'Deblending failed for source "{self.label}". ' 'Please ensure you used the same pixel ' 'connectivity in detect_sources and ' 'deblend_sources.') if len(_get_labels(markers)) == 1: # no deblending return None # markers may not be consecutive if a label was removed due to # the contrast criterion relabel_map = _create_relabel_map(markers, start_label=1) if relabel_map is not None: markers = relabel_map[markers] return markers def _get_labels(array): """ Get the unique labels greater than zero in an array. Parameters ---------- array : `~numpy.ndarray` The array to get the unique labels from. Returns ------- labels : int `~numpy.ndarray` The unique labels in the array. """ labels = np.unique(array) return labels[labels != 0] def _create_relabel_map(array, start_label=1): """ Create a mapping of original labels to new labels that are consecutive integers. By default, the new labels start from 1. Parameters ---------- array : 2D `~numpy.ndarray` The 2D array to relabel. start_label : int, optional The starting label number. Must be >= 1. The default is 1. Returns ------- relabel_map : 1D `~numpy.ndarray` or None The array mapping the original labels to the new labels. If the labels are already consecutive starting from ``start_label``, then `None` is returned. """ labels = _get_labels(array) # check if the labels are already consecutive starting from # start_label if (labels[0] == start_label and (labels[-1] - start_label + 1) == len(labels)): return None # Create an array to map old labels to new labels relabel_map = np.zeros(labels.max() + 1, dtype=array.dtype) relabel_map[labels] = np.arange(len(labels)) + start_label return relabel_map def _update_deblend_label_map(deblend_label_map, relabel_map): """ Update the deblend_label_map to reflect the new labels that are consecutive integers. Parameters ---------- deblend_label_map : dict A dictionary mapping the original labels to the new deblended labels. relabel_map : 1D `~numpy.ndarray` The array mapping the original labels to the new labels. Returns ------- deblend_label_map : dict The updated deblend_label_map. """ for old_label, new_labels in deblend_label_map.items(): deblend_label_map[old_label] = relabel_map[new_labels] return deblend_label_map