# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""
Base class for profiles.
"""
import abc
import warnings
import numpy as np
from astropy.utils import lazyproperty
from astropy.utils.exceptions import AstropyUserWarning
from photutils.aperture.core import _update_method_subpixels_docstring
from photutils.utils._deprecation import deprecated_positional_kwargs
from photutils.utils._quantity_helpers import process_quantities
from photutils.utils._stats import nanmax, nansum
__all__ = ['ProfileBase']
[docs]
@_update_method_subpixels_docstring
class ProfileBase(metaclass=abc.ABCMeta):
# numpydoc ignore: PR01,PR02,PR04,PR07
"""
Abstract base class for profile classes.
Parameters
----------
data : 2D `~numpy.ndarray`
The 2D data array. The data should be background-subtracted.
xycen : tuple of 2 floats
The ``(x, y)`` pixel coordinate of the source center.
radii : 1D float `~numpy.ndarray`
An array of radii defining the profile apertures. ``radii`` must
be strictly increasing with a minimum value greater than or
equal to zero, and contain at least 2 values. The radial spacing
does not need to be constant. See the subclass documentation for
details on how ``radii`` is interpreted.
error : 2D `~numpy.ndarray`, optional
The 1-sigma errors of the input ``data``. ``error`` is assumed
to include all sources of error, including the Poisson error of
the sources (see `~photutils.utils.calc_total_error`). ``error``
must have the same shape as the input ``data``.
mask : 2D bool `~numpy.ndarray`, optional
A boolean mask with the same shape as ``data`` where a `True`
value indicates the corresponding element of ``data`` is masked.
Masked data are excluded from all calculations.
<method_subpixels_descriptions>
"""
# Define axis labels used by `~photutils.profiles.ProfileBase.plot`.
# Subclasses may override these.
_xlabel = 'Radius (pixels)'
_ylabel = 'Profile'
def __init__(self, data, xycen, radii, *, error=None, mask=None,
method='exact', subpixels=5):
(data, error), unit = process_quantities((data, error),
('data', 'error'))
if error is not None and error.shape != data.shape:
msg = 'error must have the same shape as data'
raise ValueError(msg)
self.data = data
self.unit = unit
self.xycen = xycen
self.radii = self._validate_radii(radii)
self.error = error
self.mask = self._compute_mask(data, error, mask)
self.method = method
self.subpixels = subpixels
self.normalization_value = 1.0
def _validate_radii(self, radii):
"""
Validate and return the radii array.
"""
radii = np.array(radii)
if radii.ndim != 1 or radii.size < 2:
msg = 'radii must be a 1D array and have at least two values'
raise ValueError(msg)
if radii.min() < 0:
msg = 'minimum radii must be >= 0'
raise ValueError(msg)
if not np.all(radii[1:] > radii[:-1]):
msg = 'radii must be strictly increasing'
raise ValueError(msg)
return radii
def _compute_mask(self, data, error, mask):
"""
Compute the mask array, automatically masking non-finite data or
error values.
"""
badmask = ~np.isfinite(data)
if error is not None:
badmask |= ~np.isfinite(error)
if mask is not None:
if mask.shape != data.shape:
msg = 'mask must have the same shape as data'
raise ValueError(msg)
# Keep only non-finite values not already masked by the user
badmask &= ~mask
combined_mask = mask | badmask # all masked pixels
else:
combined_mask = badmask
if np.any(badmask):
msg = ('Input data contains non-finite values (e.g., NaN '
'or inf) that were automatically masked.')
warnings.warn(msg, AstropyUserWarning)
return combined_mask
@property
@abc.abstractmethod
def radius(self):
"""
The profile radius in pixels as a 1D `~numpy.ndarray`.
"""
@property
@abc.abstractmethod
def profile(self):
"""
The radial profile as a 1D `~numpy.ndarray`.
"""
@property
@abc.abstractmethod
def profile_error(self):
"""
The profile errors as a 1D `~numpy.ndarray`.
If no ``error`` array was provided, an empty array with shape
``(0,)`` is returned.
"""
@lazyproperty
def _circular_apertures(self):
"""
A list of `~photutils.aperture.CircularAperture` objects.
The first element may be `None`.
"""
from photutils.aperture import CircularAperture
apertures = []
for radius in self.radii:
if radius <= 0.0:
apertures.append(None)
else:
apertures.append(CircularAperture(self.xycen, radius))
return apertures
def _compute_photometry(self, apertures):
"""
Compute aperture fluxes, flux errors, and areas for the given
apertures.
Parameters
----------
apertures : list
A list of aperture objects. Elements may be `None`, in which
case the corresponding flux, error, and area are set to
zero.
Returns
-------
flux : `~numpy.ndarray`
The aperture fluxes.
flux_err : `~numpy.ndarray`
The aperture flux errors.
areas : `~numpy.ndarray`
The aperture areas.
"""
fluxes = []
flux_errs = []
areas = []
for aperture in apertures:
if aperture is None:
flux, flux_err = [0.0], [0.0]
area = 0.0
else:
flux, flux_err = aperture.do_photometry(
self.data, error=self.error, mask=self.mask,
method=self.method, subpixels=self.subpixels)
area = aperture.area_overlap(self.data, mask=self.mask,
method=self.method,
subpixels=self.subpixels)
fluxes.append(flux[0])
if self.error is not None:
flux_errs.append(flux_err[0])
areas.append(area)
fluxes = np.array(fluxes)
flux_errs = np.array(flux_errs)
areas = np.array(areas)
if self.unit is not None:
fluxes <<= self.unit
flux_errs <<= self.unit
return fluxes, flux_errs, areas
@lazyproperty
def _photometry(self):
"""
The aperture fluxes, flux errors, and areas as a function of
radius.
"""
return self._compute_photometry(self._circular_apertures)
[docs]
@deprecated_positional_kwargs(since='3.0', until='4.0')
def normalize(self, method='max'):
"""
Normalize the profile.
Parameters
----------
method : {'max', 'sum'}, optional
The method used to normalize the profile:
* ``'max'`` (default):
The profile is normalized such that its maximum value is
1.
* ``'sum'``:
The profile is normalized such that its sum (integral) is
1.
"""
if method == 'max':
with warnings.catch_warnings():
warnings.simplefilter('ignore', RuntimeWarning)
normalization = nanmax(self.profile)
elif method == 'sum':
with warnings.catch_warnings():
warnings.simplefilter('ignore', RuntimeWarning)
normalization = nansum(self.profile)
else:
msg = "invalid method, must be 'max' or 'sum'"
raise ValueError(msg)
if normalization == 0 or not np.isfinite(normalization):
msg = ('The profile cannot be normalized because the max or '
'sum is zero or non-finite.')
warnings.warn(msg, AstropyUserWarning)
else:
# normalization_values accumulate if normalize is run
# multiple times (e.g., different methods)
self.normalization_value *= normalization
# Need to use __dict__ as these are lazy properties
self.__dict__['profile'] = self.profile / normalization
self.__dict__['profile_error'] = self.profile_error / normalization
self._normalize_hook(normalization)
def _normalize_hook(self, normalization): # noqa: B027
"""
Hook called by `normalize` after normalizing ``profile`` and
``profile_error``.
This hook is only called when normalization succeeds (i.e., when
the normalization value is non-zero and finite).
Subclasses can override this to normalize additional lazy
properties (e.g., ``data_profile``).
Parameters
----------
normalization : float
The normalization value applied to the profile.
"""
[docs]
def unnormalize(self):
"""
Unnormalize the profile back to the original state before any
calls to `normalize`.
"""
self.__dict__['profile'] = self.profile * self.normalization_value
self.__dict__['profile_error'] = (self.profile_error
* self.normalization_value)
self._unnormalize_hook()
self.normalization_value = 1.0
def _unnormalize_hook(self): # noqa: B027
"""
Hook called by `unnormalize` after unnormalizing ``profile`` and
``profile_error``, but before resetting ``normalization_value``.
Subclasses can override this to unnormalize additional lazy
properties (e.g., ``data_profile``).
"""
@staticmethod
def _trim_to_monotonic(xarr, profile, name):
"""
Trim arrays to the first monotonically increasing region.
This is used by interpolation methods that require a
monotonically increasing profile.
Parameters
----------
xarr : 1D `~numpy.ndarray`
The x-axis values (e.g., radius or half-size).
profile : 1D `~numpy.ndarray`
The profile values.
name : str
A descriptive name for the profile used in the error
message.
Returns
-------
xarr, profile : tuple of `~numpy.ndarray`
The trimmed arrays.
"""
finite_mask = np.isfinite(profile)
if not np.all(finite_mask):
# Keep only the leading finite segment
first_nonfinite = np.argmin(finite_mask)
xarr = xarr[:first_nonfinite]
profile = profile[:first_nonfinite]
# np.diff produces an array of length n-1: diff[i] represents
# the step from profile[i] to profile[i+1]. A value <= 0 means
# the profile stopped increasing at that step.
diff = np.diff(profile) <= 0
if np.any(diff):
# idx is an index into the *diff* array, not the profile
# array. diff[idx] <= 0 means the drop occurs between
# profile[idx] and profile[idx+1], so profile[idx] is
# the last good value. We therefore need profile[:idx+1]
# (inclusive) to retain it.
idx = np.argmax(diff) # first non-monotonic step in diff-space
xarr = xarr[:idx + 1]
profile = profile[:idx + 1]
if len(xarr) < 2:
msg = (f'The {name} profile is not monotonically '
'increasing even at the smallest values -- cannot '
'interpolate. Try using different input values '
'(especially the starting values) and/or using the '
'"exact" aperture overlap method.')
raise ValueError(msg)
return xarr, profile
def __repr__(self):
cls_name = self.__class__.__name__
n_radii = len(self.radii)
normalized = self.normalization_value != 1.0
return (f'{cls_name}(xycen={self.xycen}, n_radii={n_radii}, '
f'normalized={normalized})')
[docs]
@deprecated_positional_kwargs(since='3.0', until='4.0')
def plot(self, ax=None, **kwargs):
"""
Plot the profile.
Parameters
----------
ax : `matplotlib.axes.Axes` or `None`, optional
The matplotlib axes on which to plot. If `None`, then the
current `~matplotlib.axes.Axes` instance is used.
**kwargs : dict, optional
Any keyword arguments accepted by `matplotlib.pyplot.plot`.
Returns
-------
lines : list of `~matplotlib.lines.Line2D`
A list of lines representing the plotted data.
"""
import matplotlib.pyplot as plt
if ax is None:
ax = plt.gca()
lines = ax.plot(self.radius, self.profile, **kwargs)
ax.set_xlabel(self._xlabel)
ylabel = self._ylabel
if self.unit is not None:
ylabel = f'{ylabel} ({self.unit})'
ax.set_ylabel(ylabel)
return lines
[docs]
@deprecated_positional_kwargs(since='3.0', until='4.0')
def plot_error(self, ax=None, **kwargs):
"""
Plot the profile errors.
Parameters
----------
ax : `matplotlib.axes.Axes` or `None`, optional
The matplotlib axes on which to plot. If `None`, then the
current `~matplotlib.axes.Axes` instance is used.
**kwargs : dict, optional
Any keyword arguments accepted by
`matplotlib.pyplot.fill_between`.
Returns
-------
poly : `matplotlib.collections.PolyCollection` or `None`
A `~matplotlib.collections.PolyCollection` containing the
plotted polygons, or `None` if no errors were input.
"""
if len(self.profile_error) == 0:
msg = 'Errors were not input'
warnings.warn(msg, AstropyUserWarning)
return None
import matplotlib.pyplot as plt
if ax is None:
ax = plt.gca()
# Set default fill_between facecolor.
# facecolor must be first key, otherwise it will override color
# kwarg (i.e., cannot use setdefault here)
if 'facecolor' not in kwargs:
kws = {'facecolor': (0.5, 0.5, 0.5, 0.3)}
kws.update(kwargs)
else:
kws = kwargs
profile = self.profile
profile_error = self.profile_error
if self.unit is not None:
profile = profile.value
profile_error = profile_error.value
ymin = profile - profile_error
ymax = profile + profile_error
return ax.fill_between(self.radius, ymin, ymax, **kws)