# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""
Tools for generating 2D image cutouts.
"""
import numpy as np
from astropy.nddata import extract_array, overlap_slices
from astropy.utils import lazyproperty
from photutils.utils._deprecation import deprecated_positional_kwargs
__all__ = ['CutoutImage']
[docs]
class CutoutImage:
"""
Create a cutout object from a 2D array.
The returned object will contain a 2D cutout array. If
``copy=False`` (default), the cutout array is a view into the
original ``data`` array, otherwise the cutout array will contain a
copy of the original data.
Parameters
----------
data : `~numpy.ndarray`
The 2D data array from which to extract the cutout array.
position : tuple of 2 ints
The ``(y, x)`` position of the center of the cutout array with
respect to the ``data`` array.
shape : tuple of 2 ints
The shape of the cutout array along each axis in ``(ny, nx)``
order.
mode : {'trim', 'partial', 'strict'}, optional
The mode used for creating the cutout data array. For the
``'partial'`` and ``'trim'`` modes, a partial overlap
of the cutout array and the input ``data`` array is
sufficient. For the ``'strict'`` mode, the cutout array
has to be fully contained within the ``data`` array,
otherwise an `~astropy.nddata.utils.PartialOverlapError`
is raised. In all modes, non-overlapping arrays will raise
a `~astropy.nddata.utils.NoOverlapError`. In ``'partial'``
mode, positions in the cutout array that do not overlap with
the ``data`` array will be filled with ``fill_value``. In
``'trim'`` mode only the overlapping elements are returned, thus
the resulting cutout array may be smaller than the requested
``shape``.
fill_value : float or int, optional
If ``mode='partial'``, the value to fill pixels in the
cutout array that do not overlap with the input ``data``.
``fill_value`` must have the same ``dtype`` as the input
``data`` array.
copy : bool, optional
If `False` (default), then the cutout data will be a view into
the original ``data`` array. If `True`, then the cutout data
will hold a copy of the original ``data`` array.
Notes
-----
If the cutout array is not fully contained within the input ``data``
array and ``mode='partial'`` with ``fill_value=np.nan``, then the
input ``data`` must have a float data type.
Examples
--------
>>> import numpy as np
>>> from photutils.utils import CutoutImage
>>> data = np.arange(20.0).reshape(5, 4)
>>> cutout = CutoutImage(data, (2, 2), (3, 3))
>>> print(cutout.data) # doctest: +FLOAT_CMP
[[ 5. 6. 7.]
[ 9. 10. 11.]
[13. 14. 15.]]
>>> cutout2 = CutoutImage(data, (0, 0), (3, 3), mode='partial')
>>> print(cutout2.data) # doctest: +FLOAT_CMP
[[nan nan nan]
[nan 0. 1.]
[nan 4. 5.]]
"""
@deprecated_positional_kwargs(since='3.0', until='4.0')
def __init__(self, data, position, shape, mode='trim', fill_value=np.nan,
copy=False):
self.position = position
self.input_shape = tuple(shape)
self.mode = mode
self.fill_value = fill_value
self.copy = copy
data = np.asanyarray(data)
self._overlap_slices = overlap_slices(data.shape, shape, position,
mode=mode)
self.data = self._make_cutout(data)
self.shape = self.data.shape
def _make_cutout(self, data):
"""
Create the cutout data array.
Parameters
----------
data : `~numpy.ndarray`
The 2D data array from which to extract the cutout array.
Returns
-------
cutout_data : `~numpy.ndarray`
The 2D cutout data array.
"""
cutout_data = extract_array(data, self.input_shape, self.position,
mode=self.mode, fill_value=self.fill_value,
return_position=False)
if self.copy:
cutout_data = np.copy(cutout_data)
return cutout_data
# NumPy calls `obj.__array__(dtype)` positionally with
# `np.asarray(obj, dtype=int)`, so dtype must remain a positional
# argument.
def __array__(self, dtype=None, *, copy=None):
"""
Array representation of the cutout data array (e.g., for
matplotlib).
Parameters
----------
dtype : `~numpy.dtype`, optional
The data type of the output array. If `None`, then the
data type of the cutout data array is used.
copy : bool, optional
If `True`, then a copy of the underlying data array
is returned.
"""
return np.array(self.data, dtype=dtype, copy=copy)
def __str__(self):
cls_name = f'<{self.__class__.__module__}.{self.__class__.__name__}>'
props = f'Shape: {self.data.shape}'
return f'{cls_name}\n' + props
def __repr__(self):
return (f'{self.__class__.__name__}(position={self.position}, '
f'shape={self.shape})')
@lazyproperty
def slices_original(self):
"""
A tuple of slice objects in axis order for the minimal bounding
box of the cutout with respect to the original array.
For ``mode='partial'``, the slices are for the valid
(non-filled) cutout values.
"""
return self._overlap_slices[0]
@lazyproperty
def slices_cutout(self):
"""
A tuple of slice objects in axis order for the minimal bounding
box of the cutout with respect to the cutout array.
For ``mode='partial'``, the slices are for the valid
(non-filled) cutout values.
"""
return self._overlap_slices[1]
def _calc_bbox(self, slices):
"""
Calculate the `~photutils.aperture.BoundingBox` of the
rectangular bounding box from the input slices.
Parameters
----------
slices : tuple of slice
The slices for the bounding box.
"""
# Prevent circular import
from photutils.aperture import BoundingBox
return BoundingBox(ixmin=slices[1].start, ixmax=slices[1].stop,
iymin=slices[0].start, iymax=slices[0].stop)
@lazyproperty
def bbox_original(self):
"""
The `~photutils.aperture.BoundingBox` of the minimal rectangular
region of the cutout array with respect to the original array.
For ``mode='partial'``, the bounding box indices are for the
valid (non-filled) cutout values.
"""
return self._calc_bbox(self.slices_original)
@lazyproperty
def bbox_cutout(self):
"""
The `~photutils.aperture.BoundingBox` of the minimal rectangular
region of the cutout array with respect to the cutout array.
For ``mode='partial'``, the bounding box indices are for the
valid (non-filled) cutout values.
"""
return self._calc_bbox(self.slices_cutout)
def _calc_xyorigin(self, slices):
"""
Calculate the (x, y) origin, taking into account partial
overlaps.
Parameters
----------
slices : tuple of slice
The slices for the bounding box.
Returns
-------
xyorigin : `~numpy.ndarray`
The ``(x, y)`` integer index of the origin pixel of the
cutout with respect to the original array.
"""
xorigin, yorigin = (slices[1].start, slices[0].start)
if self.mode == 'partial':
yorigin -= self.slices_cutout[0].start
xorigin -= self.slices_cutout[1].start
return np.array((xorigin, yorigin))
@lazyproperty
def xyorigin(self):
"""
A `~numpy.ndarray` containing the ``(x, y)`` integer index of
the origin pixel of the cutout with respect to the original
array.
The origin index will be negative for cutouts with partial
overlaps.
"""
return self._calc_xyorigin(self.slices_original)
def _make_cutouts(data, xpos, ypos, cutout_shape, *, fill_value=0.0):
"""
Make 2D cutouts from a data array at the given positions.
Positions are rounded to the nearest integer pixel. Pixels that fall
outside the image boundary are filled with ``fill_value``.
Parameters
----------
data : 2D `~numpy.ndarray`
The 2D image array.
xpos : 1D `~numpy.ndarray`
The x pixel positions of the cutout centers.
ypos : 1D `~numpy.ndarray`
The y pixel positions of the cutout centers.
cutout_shape : tuple of int
The ``(ny, nx)`` shape of each cutout.
fill_value : float, optional
The value used to fill pixels that fall outside the image
boundary. The default is 0.0. Use ``np.nan`` when out-of-bounds
pixels must be distinguishable from real data (e.g., for
sigma-clipped statistics on partial cutouts).
Returns
-------
cutouts : 3D `~numpy.ndarray`
A 3D array of shape ``(n_sources, ny, nx)`` containing the
cutout data.
overlap_mask : 3D `~numpy.ndarray` of bool
A boolean array with the same shape as ``cutouts``. `True`
indicates a pixel that came from ``data``. `False` indicates
a pixel that was filled with ``fill_value`` because it fell
outside the image boundary.
Per-source overlap status can be derived from this mask:
* Fully inside the image: ``overlap_mask[i].all()``
* No overlap (entirely outside): ``~overlap_mask[i].any()``
* Partial overlap: neither of the above
"""
data = np.asarray(data)
if data.ndim != 2:
msg = 'data must be a 2D array'
raise ValueError(msg)
xpos = np.atleast_1d(np.asarray(xpos))
ypos = np.atleast_1d(np.asarray(ypos))
if xpos.ndim != 1 or ypos.ndim != 1:
msg = 'xpos and ypos must be 1D arrays'
raise ValueError(msg)
if len(xpos) != len(ypos):
msg = 'xpos and ypos must have the same length'
raise ValueError(msg)
if len(cutout_shape) != 2:
msg = 'cutout_shape must have exactly 2 elements'
raise ValueError(msg)
ky, kx = cutout_shape
hy, hx = ky // 2, kx // 2
yc = np.round(ypos).astype(int)
xc = np.round(xpos).astype(int)
# Build index grids: shape (n_sources, ky, kx)
dy = np.arange(ky) - hy
dx = np.arange(kx) - hx
y_idx = yc[:, np.newaxis, np.newaxis] + dy[np.newaxis, :, np.newaxis]
x_idx = xc[:, np.newaxis, np.newaxis] + dx[np.newaxis, np.newaxis, :]
# Mask of pixels inside the image boundary
overlap_mask = ((y_idx >= 0) & (y_idx < data.shape[0])
& (x_idx >= 0) & (x_idx < data.shape[1]))
# Clip out-of-bounds indices to valid range so numpy indexing
# doesn't raise. The out-of-bounds pixels are replaced below.
y_safe = np.clip(y_idx, 0, data.shape[0] - 1)
x_safe = np.clip(x_idx, 0, data.shape[1] - 1)
cutouts = np.where(overlap_mask, data[y_safe, x_safe], fill_value)
return cutouts, overlap_mask