Source code for jolideco.models.npred

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from jolideco.utils.io import (
    IO_FORMATS_NPRED_CALIBRATIONS_READ,
    IO_FORMATS_NPRED_CALIBRATIONS_WRITE,
    document_io_formats,
    get_reader,
    get_writer,
)
from jolideco.utils.misc import format_class_str
from jolideco.utils.torch import convolve_fft_torch, transpose

__all__ = [
    "NPredModel",
    "NPredModels",
    "NPredCalibration",
    "NPredCalibrations",
]


[docs] class NPredModel(nn.Module): """Predicted counts model with multiple components Parameters ---------- exposure : `~torch.Tensor` Exposure tensor psf : `~torch.Tensor` Point spread function rmf : `~torch.Tensor` Energy redistribution matrix. upsampling_factor : int Upsampling factor. """ def __init__(self, exposure, psf=None, rmf=None, upsampling_factor=None): super().__init__() self.register_buffer("exposure", exposure) self.register_buffer("psf", psf) self.register_buffer("rmf", rmf) self.upsampling_factor = upsampling_factor @property def shape_upsampled(self): """Shape of the NPred model""" return tuple(self.exposure.shape) @property def shape(self): """Shape of the NPred model""" shape = list(self.shape_upsampled) shape[-1] //= self.upsampling_factor shape[-2] //= self.upsampling_factor return tuple(shape)
[docs] @classmethod def from_numpy(cls, exposure, psf, upsampling_factor, correct_exposure_edges=True): """Create NPred model from numpy arrays Parameters ---------- flux : `~torch.Tensor` Flux tensor exposure : `~torch.Tensor` Exposure tensor psf : `~torch.Tensor` Point spread function upsampling_factor : int Upsampling factor. correct_exposure_edges : bool Correct psf leakage at the exposure edges. Returns ------- npred_model : `NPredModel` Predicted counts model """ dims = (np.newaxis, np.newaxis) kwargs = { "upsampling_factor": upsampling_factor, "exposure": torch.from_numpy(exposure[dims]), "psf": torch.from_numpy(psf[dims]), } for name in ["psf", "exposure"]: tensor = kwargs[name] if upsampling_factor: tensor = F.interpolate( tensor, scale_factor=upsampling_factor, mode="bilinear" ) if name in ["psf", "flux"] and upsampling_factor: tensor = tensor / upsampling_factor**2 kwargs[name] = tensor if correct_exposure_edges: exposure = kwargs["exposure"] weights = convolve_fft_torch( image=torch.ones_like(exposure), kernel=kwargs["psf"] ) kwargs["exposure"] = exposure / weights return cls(**kwargs)
[docs] @classmethod def from_dataset_numpy( cls, dataset, upsampling_factor=None, correct_exposure_edges=True, ): """Create NPred model from dataset Parameters ---------- dataset : dict of `~numpy.ndarray` Dict containing `"counts"`, `"psf"` and optionally `"exposure"` and `"background"` upsampling_factor : int Upsampling factor for exposure, background and psf. correct_exposure_edges : bool Correct psf leakage at the exposure edges. Returns ------- npred_model : `NPredModel` Predicted counts model """ return cls.from_numpy( exposure=dataset["exposure"], psf=dataset["psf"], upsampling_factor=upsampling_factor, correct_exposure_edges=correct_exposure_edges, )
[docs] def forward(self, flux): """Forward folding model evaluation. Parameters ---------- flux : `~torch.Tensor` Flux tensor Returns ------- npred : `~torch.Tensor` Predicted number of counts """ npred = flux * self.exposure if self.psf is not None: npred = convolve_fft_torch(npred, self.psf) if self.upsampling_factor: npred = F.avg_pool2d( npred, kernel_size=self.upsampling_factor, divisor_override=1 ) if self.rmf is not None: npred_T = transpose(npred[0]) npred = torch.matmul(npred_T, self.rmf) npred = transpose(npred)[None] return torch.clip(npred, 0, torch.inf)
[docs] class NPredModels(nn.ModuleDict): """Flux components Parameters ---------- background : `~torch.Tensor` Background tensor calibration : `NPredCalibration` Calibration model. """ def __init__(self, background, calibration=None, *args, **kwargs): super().__init__(*args, **kwargs) self.register_buffer("background", background) self.calibration = calibration
[docs] def evaluate_per_component(self, fluxes): """Evaluate npred model per component Parameters ---------- fluxes : tuple of `~torch.tensor` Flux components Returns ------- npreds : dict `~torch.tensor` Predicted counts tensor per component """ npreds = {} for (name, npred_model), flux in zip(self.items(), fluxes): if self.calibration is not None: flux = self.calibration(flux=flux, scale=npred_model.upsampling_factor) npreds[name] = npred_model(flux=flux) if self.calibration is not None: npreds["background"] = self.background * self.calibration.background_norm else: npreds["background"] = self.background return npreds
[docs] def evaluate(self, fluxes): """Evaluate npred model Parameters ---------- fluxes : tuple of `~torch.tensor` Flux components Returns ------- npred_total : `~torch.tensor` Predicted counts tensor """ npreds = self.evaluate_per_component(fluxes=fluxes) npred_total = torch.zeros(self.background.shape, device=fluxes[0].device) for npred in npreds.values(): npred_total += npred return npred_total
[docs] @classmethod def from_dataset_numpy(cls, dataset, components, calibration=None): """Create multiple npred models. Parameters ---------- dataset : dict of `~numpy.ndarray` Dataset components : `FluxComponents` Flux components Returns ------- npred_models : `NPredModel` NPredModels """ values = [] for name, component in components.items(): psf = dataset["psf"] if isinstance(psf, dict): psf = psf[name] npred_model = NPredModel.from_numpy( exposure=dataset["exposure"], psf=psf, upsampling_factor=component.upsampling_factor, ) values.append((name, npred_model)) background = torch.from_numpy(dataset["background"][np.newaxis, np.newaxis]) return cls(background, calibration, values)
[docs] class NPredCalibration(nn.Module): """Dataset calibration parameters Attributes ---------- shift_x : `~torch.Tensor` Shift in x direction shift_y: `~torch.Tensor` Shift in y direction background_norm: `~torch.Tensor` Background normalisation parameter frozen : bool Whether to freeze component. """ _grid_sample_kwargs = {"align_corners": False} def __init__( self, shift_x=0.0, shift_y=0.0, background_norm=1.0, frozen=False, ): super().__init__() self.shift_xy = nn.Parameter(torch.Tensor([[shift_x, shift_y]])) value = torch.log(torch.Tensor([background_norm])) self._background_norm = nn.Parameter(value) self.frozen = frozen @property def background_norm(self): """Background norm""" return torch.exp(self._background_norm)
[docs] def parameters(self, recurse=True): """Parameter list""" if self.frozen: return [] return super().parameters(recurse)
[docs] def to_dict(self): """Convert calibration model to dict, with simple data types. Returns ------- data : dict Parameter dict. """ data = {} shift_xy = self.shift_xy.detach().cpu().numpy() data["shift_x"] = float(shift_xy[0, 0]) data["shift_y"] = float(shift_xy[0, 1]) data["background_norm"] = float(self.background_norm.detach().cpu().numpy()) data["frozen"] = self.frozen return data
[docs] @classmethod def from_dict(cls, data): """Create calibration model from dict Parameters ---------- data : dict Parameter dict. Returns ------- calibration : `NPredCalibration` Calibration model. """ return cls(**data)
[docs] def __call__(self, flux, scale): """Apply affine transform to calibrate position. Parameters ---------- flux : `~torch.Tensor` Flux tensor scale : float Upsampling factor scale. Returns ------- flux : `~torch.Tensor` Flux tensor """ size = flux.size() scale = 2 * scale / torch.tensor([[size[-1]], [size[-2]]], device=flux.device) diag = torch.eye(2, device=flux.device) theta = torch.cat([diag, scale * self.shift_xy.T], dim=1)[None] grid = F.affine_grid(theta=theta, size=size) flux_shift = F.grid_sample(flux, grid=grid, **self._grid_sample_kwargs) return flux_shift
def __str__(self): """String representation""" return format_class_str(instance=self)
[docs] class NPredCalibrations(nn.ModuleDict): """Calibration components Parameters ---------- calibration : `NPredCalibration` Calibration model. """ _registry_read = IO_FORMATS_NPRED_CALIBRATIONS_READ _registry_write = IO_FORMATS_NPRED_CALIBRATIONS_WRITE
[docs] def parameters(self, recurse=True): """Parameter list""" parameters = [] for model in self.values(): if not model.frozen: pars = list(model.parameters()) parameters.extend(pars) return parameters
[docs] def to_dict(self): """Convert calibration configuration to dict, with simple data types. Returns ------- data : dict Parameter dict. """ data = {} for name, model in self.items(): data[name] = model.to_dict() return data
[docs] @classmethod def from_dict(cls, data): """Create calibration models from dict Parameters ---------- data : dict Parameter dict. Returns ------- calibrations : `NPredCalibrations` Calibrations """ components = [] for name, component_data in data.items(): component = NPredCalibration.from_dict(data=component_data) components.append((name, component)) return cls(components)
[docs] @classmethod @document_io_formats(registry=_registry_read) def read(cls, filename, format=None): """Read npred calibrations from file Parameters ---------- filename : str or `Path` Output filename format : {formats} Format to use. Returns ------- calibrations : `NPredCalibrations` Calibrations """ reader = get_reader( filename=filename, format=format, registry=cls._registry_read ) return reader(filename)
[docs] @document_io_formats(registry=_registry_write) def write(self, filename, format=None, overwrite=False, **kwargs): """Write npred calibrations to file Parameters ---------- filename : str or `Path` Output filename format : {formats} Format to use. overwrite : bool Overwrite file. """ writer = get_writer( filename=filename, format=format, registry=self._registry_write ) return writer( npred_calibrations=self, filename=filename, overwrite=overwrite, **kwargs )
def __str__(self): """String representation""" return format_class_str(instance=self)