Source code for photutils.profiles.core

# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""
Base class for profiles.
"""

import abc
import warnings

import numpy as np
from astropy.utils import lazyproperty
from astropy.utils.exceptions import AstropyUserWarning

from photutils.utils._deprecation import deprecated_positional_kwargs
from photutils.utils._quantity_helpers import process_quantities
from photutils.utils._stats import nanmax, nansum

__all__ = ['ProfileBase']


[docs] class ProfileBase(metaclass=abc.ABCMeta): """ Abstract base class for profile classes. Parameters ---------- data : 2D `~numpy.ndarray` The 2D data array. The data should be background-subtracted. xycen : tuple of 2 floats The ``(x, y)`` pixel coordinate of the source center. radii : 1D float `~numpy.ndarray` An array of radii defining the profile apertures. ``radii`` must be strictly increasing with a minimum value greater than or equal to zero, and contain at least 2 values. The radial spacing does not need to be constant. See the subclass documentation for details on how ``radii`` is interpreted. error : 2D `~numpy.ndarray`, optional The 1-sigma errors of the input ``data``. ``error`` is assumed to include all sources of error, including the Poisson error of the sources (see `~photutils.utils.calc_total_error`). ``error`` must have the same shape as the input ``data``. mask : 2D bool `~numpy.ndarray`, optional A boolean mask with the same shape as ``data`` where a `True` value indicates the corresponding element of ``data`` is masked. Masked data are excluded from all calculations. method : {'exact', 'center', 'subpixel'}, optional The method used to determine the overlap of the aperture on the pixel grid: * ``'exact'`` (default): The exact fractional overlap of the aperture and each pixel is calculated. The aperture weights will contain values between 0 and 1. * ``'center'``: A pixel is considered to be entirely in or out of the aperture depending on whether its center is in or out of the aperture. The aperture weights will contain values only of 0 (out) and 1 (in). * ``'subpixel'``: A pixel is divided into subpixels (see the ``subpixels`` keyword), each of which are considered to be entirely in or out of the aperture depending on whether its center is in or out of the aperture. If ``subpixels=1``, this method is equivalent to ``'center'``. The aperture weights will contain values between 0 and 1. subpixels : int, optional For the ``'subpixel'`` method, resample pixels by this factor in each dimension. That is, each pixel is divided into ``subpixels**2`` subpixels. This keyword is ignored unless ``method='subpixel'``. """ # Define axis labels used by `~photutils.profiles.ProfileBase.plot`. # Subclasses may override these. _xlabel = 'Radius (pixels)' _ylabel = 'Profile' def __init__(self, data, xycen, radii, *, error=None, mask=None, method='exact', subpixels=5): (data, error), unit = process_quantities((data, error), ('data', 'error')) if error is not None and error.shape != data.shape: msg = 'error must have the same shape as data' raise ValueError(msg) self.data = data self.unit = unit self.xycen = xycen self.radii = self._validate_radii(radii) self.error = error self.mask = self._compute_mask(data, error, mask) self.method = method self.subpixels = subpixels self.normalization_value = 1.0 def _validate_radii(self, radii): """ Validate and return the radii array. """ radii = np.array(radii) if radii.ndim != 1 or radii.size < 2: msg = 'radii must be a 1D array and have at least two values' raise ValueError(msg) if radii.min() < 0: msg = 'minimum radii must be >= 0' raise ValueError(msg) if not np.all(radii[1:] > radii[:-1]): msg = 'radii must be strictly increasing' raise ValueError(msg) return radii def _compute_mask(self, data, error, mask): """ Compute the mask array, automatically masking non-finite data or error values. """ badmask = ~np.isfinite(data) if error is not None: badmask |= ~np.isfinite(error) if mask is not None: if mask.shape != data.shape: msg = 'mask must have the same shape as data' raise ValueError(msg) # Keep only non-finite values not already masked by the user badmask &= ~mask combined_mask = mask | badmask # all masked pixels else: combined_mask = badmask if np.any(badmask): msg = ('Input data contains non-finite values (e.g., NaN ' 'or inf) that were automatically masked.') warnings.warn(msg, AstropyUserWarning) return combined_mask @property @abc.abstractmethod def radius(self): """ The profile radius in pixels as a 1D `~numpy.ndarray`. """ @property @abc.abstractmethod def profile(self): """ The radial profile as a 1D `~numpy.ndarray`. """ @property @abc.abstractmethod def profile_error(self): """ The profile errors as a 1D `~numpy.ndarray`. If no ``error`` array was provided, an empty array with shape ``(0,)`` is returned. """ @lazyproperty def _circular_apertures(self): """ A list of `~photutils.aperture.CircularAperture` objects. The first element may be `None`. """ from photutils.aperture import CircularAperture apertures = [] for radius in self.radii: if radius <= 0.0: apertures.append(None) else: apertures.append(CircularAperture(self.xycen, radius)) return apertures def _compute_photometry(self, apertures): """ Compute aperture fluxes, flux errors, and areas for the given apertures. Parameters ---------- apertures : list A list of aperture objects. Elements may be `None`, in which case the corresponding flux, error, and area are set to zero. Returns ------- flux : `~numpy.ndarray` The aperture fluxes. flux_err : `~numpy.ndarray` The aperture flux errors. areas : `~numpy.ndarray` The aperture areas. """ fluxes = [] flux_errs = [] areas = [] for aperture in apertures: if aperture is None: flux, flux_err = [0.0], [0.0] area = 0.0 else: flux, flux_err = aperture.do_photometry( self.data, error=self.error, mask=self.mask, method=self.method, subpixels=self.subpixels) area = aperture.area_overlap(self.data, mask=self.mask, method=self.method, subpixels=self.subpixels) fluxes.append(flux[0]) if self.error is not None: flux_errs.append(flux_err[0]) areas.append(area) fluxes = np.array(fluxes) flux_errs = np.array(flux_errs) areas = np.array(areas) if self.unit is not None: fluxes <<= self.unit flux_errs <<= self.unit return fluxes, flux_errs, areas @lazyproperty def _photometry(self): """ The aperture fluxes, flux errors, and areas as a function of radius. """ return self._compute_photometry(self._circular_apertures)
[docs] @deprecated_positional_kwargs(since='3.0', until='4.0') def normalize(self, method='max'): """ Normalize the profile. Parameters ---------- method : {'max', 'sum'}, optional The method used to normalize the profile: * ``'max'`` (default): The profile is normalized such that its maximum value is 1. * ``'sum'``: The profile is normalized such that its sum (integral) is 1. """ if method == 'max': with warnings.catch_warnings(): warnings.simplefilter('ignore', RuntimeWarning) normalization = nanmax(self.profile) elif method == 'sum': with warnings.catch_warnings(): warnings.simplefilter('ignore', RuntimeWarning) normalization = nansum(self.profile) else: msg = "invalid method, must be 'max' or 'sum'" raise ValueError(msg) if normalization == 0 or not np.isfinite(normalization): msg = ('The profile cannot be normalized because the max or ' 'sum is zero or non-finite.') warnings.warn(msg, AstropyUserWarning) else: # normalization_values accumulate if normalize is run # multiple times (e.g., different methods) self.normalization_value *= normalization # Need to use __dict__ as these are lazy properties self.__dict__['profile'] = self.profile / normalization self.__dict__['profile_error'] = self.profile_error / normalization self._normalize_hook(normalization)
def _normalize_hook(self, normalization): # noqa: B027 """ Hook called by `normalize` after normalizing ``profile`` and ``profile_error``. This hook is only called when normalization succeeds (i.e., when the normalization value is non-zero and finite). Subclasses can override this to normalize additional lazy properties (e.g., ``data_profile``). Parameters ---------- normalization : float The normalization value applied to the profile. """
[docs] def unnormalize(self): """ Unnormalize the profile back to the original state before any calls to `normalize`. """ self.__dict__['profile'] = self.profile * self.normalization_value self.__dict__['profile_error'] = (self.profile_error * self.normalization_value) self._unnormalize_hook() self.normalization_value = 1.0
def _unnormalize_hook(self): # noqa: B027 """ Hook called by `unnormalize` after unnormalizing ``profile`` and ``profile_error``, but before resetting ``normalization_value``. Subclasses can override this to unnormalize additional lazy properties (e.g., ``data_profile``). """ @staticmethod def _trim_to_monotonic(xarr, profile, name): """ Trim arrays to the first monotonically increasing region. This is used by interpolation methods that require a monotonically increasing profile. Parameters ---------- xarr : 1D `~numpy.ndarray` The x-axis values (e.g., radius or half-size). profile : 1D `~numpy.ndarray` The profile values. name : str A descriptive name for the profile used in the error message. Returns ------- xarr, profile : tuple of `~numpy.ndarray` The trimmed arrays. """ finite_mask = np.isfinite(profile) if not np.all(finite_mask): # Keep only the leading finite segment first_nonfinite = np.argmin(finite_mask) xarr = xarr[:first_nonfinite] profile = profile[:first_nonfinite] # np.diff produces an array of length n-1: diff[i] represents # the step from profile[i] to profile[i+1]. A value <= 0 means # the profile stopped increasing at that step. diff = np.diff(profile) <= 0 if np.any(diff): # idx is an index into the *diff* array, not the profile # array. diff[idx] <= 0 means the drop occurs between # profile[idx] and profile[idx+1], so profile[idx] is # the last good value. We therefore need profile[:idx+1] # (inclusive) to retain it. idx = np.argmax(diff) # first non-monotonic step in diff-space xarr = xarr[:idx + 1] profile = profile[:idx + 1] if len(xarr) < 2: msg = (f'The {name} profile is not monotonically ' 'increasing even at the smallest values -- cannot ' 'interpolate. Try using different input values ' '(especially the starting values) and/or using the ' '"exact" aperture overlap method.') raise ValueError(msg) return xarr, profile def __repr__(self): cls_name = self.__class__.__name__ n_radii = len(self.radii) normalized = self.normalization_value != 1.0 return (f'{cls_name}(xycen={self.xycen}, n_radii={n_radii}, ' f'normalized={normalized})')
[docs] @deprecated_positional_kwargs(since='3.0', until='4.0') def plot(self, ax=None, **kwargs): """ Plot the profile. Parameters ---------- ax : `matplotlib.axes.Axes` or `None`, optional The matplotlib axes on which to plot. If `None`, then the current `~matplotlib.axes.Axes` instance is used. **kwargs : dict, optional Any keyword arguments accepted by `matplotlib.pyplot.plot`. Returns ------- lines : list of `~matplotlib.lines.Line2D` A list of lines representing the plotted data. """ import matplotlib.pyplot as plt if ax is None: ax = plt.gca() lines = ax.plot(self.radius, self.profile, **kwargs) ax.set_xlabel(self._xlabel) ylabel = self._ylabel if self.unit is not None: ylabel = f'{ylabel} ({self.unit})' ax.set_ylabel(ylabel) return lines
[docs] @deprecated_positional_kwargs(since='3.0', until='4.0') def plot_error(self, ax=None, **kwargs): """ Plot the profile errors. Parameters ---------- ax : `matplotlib.axes.Axes` or `None`, optional The matplotlib axes on which to plot. If `None`, then the current `~matplotlib.axes.Axes` instance is used. **kwargs : dict, optional Any keyword arguments accepted by `matplotlib.pyplot.fill_between`. Returns ------- poly : `matplotlib.collections.PolyCollection` or `None` A `~matplotlib.collections.PolyCollection` containing the plotted polygons, or `None` if no errors were input. """ if len(self.profile_error) == 0: msg = 'Errors were not input' warnings.warn(msg, AstropyUserWarning) return None import matplotlib.pyplot as plt if ax is None: ax = plt.gca() # Set default fill_between facecolor. # facecolor must be first key, otherwise it will override color # kwarg (i.e., cannot use setdefault here) if 'facecolor' not in kwargs: kws = {'facecolor': (0.5, 0.5, 0.5, 0.3)} kws.update(kwargs) else: kws = kwargs profile = self.profile profile_error = self.profile_error if self.unit is not None: profile = profile.value profile_error = profile_error.value ymin = profile - profile_error ymax = profile + profile_error return ax.fill_between(self.radius, ymin, ymax, **kws)