import matplotlib.pyplot as plt
import numpy as np
from astropy.table import QTable
from astropy.visualization import simple_norm
from photutils.aperture import CircularAperture
from photutils.datasets import make_noise_image
from photutils.detection import DAOStarFinder
from photutils.psf import (CircularGaussianPRF, PSFPhotometry,
                           make_psf_model_image)

psf_model = CircularGaussianPRF(flux=1, fwhm=2.7)
psf_shape = (9, 9)
n_sources = 10
shape = (101, 101)

data, true_params = make_psf_model_image(shape, psf_model, n_sources,
                                         model_shape=psf_shape,
                                         flux=(500, 700),
                                         min_separation=10, seed=0)
noise = make_noise_image(data.shape, mean=0, stddev=1, seed=0)
data += noise
error = np.abs(noise)

psf_model = CircularGaussianPRF(flux=1, fwhm=2.7)
fit_shape = (5, 5)
finder = DAOStarFinder(6.0, 2.0)
psfphot = PSFPhotometry(psf_model, fit_shape, finder=finder,
                        aperture_radius=4)

init_params = QTable()
init_params['x'] = [63]
init_params['y'] = [49]
phot = psfphot(data, error=error, init_params=init_params)

resid = psfphot.make_residual_image(data)
aper = CircularAperture(zip(phot['x_fit'], phot['y_fit']), r=4)

fig, ax = plt.subplots(ncols=3, figsize=(15, 5))
norm = simple_norm(data, 'sqrt', percent=99)
ax[0].imshow(data, norm=norm, origin='lower')
ax[1].imshow(data - resid, norm=norm, origin='lower')
im = ax[2].imshow(resid, norm=norm, origin='lower')
ax[0].set_title('Data')
aper.plot(ax=ax[0], color='red')
ax[1].set_title('Model')
aper.plot(ax=ax[1], color='red')
ax[2].set_title('Residual Image')
aper.plot(ax=ax[2], color='red')
fig.tight_layout()