import matplotlib.pyplot as plt
import numpy as np
from scipy.signal import fftconvolve

from photutils.datasets import load_irac_psf
from photutils.profiles import CurveOfGrowth
from photutils.psf_matching import (SplitCosineBellWindow, make_kernel,
                                    make_wiener_kernel)

ch1_hdu = load_irac_psf(channel=1)
ch4_hdu = load_irac_psf(channel=4)
ch1_psf = ch1_hdu.data
ch4_psf = ch4_hdu.data

window = SplitCosineBellWindow(alpha=0.15, beta=0.3)
regularization = 0.0001
kernel1 = make_kernel(ch1_psf, ch4_psf, window=window,
                      regularization=regularization)
kernel2 = make_wiener_kernel(ch1_psf, ch4_psf,
                             regularization=regularization)
kernel3 = make_wiener_kernel(ch1_psf, ch4_psf,
                             regularization=regularization,
                             penalty='laplacian')
kernel4 = make_wiener_kernel(ch1_psf, ch4_psf,
                             regularization=regularization,
                             penalty='biharmonic')

matched1 = fftconvolve(ch1_psf, kernel1, mode='same')
matched2 = fftconvolve(ch1_psf, kernel2, mode='same')
matched3 = fftconvolve(ch1_psf, kernel3, mode='same')
matched4 = fftconvolve(ch1_psf, kernel4, mode='same')

xycen = (40.0, 40.0)
radii = np.arange(1, 40)

cog_ch1 = CurveOfGrowth(ch1_psf, xycen, radii)
cog_ch4 = CurveOfGrowth(ch4_psf, xycen, radii)
cog_m1 = CurveOfGrowth(matched1, xycen, radii)
cog_m2 = CurveOfGrowth(matched2, xycen, radii)
cog_m3 = CurveOfGrowth(matched3, xycen, radii)
cog_m4 = CurveOfGrowth(matched4, xycen, radii)

for cog in [cog_ch1, cog_ch4, cog_m1, cog_m2, cog_m3, cog_m4]:
    cog.normalize()

labels = [
    'make_kernel',
    'make_wiener_kernel',
    'make_wiener_kernel (Laplacian)',
    'make_wiener_kernel (biharmonic)',
]
cogs_matched = [cog_m1, cog_m2, cog_m3, cog_m4]
ls_list = ['--', '-.', ':', (0, (3, 1, 1, 1))]

fig, (ax_top, ax_bot) = plt.subplots(
    nrows=2, figsize=(8, 8),
    gridspec_kw={'height_ratios': [3, 1]},
    sharex=True,
)

# Main panel
cog_ch1.plot(ax=ax_top, label='Channel 1 PSF', lw=2,
             color='C0', ls='-')
cog_ch4.plot(ax=ax_top, label='Channel 4 PSF', lw=3,
             color='k')
for cog, label, ls in zip(cogs_matched, labels, ls_list):
    cog.plot(ax=ax_top, label=label, lw=2, ls=ls)
ax_top.set_ylabel('Normalized Encircled Energy')
ax_top.set_title(
    'Encircled Energy: Channel 1 & 4 PSFs vs. PSF-matched results')
ax_top.legend(fontsize=9)
ax_top.set_xlabel('')

# Residual subpanel (matched - ch4)
for cog, label, ls in zip(cogs_matched, labels, ls_list):
    resid = cog.profile - cog_ch4.profile
    ax_bot.plot(cog.radius, resid, lw=2, ls=ls, label=label)
ax_bot.axhline(0, color='k', lw=1, ls='-')
ax_bot.set_xlabel('Radius (pixels)')
ax_bot.set_ylabel('Residual')
ax_bot.set_title('Matched $-$ Channel 4')

fig.tight_layout()