import matplotlib.pyplot as plt
import numpy as np
from scipy.signal import fftconvolve

from photutils.datasets import load_irac_psf
from photutils.psf_matching import (SplitCosineBellWindow, make_kernel,
                                    make_wiener_kernel)

ch1_hdu = load_irac_psf(channel=1)
ch4_hdu = load_irac_psf(channel=4)
ch1_psf = ch1_hdu.data
ch4_psf = ch4_hdu.data

window = SplitCosineBellWindow(alpha=0.15, beta=0.3)
regularization = 0.0001
kernel1 = make_kernel(ch1_psf, ch4_psf, window=window,
                      regularization=regularization)
kernel2 = make_wiener_kernel(ch1_psf, ch4_psf,
                             regularization=regularization)
kernel3 = make_wiener_kernel(ch1_psf, ch4_psf,
                             regularization=regularization,
                             penalty='laplacian')
kernel4 = make_wiener_kernel(ch1_psf, ch4_psf,
                             regularization=regularization,
                             penalty='biharmonic')

matched1 = fftconvolve(ch1_psf, kernel1, mode='same')
matched2 = fftconvolve(ch1_psf, kernel2, mode='same')
matched3 = fftconvolve(ch1_psf, kernel3, mode='same')
matched4 = fftconvolve(ch1_psf, kernel4, mode='same')

resid1 = matched1 - ch4_psf
resid2 = matched2 - ch4_psf
resid3 = matched3 - ch4_psf
resid4 = matched4 - ch4_psf
vmax = np.abs(
    np.array([resid1, resid2, resid3, resid4])).max()

titles = ['make_kernel',
          'make_wiener_kernel',
          'make_wiener_kernel\n(Laplacian penalty)',
          'make_wiener_kernel\n(biharmonic penalty)']
residuals = [resid1, resid2, resid3, resid4]

fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(9, 7))
for ax, resid, title in zip(axes.ravel(), residuals, titles):
    axim = ax.imshow(resid, origin='lower', cmap='RdBu_r',
                     vmin=-vmax, vmax=vmax)
    fig.colorbar(axim, ax=ax)
    ax.set_title(title)

fig.suptitle('Residuals: PSF-matched minus channel 4 PSF')
fig.tight_layout()