Source code for jolideco.core

import copy
import logging
import sys

import matplotlib.pyplot as plt
import numpy as np
import torch
from astropy.table import Table
from astropy.utils import lazyproperty
from astropy.visualization import simple_norm
from packaging import version
from tqdm.auto import tqdm

from .loss import PoissonLoss, PriorLoss, TotalLoss
from .models import FluxComponents, SpatialFluxComponent
from .utils.io import (
    IO_FORMATS_MAP_RESULT_READ,
    IO_FORMATS_MAP_RESULT_WRITE,
    get_reader,
    get_writer,
)
from .utils.misc import format_class_str
from .utils.torch import TORCH_DEFAULT_DEVICE
from .utils.plot import add_cbar

logging.basicConfig(level=logging.INFO)

log = logging.getLogger(__name__)

__all__ = ["MAPDeconvolver", "MAPDeconvolverResult"]


IS_PYTORCH2 = (
    version.parse(torch.__version__) >= version.parse("2.0")
    and "win" not in sys.platform
)

OPTIMIZER = {
    "adam": torch.optim.Adam,
    "sgd": torch.optim.SGD,
}


[docs] class MAPDeconvolver: """Maximum A-Posteriori deconvolver Attributes ---------- n_epochs : int Number of epochs to train beta : float Scale factor for the prior. learning_rate : float Learning rate compute_error : bool Whether to compute flux error stop_early : bool Stop training early, once the average results on the last n test datasets do not improve any more. stop_early_n_average: int Number of iterations to avergae over. device : `~pytorch.Device` Pytorch device display_progress : bool Whether to display a progress bar optimizer : {"adam", "sgd"} Optimizer to use """ _default_flux_component = "flux" def __init__( self, n_epochs=1_000, beta=1, learning_rate=0.1, compute_error=False, stop_early=False, stop_early_n_average=10, device=TORCH_DEFAULT_DEVICE, display_progress=True, optimizer="adam", ): self.n_epochs = n_epochs self.beta = beta self.learning_rate = learning_rate self.compute_error = compute_error self.stop_early = stop_early self.stop_early_n_average = stop_early_n_average self.display_progress = display_progress if "cuda" in device and not torch.cuda.is_available(): log.warning( f"Device {device} not available, falling back to {TORCH_DEFAULT_DEVICE}" ) device = TORCH_DEFAULT_DEVICE self.device = torch.device(device) self.optimizer = optimizer
[docs] def to_dict(self): """Convert deconvolver configuration to dict, with simple data types. Returns ------- data : dict Parameter dict. """ data = {} data.update(self.__dict__) data["device"] = str(self.device) return data
def __str__(self): """String representation""" return format_class_str(instance=self)
[docs] def run( self, datasets, datasets_validation=None, components=None, calibrations=None ): """Run the MAP deconvolver Parameters ---------- datasets : dict of [str, dict] Dictionary containing a name of the dataset as key and a dictionary containing, the data like "counts", "psf", "background" and "exposure". datasets_validation : dict of [str, dict] Dictionary containing a name of the validation dataset as key and a dictionary containing, the data like "counts", "psf", "background" and "exposure". components : `FluxComponents` or `FluxComponent` Flux components. calibrations : `NPredCalibrations` Optional model calibrations. Returns ------- flux : `~numpy.ndarray` Reconstructed flux. """ if self.stop_early and datasets_validation is None: raise ValueError("Early stopping requires providing test datasets") if isinstance(components, SpatialFluxComponent): components = {self._default_flux_component: components} components = FluxComponents(components) components_init = copy.deepcopy(components) calibrations_init = copy.deepcopy(calibrations) # Use torch's JIT compilation feature if available... if IS_PYTORCH2: components_compiled = torch.compile(components) else: components_compiled = components components_compiled = components.to(self.device) poisson_loss = PoissonLoss.from_datasets( datasets=datasets, components=components_compiled, device=self.device, calibrations=calibrations, ) if datasets_validation: poisson_loss_validation = PoissonLoss.from_datasets( datasets=datasets_validation, components=components_compiled, calibrations=calibrations, device=self.device, ) else: poisson_loss_validation = None prior_loss = PriorLoss(priors=components_compiled.priors) total_loss = TotalLoss( poisson_loss=poisson_loss, poisson_loss_validation=poisson_loss_validation, prior_loss=prior_loss, beta=self.beta, ) parameters = list(components_compiled.parameters()) if calibrations: parameters.extend(calibrations.parameters()) optimizer = OPTIMIZER[self.optimizer]( params=parameters, lr=self.learning_rate, ) disable = not self.display_progress with tqdm(total=self.n_epochs * len(datasets), disable=disable) as pbar: for epoch in range(self.n_epochs): pbar.set_description(f"Epoch {epoch + 1}") components.train() for counts, npred_model in poisson_loss.iter_by_dataset: optimizer.zero_grad() # evaluate npred model fluxes = components_compiled.to_flux_tuple() npred = npred_model.evaluate(fluxes=fluxes) # compute Poisson loss loss = poisson_loss.loss_function(npred, counts) # compute prior losses loss_prior = prior_loss(fluxes=fluxes) loss_total = loss - self.beta * loss_prior / total_loss.prior_weight loss_total.backward() optimizer.step() pbar.update(1) components.eval() total_loss.append_trace(fluxes=fluxes) row = total_loss.trace[-1] if ( self.stop_early and len(total_loss.trace) > self.stop_early_n_average ): range_mean = slice(-self.stop_early_n_average, None) trace_loss_validation = total_loss.trace[ "datasets-validation-total" ] loss_test_average = np.mean(trace_loss_validation[range_mean]) if row["datasets-validation-total"] > loss_test_average: break pbar.set_postfix( total=row["total"], datasets_total=row["datasets-total"], priors_total=row["priors-total"], ) if self.compute_error: flux_errors = total_loss.fluxes_error(fluxes=fluxes) components.set_flux_errors(flux_errors=flux_errors) config = self.to_dict() return MAPDeconvolverResult( config=config, components=components, components_init=components_init, trace_loss=total_loss.trace, calibrations=calibrations, calibrations_init=calibrations_init, )
[docs] class MAPDeconvolverResult: """MAP deconvolver result Parameters ---------- config : `dict` Configuration from the `LIRADeconvolver` components: `FluxComponents` Flux components. components_init : `FluxComponents` Initial flux components. trace_loss : `~astropy.table.Table` or dict Trace of the total loss. """ def __init__( self, config, components, components_init, trace_loss, calibrations=None, calibrations_init=None, wcs=None, ): self._components = components self._components_init = components_init self.trace_loss = trace_loss self._calibrations = calibrations self._calibrations_init = calibrations_init self._config = config self._wcs = wcs @property def components(self): """Flux components (`FluxComponents`)""" return self._components @property def components_init(self): """Initial flux components (`FluxComponents`)""" return self._components_init @property def calibrations(self): """Calibrations (`NPredCalibrations`)""" return self._calibrations @property def calibrations_init(self): """Initial calibrations (`NPredCalibrations`)""" return self._calibrations_init @property def flux_total(self): """Total flux""" return self.components.flux_total_numpy @property def flux_upsampled_total(self): """Total flux""" return self.components.flux_upsampled_total_numpy @lazyproperty def config(self): """Configuration data (`dict`)""" return self._config
[docs] def plot_trace_loss(self, ax=None, which=None, **kwargs): """Plot trace loss Parameters ---------- ax : `~matplotlib.pyplot.Axes` Plot axes which : list of str Which traces to plot. Returns ------- ax : `~matplotlib.pyplot.Axes` Plot axes """ from .utils.plot import plot_trace_loss ax = plt.gca() if ax is None else ax plot_trace_loss(ax=ax, trace_loss=self.trace_loss, which=which, **kwargs) return ax
[docs] def peek(self, figsize=(12, 5), kwargs_norm=None): """Plot the result and the trace of the loss function Parameters ---------- figsize : tuple Figure size """ fig, axes = plt.subplots(nrows=1, ncols=2, figsize=figsize) self.plot_trace_loss(ax=axes[0]) kwargs_norm = kwargs_norm or {"min_cut": 0, "stretch": "asinh", "asinh_a": 0.01} flux = self.components.flux_total_numpy norm = simple_norm(flux, **kwargs_norm) kwargs = { "norm": norm, "interpolation": "None" } im = axes[1].imshow(flux, origin="lower", **kwargs) add_cbar(im=im, ax=axes[1], fig=fig)
@property def config_table(self): """Configuration data as table (`~astropy.table.Table`)""" config = Table() for key, value in self.config.items(): config[key] = [value] return config
[docs] def write(self, filename, overwrite=False, format="fits"): """Write result to file Parameters ---------- filename : str or `Path` Output filename overwrite : bool Overwrite file. format : {"fits"} Format to use. """ writer = get_writer( filename=filename, format=format, registry=IO_FORMATS_MAP_RESULT_WRITE ) writer(result=self, filename=filename, overwrite=overwrite)
[docs] @classmethod def read(cls, filename, format="fits"): """Write result to file Parameters ---------- filename : str or `Path` Output filename format : {"fits"} Format to use. Returns ------- result : `~MAPDeconvolverResult` Result object """ reader = get_reader( filename=filename, format=format, registry=IO_FORMATS_MAP_RESULT_READ ) return reader(filename=filename)