# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""
This module defines the GriddedPSFModel and related tools.
"""
import copy
import io
import itertools
import os
import warnings
from functools import lru_cache
import astropy
import numpy as np
from astropy.io import fits, registry
from astropy.io.fits.verify import VerifyWarning
from astropy.modeling import Fittable2DModel, Parameter
from astropy.nddata import NDData, reshape_as_blocks
from astropy.utils import minversion
from astropy.visualization import simple_norm
from photutils.utils._parameters import as_pair
__all__ = ['GriddedPSFModel', 'ModelGridPlotMixin', 'stdpsf_reader',
'webbpsf_reader', 'STDPSFGrid']
__doctest_skip__ = ['GriddedPSFModelRead', 'STDPSFGrid']
[docs]
class ModelGridPlotMixin:
"""
Mixin class to plot a grid of ePSF models.
"""
def _reshape_grid(self, data):
"""
Reshape the 3D ePSF grid as a 2D array of horizontally and
vertically stacked ePSFs.
"""
nypsfs = self._ygrid.shape[0]
nxpsfs = self._xgrid.shape[0]
ny, nx = self.data.shape[1:]
data.shape = (nypsfs, nxpsfs, ny, nx)
return data.transpose([0, 2, 1, 3]).reshape(nypsfs * ny, nxpsfs * nx)
[docs]
def plot_grid(self, *, ax=None, vmax_scale=None, peak_norm=False,
deltas=False, cmap=None, dividers=True,
divider_color='darkgray', divider_ls='-', figsize=None):
"""
Plot the grid of ePSF models.
Parameters
----------
ax : `matplotlib.axes.Axes` or `None`, optional
The matplotlib axes on which to plot. If `None`, then the
current `~matplotlib.axes.Axes` instance is used.
vmax_scale : float, optional
Scale factor to apply to the image stretch limits. This
value is multiplied by the peak ePSF value to determine the
plotting ``vmax``. The defaults are 1.0 for plotting the
ePSF data and 0.03 for plotting the ePSF difference data
(``deltas=True``). If ``deltas=True``, the ``vmin`` is set
to ``-vmax``. If ``deltas=False`` the ``vmin`` is set to
``vmax`` / 1e4.
peak_norm : bool, optional
Whether to normalize the ePSF data by the peak value. The
default shows the ePSF flux per pixel.
deltas : bool, optional
Set to `True` to show the differences between each ePSF
and the average ePSF.
cmap : str or `matplotlib.colors.Colormap`, optional
The colormap to use. The default is `None`, which uses
the 'viridis' colormap for plotting ePSF data and the
'gray_r' colormap for plotting the ePSF difference data
(``deltas=True``).
dividers : bool, optional
Whether to show divider lines between the ePSFs.
divider_color, divider_ls : str, optional
Matplotlib color and linestyle options for the divider
lines between ePSFs. These keywords have no effect unless
``show_dividers=True``.
figsize : (float, float), optional
The figure (width, height) in inches.
Returns
-------
fig : `matplotlib.figure.Figure`
The matplotlib figure object. This will be the current
figure if ``ax=None``. Use ``fig.savefig()`` to save the
figure to a file.
Note that when calling this method in a notebook, if you do
not store the return value of this function, the figure will
be displayed twice due to the REPL automatically displaying
the return value of the last function call. Alternatively,
you can append a semicolon to the end of the function call
to suppress the display of the return value.
"""
import matplotlib.pyplot as plt
from matplotlib import cm
data = self.data.copy()
if deltas:
# Compute mean ignoring any blank (all zeros) ePSFs.
# This is the case for MIRI with its non-square FOV.
mask = np.zeros(data.shape[0], dtype=bool)
for i, arr in enumerate(data):
if np.count_nonzero(arr) == 0:
mask[i] = True
data -= np.mean(data[~mask], axis=0)
data[mask] = 0.0
data = self._reshape_grid(data)
if ax is None:
if figsize is None and self.meta.get('detector', '') == 'NRCSW':
figsize = (20, 8)
fig, ax = plt.subplots(figsize=figsize)
else:
fig = plt.gcf()
if peak_norm: # normalize relative to peak
if data.max() != 0:
data /= data.max()
if deltas:
if cmap is None:
cmap = cm.gray_r.copy()
if vmax_scale is None:
vmax_scale = 0.03
vmax = data.max() * vmax_scale
vmin = -vmax
if minversion(astropy, '6.1.dev'):
norm = simple_norm(data, 'linear', vmin=vmin, vmax=vmax)
else:
norm = simple_norm(data, 'linear', min_cut=vmin, max_cut=vmax)
else:
if cmap is None:
cmap = cm.viridis.copy()
if vmax_scale is None:
vmax_scale = 1.0
vmax = data.max() * vmax_scale
vmin = vmax / 1.0e4
if minversion(astropy, '6.1.dev'):
norm = simple_norm(data, 'log', vmin=vmin, vmax=vmax,
log_a=1.0e4)
else:
norm = simple_norm(data, 'log', min_cut=vmin, max_cut=vmax,
log_a=1.0e4)
# Set up the coordinate axes to later set tick labels based on
# detector ePSF coordinates. This sets up axes to have, behind the
# scenes, the ePSFs centered at integer coords 0, 1, 2, 3 etc.
# extent = (left, right, bottom, top)
nypsfs = self._ygrid.shape[0]
nxpsfs = self._xgrid.shape[0]
extent = [-0.5, nxpsfs - 0.5, -0.5, nypsfs - 0.5]
ax.imshow(data, extent=extent, norm=norm, cmap=cmap, origin='lower')
# Use the axes set up above to set appropriate tick labels
xticklabels = self._xgrid.astype(int)
yticklabels = self._ygrid.astype(int)
if self.meta.get('detector', '') == 'NRCSW':
xticklabels = list(xticklabels[0:5]) * 4
yticklabels = list(yticklabels[0:5]) * 2
ax.set_xticks(np.arange(nxpsfs))
ax.set_xticklabels(xticklabels)
ax.set_xlabel('ePSF location in detector X pixels')
ax.set_yticks(np.arange(nypsfs))
ax.set_yticklabels(yticklabels)
ax.set_ylabel('ePSF location in detector Y pixels')
if dividers:
for ix in range(nxpsfs):
ax.axvline(ix + 0.5, color=divider_color, ls=divider_ls)
for iy in range(nypsfs):
ax.axhline(iy + 0.5, color=divider_color, ls=divider_ls)
instrument = self.meta.get('instrument', '')
if not instrument:
# WebbPSF output
instrument = self.meta.get('instrume', '')
detector = self.meta.get('detector', '')
filtername = self.meta.get('filter', '')
# WebbPSF outputs a tuple with the comment in the second element
if isinstance(instrument, (tuple, list, np.ndarray)):
instrument = instrument[0]
if isinstance(detector, (tuple, list, np.ndarray)):
detector = detector[0]
if isinstance(filtername, (tuple, list, np.ndarray)):
filtername = filtername[0]
title = f'{instrument} {detector} {filtername}'
if title != '':
# add extra space at end
title += ' '
if deltas:
ax.set_title(f'{title}(ePSFs − <ePSF>)')
if peak_norm:
label = 'Difference relative to average ePSF peak'
else:
label = 'Difference relative to average ePSF values'
else:
ax.set_title(f'{title}ePSFs')
if peak_norm:
label = 'Scale relative to ePSF peak pixel'
else:
label = 'ePSF flux per pixel'
cbar = plt.colorbar(label=label, mappable=ax.images[0])
if not deltas:
cbar.ax.set_yscale('log')
if self.meta.get('detector', '') == 'NRCSW':
# NIRCam NRCSW STDPSF files contain all detectors.
# The plot gets extra divider lines and SCA name labels.
nxpsfs = len(self._xgrid)
nypsfs = len(self._ygrid)
plt.axhline(nypsfs / 2 - 0.5, color='orange')
for i in range(1, 4):
ax.axvline(nxpsfs / 4 * i - 0.5, color='orange')
det_labels = [['A1', 'A3', 'B4', 'B2'], ['A2', 'A4', 'B3', 'B1']]
for i in range(2):
for j in range(4):
ax.text(j * nxpsfs / 4 - 0.45,
(i + 1) * nypsfs / 2 - 0.55,
det_labels[i][j], color='orange',
verticalalignment='top', fontsize=12)
fig.tight_layout()
return fig
class GriddedPSFModelRead(registry.UnifiedReadWrite):
"""
Read and parse a FITS file into a `GriddedPSFModel` instance.
This class enables the astropy unified I/O layer for
`GriddedPSFModel`. This allows easily reading a file in different
supported data formats using syntax such as::
>>> from photutils.psf import GriddedPSFModel
>>> psf_model = GriddedPSFModel.read('filename.fits', format=format)
Get help on the available readers for `GriddedPSFModel` using the
``help()`` method::
>>> # Get help reading Table and list supported formats
>>> GriddedPSFModel.read.help()
>>> # Get detailed help on the STSPSF FITS reader
>>> GriddedPSFModel.read.help('stdpsf')
>>> # Get detailed help on the WebbPSF FITS reader
>>> GriddedPSFModel.read.help('webbpsf')
>>> # Print list of available formats
>>> GriddedPSFModel.read.list_formats()
Parameters
----------
*args : tuple, optional
Positional arguments passed through to data reader. If supplied
the first argument is typically the input filename.
format : str
File format specifier.
**kwargs : dict, optional
Keyword arguments passed through to data reader.
Returns
-------
out : `~photutils.psf.GriddedPSFModel`
A gridded ePSF model corresponding to FITS file contents.
"""
def __init__(self, instance, cls):
super().__init__(instance, cls, 'read', registry=None)
# uses default global registry
def __call__(self, *args, **kwargs):
return self.registry.read(self._cls, *args, **kwargs)
[docs]
class GriddedPSFModel(ModelGridPlotMixin, Fittable2DModel):
"""
A fittable 2D model containing a grid ePSF models.
The ePSF models are defined at fiducial detector locations and are
bilinearly interpolated to calculate an ePSF model at an arbitrary
(x, y) detector position.
When evaluating this model, it cannot be called with x and y arrays
that have greater than 2 dimensions.
Parameters
----------
nddata : `~astropy.nddata.NDData`
A `~astropy.nddata.NDData` object containing the grid of
reference ePSF arrays. The data attribute must contain a 3D
`~numpy.ndarray` containing a stack of the 2D ePSFs with a shape
of ``(N_psf, ePSF_ny, ePSF_nx)``.
The meta attribute must be `dict` containing the following:
* ``'grid_xypos'``: A list of the (x, y) grid positions of
each reference ePSF. The order of positions should match the
first axis of the 3D `~numpy.ndarray` of ePSFs. In other
words, ``grid_xypos[i]`` should be the (x, y) position of
the reference ePSF defined in ``data[i]``.
* ``'oversampling'``: The integer oversampling factor(s) of
the ePSF. If ``oversampling`` is a scalar then it will be
used for both axes. If ``oversampling`` has two elements,
they must be in ``(y, x)`` order.
The meta attribute may contain other properties such as the
telescope, instrument, detector, and filter of the ePSF.
Methods
-------
read(\\*args, \\**kwargs)
Class method to create a `GriddedPSFModel` instance from a
STDPSF FITS file. This method uses :func:`stdpsf_reader` with
the provided parameters.
Notes
-----
Internally, the grid of ePSFs will be arranged and stored such that
it is sorted first by y and then by x.
"""
flux = Parameter(description='Intensity scaling factor for the ePSF '
'model.', default=1.0)
x_0 = Parameter(description='x position in the output coordinate grid '
'where the model is evaluated.', default=0.0)
y_0 = Parameter(description='y position in the output coordinate grid '
'where the model is evaluated.', default=0.0)
read = registry.UnifiedReadWriteMethod(GriddedPSFModelRead)
def __init__(self, nddata, *, flux=flux.default, x_0=x_0.default,
y_0=y_0.default, fill_value=0.0):
self._validate_data(nddata)
self.data, self.grid_xypos = self._define_grid(nddata)
# use _meta to avoid the meta descriptor
self._meta = nddata.meta.copy()
self.oversampling = as_pair('oversampling',
nddata.meta['oversampling'],
lower_bound=(0, 1))
self.fill_value = fill_value
self._grid_xpos, self._grid_ypos = np.transpose(self.grid_xypos)
self._xgrid = np.unique(self._grid_xpos) # also sorts values
self._ygrid = np.unique(self._grid_ypos) # also sorts values
self.meta['grid_shape'] = (len(self._ygrid), len(self._xgrid))
if (len(list(itertools.product(self._xgrid, self._ygrid)))
!= len(self.grid_xypos)):
raise ValueError('"grid_xypos" must form a regular grid.')
self._xidx = np.arange(self.data.shape[2], dtype=float)
self._yidx = np.arange(self.data.shape[1], dtype=float)
# Here we avoid decorating the instance method with @lru_cache
# to prevent memory leaks; we set maxsize=128 to prevent the
# cache from growing too large.
self._calc_interpolator = lru_cache(maxsize=128)(
self._calc_interpolator_uncached)
super().__init__(flux, x_0, y_0)
@staticmethod
def _validate_data(data):
if not isinstance(data, NDData):
raise TypeError('data must be an NDData instance.')
if data.data.ndim != 3:
raise ValueError('The NDData data attribute must be a 3D numpy '
'ndarray')
if 'grid_xypos' not in data.meta:
raise ValueError('"grid_xypos" must be in the nddata meta '
'dictionary.')
if len(data.meta['grid_xypos']) != data.data.shape[0]:
raise ValueError('The length of grid_xypos must match the number '
'of input ePSFs.')
if 'oversampling' not in data.meta:
raise ValueError('"oversampling" must be in the nddata meta '
'dictionary.')
def _define_grid(self, nddata):
"""
Sort the input ePSF data into a regular grid where the ePSFs are
sorted first by y and then by x.
Parameters
----------
nddata : `~astropy.nddata.NDData`
The input NDData object containing the ePSF data.
Returns
-------
data : 3D `~numpy.ndarray`
The 3D array of ePSFs.
grid_xypos : array of (x, y) pairs
The (x, y) positions of the ePSFs, sorted first by y and
then by x.
"""
grid_xypos = np.array(nddata.meta['grid_xypos'])
# sort by y and then by x
idx = np.lexsort((grid_xypos[:, 0], grid_xypos[:, 1]))
grid_xypos = grid_xypos[idx]
data = nddata.data[idx]
return data, grid_xypos
def _cls_info(self):
cls_info = []
keys = ('STDPSF', 'instrument', 'detector', 'filter', 'grid_shape')
for key in keys:
if key in self.meta:
name = key.capitalize() if key != 'STDPSF' else key
cls_info.append((name, self.meta[key]))
cls_info.extend([('Number of ePSFs', len(self.grid_xypos)),
('ePSF shape (oversampled pixels)',
self.data.shape[1:]),
('Oversampling', tuple(self.oversampling))])
return cls_info
def __str__(self):
return self._format_str(keywords=self._cls_info())
[docs]
def copy(self):
"""
Return a copy of this model where only the model parameters are
copied.
All other copied model attributes are references to the
original model. This prevents copying the ePSF grid data, which
may contain a large array.
This method is useful if one is interested in only changing
the model parameters in a model copy. It is used in the PSF
photometry classes during model fitting.
Use the `deepcopy` method if you want to copy all of the model
attributes, including the ePSF grid data.
"""
newcls = object.__new__(self.__class__)
for key, val in self.__dict__.items():
if key in self.param_names: # copy only the parameter values
newcls.__dict__[key] = copy.copy(val)
else:
newcls.__dict__[key] = val
return newcls
[docs]
def deepcopy(self):
"""
Return a deep copy of this model.
"""
return copy.deepcopy(self)
[docs]
def clear_cache(self):
"""
Clear the internal cache.
"""
self._calc_interpolator.cache_clear()
def _cache_info(self):
"""
Return information about the internal cache.
"""
return self._calc_interpolator.cache_info()
@staticmethod
def _find_start_idx(data, x):
"""
Find the index of the lower bound where ``x`` should be inserted
into ``a`` to maintain order.
The index of the upper bound is the index of the lower bound
plus 2. Both bound indices must be within the array.
Parameters
----------
data : 1D `~numpy.ndarray`
The 1D array to search.
x : float
The value to insert.
Returns
-------
index : int
The index of the lower bound.
"""
idx = np.searchsorted(data, x)
if idx == 0:
idx0 = 0
elif idx == len(data): # pragma: no cover
idx0 = idx - 2
else:
idx0 = idx - 1
return idx0
def _find_bounding_points(self, x, y):
"""
Find the indices of the grid points that bound the input
``(x, y)`` position.
Parameters
----------
x, y : float
The ``(x, y)`` position where the ePSF is to be evaluated.
The position must be inside the region defined by the grid
of ePSF positions.
Returns
-------
indices : list of int
A list of indices of the bounding grid points.
"""
x0 = self._find_start_idx(self._xgrid, x)
y0 = self._find_start_idx(self._ygrid, y)
xypoints = list(itertools.product(self._xgrid[x0:x0 + 2],
self._ygrid[y0:y0 + 2]))
# find the grid_xypos indices of the reference xypoints
indices = []
for xx, yy in xypoints:
indices.append(np.argsort(np.hypot(self._grid_xpos - xx,
self._grid_ypos - yy))[0])
return indices
@staticmethod
def _bilinear_interp(xyref, zref, xi, yi):
"""
Perform bilinear interpolation of four 2D arrays located at
points on a regular grid.
Parameters
----------
xyref : list of 4 (x, y) pairs
A list of 4 ``(x, y)`` pairs that form a rectangle.
zref : 3D `~numpy.ndarray`
A 3D `~numpy.ndarray` of shape ``(4, nx, ny)``. The first
axis corresponds to ``xyref``, i.e., ``refdata[0, :, :]`` is
the 2D array located at ``xyref[0]``.
xi, yi : float
The ``(xi, yi)`` point at which to perform the
interpolation. The ``(xi, yi)`` point must lie within the
rectangle defined by ``xyref``.
Returns
-------
result : 2D `~numpy.ndarray`
The 2D interpolated array.
"""
xyref = [tuple(i) for i in xyref]
idx = sorted(range(len(xyref)), key=xyref.__getitem__)
xyref = sorted(xyref) # sort by x, then y
(x0, y0), (_x0, y1), (x1, _y0), (_x1, _y1) = xyref
if x0 != _x0 or x1 != _x1 or y0 != _y0 or y1 != _y1:
raise ValueError('The refxy points do not form a rectangle.')
if not np.isscalar(xi):
xi = xi[0]
if not np.isscalar(yi):
yi = yi[0]
if not x0 <= xi <= x1 or not y0 <= yi <= y1:
raise ValueError('The (x, y) input is not within the rectangle '
'defined by xyref.')
data = np.asarray(zref)[idx]
weights = np.array([(x1 - xi) * (y1 - yi), (x1 - xi) * (yi - y0),
(xi - x0) * (y1 - yi), (xi - x0) * (yi - y0)])
norm = (x1 - x0) * (y1 - y0)
return np.sum(data * weights[:, None, None], axis=0) / norm
def _calc_interpolator_uncached(self, x_0, y_0):
"""
Return the local interpolation function for the ePSF model at
(x_0, y_0).
Note that the interpolator will be cached by _calc_interpolator.
It can be cleared by calling the clear_cache method.
"""
from scipy.interpolate import RectBivariateSpline
if (x_0 < self._xgrid[0] or x_0 > self._xgrid[-1]
or y_0 < self._ygrid[0] or y_0 > self._ygrid[-1]):
# position is outside of the grid, so simply use the
# closest reference ePSF
ref_index = np.argsort(np.hypot(self._grid_xpos - x_0,
self._grid_ypos - y_0))[0]
psf_image = self.data[ref_index, :, :]
else:
# find the four bounding reference ePSFs and interpolate
ref_indices = self._find_bounding_points(x_0, y_0)
xyref = self.grid_xypos[ref_indices]
psfs = self.data[ref_indices, :, :]
psf_image = self._bilinear_interp(xyref, psfs, x_0, y_0)
interpolator = RectBivariateSpline(self._xidx, self._yidx,
psf_image.T, kx=3, ky=3, s=0)
return interpolator
[docs]
def evaluate(self, x, y, flux, x_0, y_0):
"""
Evaluate the `GriddedPSFModel` for the input parameters.
"""
if x.ndim > 2:
raise ValueError('x and y must be 1D or 2D.')
# NOTE: the astropy base Model.__call__() method converts scalar
# inputs to size-1 arrays before calling evaluate().
if not np.isscalar(flux):
flux = flux[0]
if not np.isscalar(x_0):
x_0 = x_0[0]
if not np.isscalar(y_0):
y_0 = y_0[0]
# Calculate the local interpolation function for the ePSF at
# (x_0, y_0). Only the integer part of the position is input in
# order to have effective caching.
interpolator = self._calc_interpolator(int(x_0), int(y_0))
# now evaluate the ePSF at the (x_0, y_0) subpixel position on
# the input (x, y) values
xi = self.oversampling[1] * (np.asarray(x, dtype=float) - x_0)
yi = self.oversampling[0] * (np.asarray(y, dtype=float) - y_0)
# define origin at the ePSF image center
ny, nx = self.data.shape[1:]
xi += (nx - 1) / 2
yi += (ny - 1) / 2
evaluated_model = flux * interpolator.ev(xi, yi)
if self.fill_value is not None:
# find indices of pixels that are outside the input pixel
# grid and set these pixels to the fill_value
invalid = (((xi < 0) | (xi > nx - 1))
| ((yi < 0) | (yi > ny - 1)))
evaluated_model[invalid] = self.fill_value
return evaluated_model
def _read_stdpsf(filename):
with warnings.catch_warnings():
warnings.simplefilter('ignore', VerifyWarning)
with fits.open(filename, ignore_missing_end=True) as hdulist:
header = hdulist[0].header
data = hdulist[0].data
try:
npsfs = header['NAXIS3']
nxpsfs = header['NXPSFS']
nypsfs = header['NYPSFS']
except KeyError as exc:
raise ValueError('Invalid STDPDF FITS file.') from exc
if 'IPSFX01' in header:
xgrid = [header[f'IPSFX{i:02d}'] for i in range(1, nxpsfs + 1)]
ygrid = [header[f'JPSFY{i:02d}'] for i in range(1, nypsfs + 1)]
elif 'IPSFXA5' in header:
xgrid = []
ygrid = []
xkeys = ('IPSFXA5', 'IPSFXB5', 'IPSFXC5', 'IPSFXD5')
for xkey in xkeys:
xgrid.extend([int(n) for n in header[xkey].split()])
ykeys = ('JPSFYA5', 'JPSFYB5')
for ykey in ykeys:
ygrid.extend([int(n) for n in header[ykey].split()])
else:
raise ValueError('Unknown STDPSF FITS file.')
# STDPDF FITS positions are 1-indexed
xgrid = np.array(xgrid) - 1
ygrid = np.array(ygrid) - 1
# (nypsfs, nxpsfs)
# (6, 6) # WFPC2, 4 det
# (1, 1) # ACS/HRC
# (10, 9) # ACS/WFC, 2 det
# (3, 3) # WFC3/IR
# (8, 7) # WFC3/UVIS, 2 det
# (5, 5) # NIRISS
# (5, 5) # NIRCam SW
# (10, 20) # NIRCam SW (NRCSW), 8 det
# (5, 5) # NIRCam LW
# (3, 3) # MIRI
grid_data = {'data': data,
'npsfs': npsfs,
'nxpsfs': nxpsfs,
'nypsfs': nypsfs,
'xgrid': xgrid,
'ygrid': ygrid}
return grid_data
def _split_detectors(grid_data, detector_data, detector_id):
"""
Split an ePSF array into individual detectors.
In particular::
* HST WFPC2 STDPSF file contains 4 detectors
* HST ACS/WFC STDPSF file contains 2 detectors
* HST WFC3/UVIS STDPSF file contains 2 detectors
* JWST NIRCam "NRCSW" STDPSF file contains 8 detectors
"""
data = grid_data['data']
npsfs = grid_data['npsfs']
nxpsfs = grid_data['nxpsfs']
nypsfs = grid_data['nypsfs']
xgrid = grid_data['xgrid']
ygrid = grid_data['ygrid']
nxdet = detector_data['nxdet']
nydet = detector_data['nydet']
det_map = detector_data['det_map']
det_size = detector_data['det_size']
ii = np.arange(npsfs).reshape((nypsfs, nxpsfs))
nxpsfs //= nxdet
nypsfs //= nydet
ndet = nxdet * nydet
ii = reshape_as_blocks(ii, (nypsfs, nxpsfs))
ii = ii.reshape(ndet, npsfs // ndet)
# detector_id -> index
det_idx = det_map[detector_id]
idx = ii[det_idx]
data = data[idx]
xp = det_idx % nxdet
i0 = xp * nxpsfs
i1 = i0 + nxpsfs
xgrid = xgrid[i0:i1] - xp * det_size
if det_idx < nxdet:
ygrid = ygrid[:nypsfs]
else:
ygrid = ygrid[nypsfs:] - det_size
return data, xgrid, ygrid
def _split_wfc_uvis(grid_data, detector_id):
if detector_id is None:
raise ValueError('detector_id must be specified for ACS/WFC and '
'WFC3/UVIS ePSFs.')
if detector_id not in (1, 2):
raise ValueError('detector_id must be 1 or 2.')
# ACS/WFC1 and WFC3/UVIS1 chip1 (sci, 2) are above chip2 (sci, 1)
# in y-pixel coordinates
xgrid = grid_data['xgrid']
ygrid = grid_data['ygrid']
ygrid = ygrid.reshape((2, ygrid.shape[0] // 2))[detector_id - 1]
if detector_id == 2:
ygrid -= 2048
npsfs = grid_data['npsfs']
data = grid_data['data']
data_ny, data_nx = data.shape[1:]
data = data.reshape((2, npsfs // 2, data_ny, data_nx))[detector_id - 1]
return data, xgrid, ygrid
def _split_wfpc2(grid_data, detector_id):
if detector_id is None:
raise ValueError('detector_id must be specified for WFPC2 ePSFs')
if detector_id not in range(1, 5):
raise ValueError('detector_id must be between 1 and 4, inclusive')
nxdet = 2
nydet = 2
det_size = 800
# det (exten:idx)
# WF2 (2:2) PC (1:3)
# WF3 (3:0) WF4 (4:1)
det_map = {1: 3, 2: 2, 3: 0, 4: 1}
detector_data = {'nxdet': nxdet,
'nydet': nydet,
'det_size': det_size,
'det_map': det_map}
return _split_detectors(grid_data, detector_data, detector_id)
def _split_nrcsw(grid_data, detector_id):
if detector_id is None:
raise ValueError('detector_id must be specified for NRCSW ePSFs')
if detector_id not in range(1, 9):
raise ValueError('detector_id must be between 1 and 8, inclusive')
nxdet = 4
nydet = 2
det_size = 2048
# det (ext:idx)
# A2 (2:4) A4 (4:5) B3 (7:6) B1 (5:7)
# A1 (1:0) A3 (3:1) B4 (8:2) B2 (6:3)
det_map = {1: 0, 3: 1, 8: 2, 6: 3, 2: 4, 4: 5, 7: 6, 5: 7}
detector_data = {'nxdet': nxdet,
'nydet': nydet,
'det_size': det_size,
'det_map': det_map}
return _split_detectors(grid_data, detector_data, detector_id)
def _get_metadata(filename, detector_id):
"""
Get metadata from the filename and ``detector_id``.
"""
if isinstance(filename, io.FileIO):
filename = filename.name
parts = os.path.basename(filename).strip('.fits').split('_')
if len(parts) not in (3, 4):
return None # filename from astropy download_file
detector, filter_name = parts[1:3]
meta = {'STDPSF': filename,
'detector': detector,
'filter': filter_name}
if detector_id is not None:
detector_map = {'WFPC2': ['HST/WFPC2', 'WFPC2'],
'ACSHRC': ['HST/ACS', 'HRC'],
'ACSWFC': ['HST/ACS', 'WFC'],
'WFC3UV': ['HST/WFC3', 'UVIS'],
'WFC3IR': ['HST/WFC3', 'IR'],
'NRCSW': ['JWST/NIRCam', 'NRCSW'],
'NRCA1': ['JWST/NIRCam', 'A1'],
'NRCA2': ['JWST/NIRCam', 'A2'],
'NRCA3': ['JWST/NIRCam', 'A3'],
'NRCA4': ['JWST/NIRCam', 'A4'],
'NRCB1': ['JWST/NIRCam', 'B1'],
'NRCB2': ['JWST/NIRCam', 'B2'],
'NRCB3': ['JWST/NIRCam', 'B3'],
'NRCB4': ['JWST/NIRCam', 'B4'],
'NRCAL': ['JWST/NIRCam', 'A5'],
'NRCBL': ['JWST/NIRCam', 'B5'],
'NIRISS': ['JWST/NIRISS', 'NIRISS'],
'MIRI': ['JWST/MIRI', 'MIRIM']}
try:
inst_det = detector_map[detector]
except KeyError as exc:
raise ValueError(f'Unknown detector {detector}.') from exc
if inst_det[1] == 'WFPC2':
wfpc2_map = {1: 'PC', 2: 'WF2', 3: 'WF3', 4: 'WF4'}
inst_det[1] = wfpc2_map[detector_id]
if inst_det[1] in ('WFC', 'UVIS'):
chip = 2 if detector_id == 1 else 1
inst_det[1] = f'{inst_det[1]}{chip}'
if inst_det[1] == 'NRCSW':
sw_map = {1: 'A1', 2: 'A2', 3: 'A3', 4: 'A4',
5: 'B1', 6: 'B2', 7: 'B3', 8: 'B4'}
inst_det[1] = sw_map[detector_id]
meta['instrument'] = inst_det[0]
meta['detector'] = inst_det[1]
return meta
[docs]
def stdpsf_reader(filename, detector_id=None):
"""
Generate a `~photutils.psf.GriddedPSFModel` from a STScI
standard-format ePSF (STDPSF) FITS file.
.. note::
Instead of being used directly, this function is intended to be
used via the `GriddedPSFModel` ``read`` method, e.g.,
``model = GriddedPSFModel.read(filename, format='stdpsf')``.
STDPSF files are FITS files that contain a 3D array of ePSFs with
the header detailing where the fiducial ePSFs are located in the
detector coordinate frame.
The oversampling factor for STDPSF FITS files is assumed to be 4.
Parameters
----------
filename : str
The name of the STDPDF FITS file. A URL can also be used.
detector_id : `None` or int, optional
For STDPSF files that contain ePSF grids for multiple detectors,
one will need to identify the detector for which to extract the
ePSF grid. This keyword is ignored for STDPSF files that do not
contain ePSF grids for multiple detectors.
For WFPC2, the detector value (int) should be:
- 1: PC, 2: WF2, 3: WF3, 4: WF4
For ACS/WFC and WFC3/UVIS, the detector value should be:
- 1: WFC2, UVIS2 (sci, 1)
- 2: WFC1, UVIS1 (sci, 2)
Note that for these two instruments, detector 1 is above
detector 2 in the y direction. However, in the FLT FITS files,
the (sci, 1) extension corresponds to detector 2 (WFC2, UVIS2)
and the (sci, 2) extension corresponds to detector 1 (WFC1,
UVIS1).
For NIRCam NRCSW files that contain ePSF grids for all 8 SW
detectors, the detector value should be:
* 1: A1, 2: A2, 3: A3, 4: A4
* 5: B1, 6: B2, 7: B3, 8: B4
Returns
-------
model : `~photutils.psf.GriddedPSFModel`
The gridded ePSF model.
"""
grid_data = _read_stdpsf(filename)
npsfs = grid_data['npsfs']
if npsfs in (90, 56, 36, 200):
if npsfs in (90, 56): # ACS/WFC or WFC3/UVIS data (2 chips)
data, xgrid, ygrid = _split_wfc_uvis(grid_data, detector_id)
elif npsfs == 36: # WFPC2 data (4 chips)
data, xgrid, ygrid = _split_wfpc2(grid_data, detector_id)
elif npsfs == 200: # NIRCam SW data (8 chips)
data, xgrid, ygrid = _split_nrcsw(grid_data, detector_id)
else:
raise ValueError('Unknown detector or STDPSF format')
else:
data = grid_data['data']
xgrid = grid_data['xgrid']
ygrid = grid_data['ygrid']
# itertools.product iterates over the last input first
xy_grid = [yx[::-1] for yx in itertools.product(ygrid, xgrid)]
oversampling = 4 # assumption for STDPSF files
nxpsfs = xgrid.shape[0]
nypsfs = ygrid.shape[0]
meta = {'grid_xypos': xy_grid,
'oversampling': oversampling,
'nxpsfs': nxpsfs,
'nypsfs': nypsfs}
# try to get additional metadata from the filename because this
# information is not currently available in the FITS headers
file_meta = _get_metadata(filename, detector_id)
if file_meta is not None:
meta.update(file_meta)
return GriddedPSFModel(NDData(data, meta=meta))
[docs]
def webbpsf_reader(filename):
"""
Generate a `~photutils.psf.GriddedPSFModel` from a WebbPSF
FITS file containing a PSF grid.
.. note::
Instead of being used directly, this function is intended to be
used via the `GriddedPSFModel` ``read`` method, e.g., ``model =
GriddedPSFModel.read(filename, format='webbpsf')``.
The WebbPSF FITS file contain a 3D array of ePSFs with the header
detailing where the fiducial ePSFs are located in the detector
coordinate frame.
Parameters
----------
filename : str
The name of the WebbPSF FITS file. A URL can also be used.
Returns
-------
model : `~photutils.psf.GriddedPSFModel`
The gridded ePSF model.
"""
with warnings.catch_warnings():
warnings.simplefilter('ignore', VerifyWarning)
with fits.open(filename, ignore_missing_end=True) as hdulist:
header = hdulist[0].header
data = hdulist[0].data
# handle the case of only one 2D PSF
data = np.atleast_3d(data)
if not any('DET_YX' in key for key in header.keys()):
raise ValueError('Invalid WebbPSF FITS file; missing "DET_YX{}" '
'header keys.')
if 'OVERSAMP' not in header.keys():
raise ValueError('Invalid WebbPSF FITS file; missing "OVERSAMP" '
'header key.')
# convert header to meta dict
header = header.copy(strip=True)
header.pop('HISTORY', None)
header.pop('COMMENT', None)
header.pop('', None)
meta = dict(header)
meta = {key.lower(): meta[key] for key in meta} # user lower-case keys
# define grid_xypos from DET_YX{} FITS header keywords
xypos = []
for key in meta.keys():
if 'det_yx' in key:
vals = header[key].lstrip('(').rstrip(')').split(',')
xypos.append((float(vals[0]), float(vals[1])))
meta['grid_xypos'] = xypos
if 'oversampling' not in meta:
meta['oversampling'] = meta['oversamp']
ndd = NDData(data, meta=meta)
return GriddedPSFModel(ndd)
def is_stdpsf(origin, filepath, fileobj, *args, **kwargs):
"""
Determine whether `origin` is a STDPSF FITS file.
Parameters
----------
origin : str or readable file-like
Path or file object containing a potential FITS file.
Returns
-------
is_stdpsf : bool
Returns `True` if the given file is a STDPSF FITS file.
"""
if filepath is not None:
extens = ('.fits', '.fits.gz', '.fit', '.fit.gz', '.fts', '.fts.gz')
isfits = filepath.lower().endswith(extens)
if isfits:
with warnings.catch_warnings():
warnings.simplefilter('ignore', VerifyWarning)
header = fits.getheader(filepath)
keys = ('NAXIS3', 'NXPSFS', 'NYPSFS')
for key in keys:
if key not in header:
return False
return True
return False
def is_webbpsf(origin, filepath, fileobj, *args, **kwargs):
"""
Determine whether `origin` is a WebbPSF FITS file.
Parameters
----------
origin : str or readable file-like
Path or file object containing a potential FITS file.
Returns
-------
is_webbpsf : bool
Returns `True` if the given file is a WebbPSF FITS file.
"""
if filepath is not None:
extens = ('.fits', '.fits.gz', '.fit', '.fit.gz', '.fts', '.fts.gz')
isfits = filepath.lower().endswith(extens)
if isfits:
with warnings.catch_warnings():
warnings.simplefilter('ignore', VerifyWarning)
header = fits.getheader(filepath)
keys = ('NAXIS3', 'OVERSAMP', 'DET_YX0')
for key in keys:
if key not in header:
return False
return True
return False
[docs]
class STDPSFGrid(ModelGridPlotMixin):
"""
Class to read and plot "STDPSF" format ePSF model grids.
STDPSF files are FITS files that contain a 3D array of ePSFs with
the header detailing where the fiducial ePSFs are located in the
detector coordinate frame.
The oversampling factor for STDPSF FITS files is assumed to be 4.
Parameters
----------
filename : str
The name of the STDPDF FITS file. A URL can also be used.
Examples
--------
>>> psfgrid = STDPSFGrid.read('STDPSF_ACSWFC_F814W.fits')
>>> psfgrid.plot_grid()
"""
def __init__(self, filename):
grid_data = _read_stdpsf(filename)
self.data = grid_data['data']
self._xgrid = grid_data['xgrid']
self._ygrid = grid_data['ygrid']
xy_grid = [yx[::-1] for yx in itertools.product(self._ygrid,
self._xgrid)]
oversampling = 4 # assumption for STDPSF files
self.grid_xypos = xy_grid
self.oversampling = as_pair('oversampling', oversampling,
lower_bound=(0, 1))
meta = {'grid_shape': (len(self._ygrid), len(self._xgrid)),
'grid_xypos': xy_grid,
'oversampling': oversampling}
# try to get additional metadata from the filename because this
# information is not currently available in the FITS headers
file_meta = _get_metadata(filename, None)
if file_meta is not None:
meta.update(file_meta)
self.meta = meta
def __str__(self):
cls_name = f'<{self.__class__.__module__}.{self.__class__.__name__}>'
cls_info = []
keys = ('STDPSF', 'detector', 'filter', 'grid_shape')
for key in keys:
if key in self.meta:
name = key.capitalize() if key != 'STDPSF' else key
cls_info.append((name, self.meta[key]))
cls_info.extend([('Number of ePSFs', len(self.grid_xypos)),
('ePSF shape (oversampled pixels)',
self.data.shape[1:]),
('Oversampling', self.oversampling)])
with np.printoptions(threshold=25, edgeitems=5):
fmt = [f'{key}: {val}' for key, val in cls_info]
return f'{cls_name}\n' + '\n'.join(fmt)
def __repr__(self):
return self.__str__()
with registry.delay_doc_updates(GriddedPSFModel):
registry.register_reader('stdpsf', GriddedPSFModel, stdpsf_reader)
registry.register_identifier('stdpsf', GriddedPSFModel, is_stdpsf)
registry.register_reader('webbpsf', GriddedPSFModel, webbpsf_reader)
registry.register_identifier('webbpsf', GriddedPSFModel, is_webbpsf)