# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""
Utility functions for the psf_matching subpackage.
"""
import numpy as np
from scipy.fft import fft2, fftshift, ifftshift
from scipy.ndimage import zoom
__all__ = ['resize_psf']
def _validate_kernel_inputs(source_psf, target_psf, window):
"""
Validate and prepare common inputs for kernel-making functions.
Parameters
----------
source_psf : array-like
The source PSF array.
target_psf : array-like
The target PSF array.
window : callable or None
The window function or None.
Returns
-------
source_psf : `~numpy.ndarray`
The validated and normalized source PSF as a float array.
target_psf : `~numpy.ndarray`
The validated and normalized target PSF as a float array.
Raises
------
ValueError
If the PSFs are not 2D arrays, have even dimensions, do not have
the same shape, or contain NaN or Inf values.
TypeError
If the input ``window`` is not callable.
"""
# Copy as float so in-place normalization doesn't modify inputs
source_psf = np.array(source_psf, dtype=float)
target_psf = np.array(target_psf, dtype=float)
_validate_psf(source_psf, 'source_psf')
_validate_psf(target_psf, 'target_psf')
if source_psf.shape != target_psf.shape:
msg = ('source_psf and target_psf must have the same shape '
'(i.e., registered with the same pixel scale).')
raise ValueError(msg)
if window is not None and not callable(window):
msg = 'window must be a callable.'
raise TypeError(msg)
# Ensure input PSFs are normalized
source_psf /= source_psf.sum()
target_psf /= target_psf.sum()
return source_psf, target_psf
def _validate_psf(psf, name):
"""
Validate that a PSF is 2D with odd dimensions.
Parameters
----------
psf : `~numpy.ndarray`
The PSF array to validate.
name : str
The parameter name used in error messages.
Raises
------
ValueError
If the PSF is not 2D, has even dimensions, or contains NaN or
Inf values.
"""
if psf.ndim != 2:
msg = f'{name} must be a 2D array.'
raise ValueError(msg)
if psf.shape[0] % 2 == 0 or psf.shape[1] % 2 == 0:
msg = (f'{name} must have odd dimensions, got '
f'shape {psf.shape}.')
raise ValueError(msg)
if not np.all(np.isfinite(psf)):
msg = f'{name} contains NaN or Inf values.'
raise ValueError(msg)
if np.sum(psf) == 0:
msg = f'{name} must have a non-zero sum; it cannot be normalized.'
raise ValueError(msg)
def _validate_window_array(window_array, expected_shape):
"""
Validate window function output.
Parameters
----------
window_array : any
The array returned by the window function.
expected_shape : tuple
The expected shape of the window array.
Raises
------
ValueError
If the window array is not a 2D array, has the wrong shape,
or contains values outside the range [0, 1].
"""
if not isinstance(window_array, np.ndarray) or window_array.ndim != 2:
msg = ('window function must return a 2D array, got '
f'{type(window_array).__name__} with '
f'ndim={getattr(window_array, "ndim", "undefined")}.')
raise ValueError(msg)
if window_array.shape != expected_shape:
msg = (f'window function must return an array with shape '
f'{expected_shape}, got {window_array.shape}.')
raise ValueError(msg)
if np.any(np.logical_or(window_array < 0, window_array > 1)):
msg = ('window function values must be in the range [0, 1], '
f'got range [{np.min(window_array)}, '
f'{np.max(window_array)}].')
raise ValueError(msg)
def _convert_psf_to_otf(psf, shape):
"""
Convert a point-spread function to an optical transfer function.
This computes the FFT of a PSF array after centering it in a
zero-padded array of the output shape and applying `ifftshift` to
move the PSF center to position [0, 0].
The PSF is first placed at the center of the zero-padded array,
ensuring its center aligns with the array's center. The zero-padding
is needed when the input kernel (e.g., a 3x3 Laplacian) is smaller
than the target shape, so that the resulting OTF has the correct
size for element-wise operations with other same-shaped OTFs.
The `ifftshift` operation then moves the PSF center from the array
center to position [0, 0], which is the standard convention for
computing OTFs via FFT. This ensures correct complex phase in
the resulting OTF for general use. Note that when only the power
spectrum (|OTF|^2) is needed, the shift has no effect because it
only changes the phase.
Parameters
----------
psf : 2D `~numpy.ndarray`
The PSF array. The PSF must have odd dimensions and be centered
on the central pixel. The PSF shape must be smaller than or
equal to the target shape in both dimensions.
shape : tuple of int
The desired output shape.
Returns
-------
otf : 2D `~numpy.ndarray`
The optical transfer function (complex array).
"""
if np.all(psf == 0):
return np.zeros(shape, dtype=complex)
if psf.ndim != 2:
msg = 'psf must be a 2D array.'
raise ValueError(msg)
if psf.shape[0] % 2 == 0 or psf.shape[1] % 2 == 0:
msg = f'psf must have odd dimensions, got shape {psf.shape}.'
raise ValueError(msg)
inshape = psf.shape
if any(i > s for i, s in zip(inshape, shape, strict=True)):
msg = (f'The PSF shape {inshape} is larger than the target '
f'shape {shape} in at least one dimension.')
raise ValueError(msg)
# Zero-pad to the output shape with PSF centered in the array
padded = np.zeros(shape, dtype=psf.dtype)
# Calculate where to place PSF so its center aligns with padded
# array center
center = tuple(s // 2 for s in shape)
psf_center = tuple(s // 2 for s in inshape)
start = tuple(c - pc for c, pc in zip(center, psf_center, strict=True))
padded[start[0]:start[0] + inshape[0],
start[1]:start[1] + inshape[1]] = psf
# Shift the centered PSF so its center moves to [0, 0]
padded = ifftshift(padded)
return fft2(padded)
def _apply_window_to_fourier(fourier_array, window, shape):
"""
Apply a centered window function to a Fourier-domain array.
The window function is assumed to be defined with the DC component
at the center of the array. Since Fourier arrays use the standard
FFT layout with the DC component at the corner, this function shifts
the array to the center, applies the window, and shifts it back.
Parameters
----------
fourier_array : 2D `~numpy.ndarray`
A complex Fourier-domain array with the DC component at the
corner.
window : callable
The window function. Must accept a single ``shape`` tuple and
return a 2D array with values in [0, 1].
shape : tuple of int
The shape passed to the window function and the expected shape
of the window output.
Returns
-------
result : 2D `~numpy.ndarray`
The windowed Fourier-domain array, still in standard FFT layout
(DC at the corner).
"""
window_array = window(shape)
_validate_window_array(window_array, shape)
fourier_array = fftshift(fourier_array)
fourier_array *= window_array
return ifftshift(fourier_array)
[docs]
def resize_psf(psf, input_pixel_scale, output_pixel_scale, *, order=3):
"""
Resize a PSF using spline interpolation of the requested order.
The total flux of the PSF is conserved during the resizing.
Parameters
----------
psf : 2D `~numpy.ndarray`
The 2D data array of the PSF. The PSF must have odd dimensions.
It is assumed to be centered on the central pixel.
input_pixel_scale : float
The pixel scale of the input ``psf``. The units must match
``output_pixel_scale``.
output_pixel_scale : float
The pixel scale of the output ``psf``. The units must match
``input_pixel_scale``.
order : int, optional
The order of the spline interpolation (0-5). The default is 3.
Returns
-------
result : 2D `~numpy.ndarray`
The resampled/interpolated 2D data array. The output always
has odd dimensions. The natural resampled size is computed
by taking the ceiling of ``input_size * (input_pixel_scale
/ output_pixel_scale)`` for each axis, then adding 1 to
any axis whose size is even. This guarantees the output is
centered and usable for PSF matching. When the output size is
adjusted, the effective pixel scale will be slightly smaller
than ``output_pixel_scale``; the exact value per axis is
``input_pixel_scale * input_size / output_size``.
Raises
------
ValueError
If ``psf`` is not a 2D array, has even dimensions, is not
centered, or if the pixel scales are not positive.
"""
psf = np.asarray(psf, dtype=float)
if input_pixel_scale <= 0 or output_pixel_scale <= 0:
msg = ('input_pixel_scale and output_pixel_scale must be '
'positive.')
raise ValueError(msg)
_validate_psf(psf, 'psf')
ratio = input_pixel_scale / output_pixel_scale
# Compute target shape using ceiling (never discard pixels), then
# add 1 to any even dimension to guarantee an odd output, which is
# required for PSF matching.
in_shape = np.array(psf.shape)
out_shape = np.maximum(1, np.ceil(in_shape * ratio).astype(int))
out_shape += out_shape % 2 == 0
# Per-axis zoom factors for the forced-odd target shape
zoom_factors = out_shape / in_shape
# Normalize the PSF to conserve total flux after resizing.
psf_sum = psf.sum()
result = zoom(psf, zoom_factors, order=order)
return result * (psf_sum / result.sum())