Source code for jolideco.models.core

import logging
from pathlib import Path
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from astropy.coordinates import SkyCoord
from astropy.utils import lazyproperty
from astropy.visualization import simple_norm
from astropy.wcs import WCS

from jolideco.priors.core import Prior, Priors, UniformPrior
from jolideco.utils.io import (
    IO_FORMATS_FLUX_COMPONENT_READ,
    IO_FORMATS_FLUX_COMPONENT_WRITE,
    IO_FORMATS_FLUX_COMPONENTS_READ,
    IO_FORMATS_FLUX_COMPONENTS_WRITE,
    IO_FORMATS_SPARSE_FLUX_COMPONENT_READ,
    IO_FORMATS_SPARSE_FLUX_COMPONENT_WRITE,
    document_io_formats,
    get_reader,
    get_writer,
)
from jolideco.utils.misc import format_class_str
from jolideco.utils.plot import add_cbar
from jolideco.utils.torch import grid_weights

log = logging.getLogger(__name__)


__all__ = [
    "SpatialFluxComponent",
    "FluxComponents",
    "SparseSpatialFluxComponent",
]


[docs] class SparseSpatialFluxComponent(nn.Module): """Sparse flux component to represent a list of point sources Attributes ---------- flux : `~torch.Tensor` Initial flux tensor x_pos : `~torch.Tensor` x position in pixel coordinates y_pos : `~torch.Tensor` y position in pixel coordinates shape : tuple of int Image shape use_log_flux : bool Use log scaling for flux prior : `Prior` Prior for this flux component. frozen : bool Whether to freeze component. wcs : `~astropy.wcs.WCS` World coordinate transform object """ is_sparse = True upsampling_factor = 1 _shape_eval = (-1, 1, 1, 1, 1) _shape_eval_x = (1, 1, 1, 1, -1) _shape_eval_y = (1, 1, 1, -1, 1) _registry_write = IO_FORMATS_SPARSE_FLUX_COMPONENT_WRITE _registry_read = IO_FORMATS_SPARSE_FLUX_COMPONENT_READ def __init__( self, flux, x_pos, y_pos, shape, use_log_flux=True, prior=None, frozen=False, wcs=None, ): super().__init__() if prior is None: prior = UniformPrior() if use_log_flux: flux = torch.log(flux) self.prior = prior self.frozen = frozen self._wcs = wcs self._shape = shape self._flux = nn.Parameter(flux.type(torch.float32)) self.x_pos = nn.Parameter(x_pos.type(torch.float32)) self.y_pos = nn.Parameter(y_pos.type(torch.float32)) self._use_log_flux = use_log_flux
[docs] def parameters(self, recurse=True): """Parameter list""" if self.frozen: return [] else: return super().parameters(recurse)
@property def x_pos_numpy(self) -> np.ndarray: """x pos as numpy array""" return self.x_pos.detach().cpu().numpy() @property def y_pos_numpy(self) -> np.ndarray: """y pos as numpy array""" return self.y_pos.detach().cpu().numpy() @property def sky_coord(self) -> SkyCoord: """Positions as SkyCoord""" return SkyCoord.from_pixel( xp=self.x_pos_numpy, yp=self.y_pos_numpy, wcs=self.wcs )
[docs] @classmethod def from_numpy(cls, flux, x_pos, y_pos, **kwargs): """Create sparse flux component from numpy arrays Attributes ---------- flux : `~numpy.ndarray` Initial flux tensor x_pos : `~numpy.ndarray` x position in pixel coordinates y_pos : `~numpy.ndarray` y position in pixel coordinates **kwargs : dict Keyword arguments forwarded to `SparseFluxComponent` Returns ------- sparse_flux_component : `SparseFluxComponent` Sparse flux component """ flux = np.atleast_1d(flux) x_pos = np.atleast_1d(x_pos) y_pos = np.atleast_1d(y_pos) flux = torch.from_numpy(flux.astype(np.float32)) x_pos = torch.from_numpy(x_pos.astype(np.float32)) y_pos = torch.from_numpy(y_pos.astype(np.float32)) return cls(flux=flux, x_pos=x_pos, y_pos=y_pos, **kwargs)
[docs] @classmethod def from_sky_coord(cls, skycoord, wcs, **kwargs): """Create sparse flux component from sky coordinates Parameters ---------- skycoord: `~astropy.coordinates.SkyCoord` Sky coordinates wcs : `~astropy.wcs.WCS` World coordinate transform object Returns ------- sparse_flux_component : `SparseFluxComponent` Sparse flux component """ y_pos, x_pos = skycoord.to_pixel(wcs=wcs) return cls.from_numpy(x_pos=x_pos, y_pos=y_pos, **kwargs)
@property def wcs(self) -> WCS: """Flux error""" return self._wcs @property def shape(self) -> tuple: """Shape of the flux component""" return (1, 1) + self._shape @lazyproperty def indices(self) -> tuple[torch.Tensor, torch.Tensor]: """Shape of the flux component""" idx = torch.arange(self._shape[1], dtype=torch.float32) idy = torch.arange(self._shape[0], dtype=torch.float32) return idx.reshape(self._shape_eval_x), idy.reshape(self._shape_eval_y) @property def use_log_flux(self) -> bool: """Use log flux""" return self._use_log_flux @property def flux_numpy(self) -> np.ndarray: """Flux as numpy array""" flux_cpu = self.flux.detach().cpu() return flux_cpu.numpy()[0, 0] @property def flux(self) -> torch.Tensor: """Flux (`~torch.Tensor`)""" y, x = self.indices x0 = self.x_pos.reshape(self._shape_eval) y0 = self.y_pos.reshape(self._shape_eval) weights = grid_weights(x=x, y=y, x0=x0, y0=y0) if self.use_log_flux: flux = torch.exp(self._flux) else: flux = self._flux flux = weights * flux.reshape(self._shape_eval) return flux.sum(axis=0) @property def flux_upsampled(self) -> torch.Tensor: """Upsampled flux""" return self.flux
[docs] def plot(self, ax=None, kwargs_norm=None, **kwargs): """Plot flux component as sky image Parameters ---------- ax : `~matplotlib.pyplot.Axes` Plotting axes kwargs_norm: dict Keyword arguments passed to `~astropy.visualization.simple_norm` **kwargs : dict Keywords forwarded to `~matplotlib.pyplot.imshow` Returns ------- ax : `~matplotlib.pyplot.Axes` Plotting axes """ if ax is None: ax = plt.subplot(projection=self.wcs) kwargs_norm = kwargs_norm or {"min_cut": 0, "stretch": "asinh", "asinh_a": 0.01} flux = self.flux_numpy norm = simple_norm(flux, **kwargs_norm) kwargs.setdefault("norm", norm) kwargs.setdefault("interpolation", "None") im = ax.imshow(flux, origin="lower", **kwargs) add_cbar(im=im, ax=ax, fig=ax.figure) return ax
[docs] def to_dict(self, **kwargs) -> dict[str, Any]: """Convert sparse flux component configuration to dict, with simple data types. Returns ------- data : dict Parameter dict. """ # TODO: add all parameters, flux_upsampled could be filename data = {} data["use_log_flux"] = self.use_log_flux data["frozen"] = self.frozen data["shape"] = self.shape if self.use_log_flux: flux = torch.exp(self._flux) else: flux = self._flux data["flux"] = flux.detach().cpu().numpy() data["x_pos"] = self.x_pos_numpy data["y_pos"] = self.y_pos_numpy data["prior"] = self.prior.to_dict() return data
def __str__(self): """String representation""" return format_class_str(instance=self)
[docs] @document_io_formats(registry=_registry_write) def write(self, filename, format=None, overwrite=False, **kwargs): """Write flux component to file Parameters ---------- filename : str or `Path` Output filename overwrite : bool Overwrite file. format : {formats} Format to use. """ writer = get_writer( filename=filename, format=format, registry=self._registry_write ) return writer( flux_component=self, filename=filename, overwrite=overwrite, **kwargs )
[docs] @classmethod @document_io_formats(registry=_registry_read) def read(cls, filename, format=None): """Read sparse flux component from file Parameters ---------- filename : str or `Path` Output filename format : {formats} Format to use. Returns ------- flux_component : `SparseFluxComponent` Flux component """ reader = get_reader( filename=filename, format=format, registry=cls._registry_read ) return reader(filename)
def freeze_mask(module, grad_input, grad_output): """Freeze masked parameters""" if module.mask: grad_input = grad_input * module.mask return grad_input
[docs] class SpatialFluxComponent(nn.Module): """Flux component Attributes ---------- flux_upsampled : `~torch.Tensor` Initial flux tensor flux_upsampled_error : `~torch.Tensor` Flux tensor error use_log_flux : bool Use log scaling for flux upsampling_factor : None Spatial upsampling factor for the flux. prior : `Prior` Prior for this flux component. frozen : bool Whether to freeze component. wcs : `~astropy.wcs.WCS` World coordinate transform object """ is_sparse = False _registry_read = IO_FORMATS_FLUX_COMPONENT_READ _registry_write = IO_FORMATS_FLUX_COMPONENT_WRITE def __init__( self, flux_upsampled, flux_upsampled_error=None, mask=None, use_log_flux=True, upsampling_factor=1, prior=None, frozen=False, wcs=None, ): super().__init__() if not flux_upsampled.ndim == 4: raise ValueError( f"Flux tensor must be four dimensional. Got {flux_upsampled.ndim}" ) if use_log_flux: flux_upsampled = torch.log(flux_upsampled) self._flux_upsampled = nn.Parameter(flux_upsampled) self._flux_upsampled_error = flux_upsampled_error if mask is not None and not mask.shape == flux_upsampled.shape: raise ValueError( "Flux and mask need to have the same shape, got " f"{flux_upsampled.shape} and {mask.shape}" ) self.mask = mask self._use_log_flux = use_log_flux self.upsampling_factor = int(upsampling_factor) if prior is None: prior = UniformPrior() self.prior = prior self.frozen = frozen self._wcs = wcs self.register_full_backward_hook(freeze_mask)
[docs] def to_dict(self, include_data=None) -> dict[str, Any]: """Convert flux component configuration to dict, with simple data types. Parameters ---------- include_data : None or {"numpy"} Optionally include data array in the given format Returns ------- data : dict Parameter dict. """ # TODO: add all parameters, flux_upsampled could be filename data = {} data["use_log_flux"] = self.use_log_flux data["upsampling_factor"] = int(self.upsampling_factor) data["frozen"] = self.frozen data["prior"] = self.prior.to_dict() if include_data == "numpy": data["flux_upsampled"] = self.flux_upsampled_numpy return data
[docs] @classmethod def from_dict(cls, data): """Create flux component from dict Parameters ---------- data : dict Parameter dict. Returns ------- flux_component : `FluxComponent` Flux component """ kwargs = data.copy() prior_data = kwargs.pop("prior", None) if prior_data: kwargs["prior"] = Prior.from_dict(data=prior_data) value = kwargs["flux_upsampled"] if isinstance(value, str): filename = Path(value) flux = cls.read(filename).flux_upsampled elif not isinstance(value, torch.Tensor): flux = torch.from_numpy(value[np.newaxis, np.newaxis].astype(np.float32)) else: flux = value kwargs["flux_upsampled"] = flux return cls(**kwargs)
def __str__(self): """String representation""" return format_class_str(instance=self) @property def wcs(self) -> WCS: """Flux error""" return self._wcs
[docs] def parameters(self, recurse=True): """Parameter list""" if self.frozen: return [] else: return super().parameters(recurse)
[docs] @classmethod def from_numpy(cls, flux, mask=None, **kwargs): """Create flux component from downsampled data. Parameters ---------- flux : `~numpy.ndarray` Flux init array with 2 dimensions **kwargs : dict Keyword arguments passed to `FluxComponent` Returns ------- flux_component : `FluxComponent` Flux component """ upsampling_factor = kwargs.get("upsampling_factor", None) # convert to pytorch tensors flux = torch.from_numpy(flux[np.newaxis, np.newaxis].astype(np.float32)) if upsampling_factor: flux = F.interpolate(flux, scale_factor=upsampling_factor, mode="bilinear") if mask is not None: mask = torch.from_numpy(mask[np.newaxis, np.newaxis].astype(bool)) if upsampling_factor: mask = F.interpolate( mask.type(torch.float32), scale_factor=upsampling_factor, mode="bilinear", ) mask = mask > 0.5 return cls(flux_upsampled=flux, mask=mask, **kwargs)
[docs] @classmethod def from_flux_init_datasets(cls, datasets, **kwargs): """Compute flux init from datasets by averaging over the raw flux estimate. Parameters ---------- datasets : list of dict List of dictionaries containing, "counts", "psf", "background" and "exposure". **kwargs : dict Keyword arguments passed to `FluxComponent` Returns ------- flux_init : `~numpy.ndarray` Initial flux estimate. """ fluxes = [] for dataset in datasets: flux = dataset["counts"] / dataset["exposure"] - dataset["background"] fluxes.append(flux) flux_init = np.nanmean(fluxes, axis=0) return cls.from_numpy(flux=flux_init, **kwargs)
@property def shape(self) -> tuple[int, int, int, int]: """Shape of the flux component""" return self._flux_upsampled.shape @property def shape_image(self) -> tuple[int, int]: """Image shape of the flux component""" return self.shape[-2:] @property def use_log_flux(self) -> bool: """Use log flux""" return self._use_log_flux @property def flux_upsampled(self) -> torch.Tensor: """Flux""" flux = self._flux_upsampled if self.use_log_flux: flux = torch.exp(flux) if self.mask is not None: flux = flux * self.mask return flux @property def flux(self) -> torch.Tensor: """Flux as torch tensor""" flux = self.flux_upsampled if self.upsampling_factor: flux = F.avg_pool2d( flux, kernel_size=self.upsampling_factor, divisor_override=1, ) return flux @property def flux_upsampled_error(self) -> torch.Tensor: """Flux error as torch tensor""" return self._flux_upsampled_error @property def flux_numpy(self) -> np.ndarray: """Flux as numpy array""" flux_cpu = self.flux.detach().cpu() return flux_cpu.numpy()[0, 0] @property def flux_upsampled_numpy(self) -> np.ndarray: """Flux upsampled as numpy array""" flux_cpu = self.flux_upsampled.detach().cpu() return flux_cpu.numpy()[0, 0] @property def flux_upsampled_error_numpy(self) -> np.ndarray: """Flux error upsampled as numpy array""" flux_error_cpu = self.flux_upsampled_error.detach().cpu() return flux_error_cpu.numpy()[0, 0]
[docs] @classmethod @document_io_formats(registry=_registry_read) def read(cls, filename, format=None): """Read flux component from file Parameters ---------- filename : str or `Path` Output filename format : {formats} Format to use. Returns ------- flux_component : `FluxComponent` Flux component """ reader = get_reader( filename=filename, format=format, registry=cls._registry_read ) return reader(filename)
[docs] @document_io_formats(registry=_registry_write) def write(self, filename, format=None, overwrite=False, **kwargs): """Write flux component to file Parameters ---------- filename : str or `Path` Output filename overwrite : bool Overwrite file. format : {formats} Format to use. """ writer = get_writer( filename=filename, format=format, registry=self._registry_write ) return writer( flux_component=self, filename=filename, overwrite=overwrite, **kwargs )
[docs] def plot(self, ax=None, kwargs_norm=None, **kwargs): """Plot flux component as sky image Parameters ---------- ax : `~matplotlib.pyplot.Axes` Plotting axes kwargs_norm: dict Keyword arguments passed to `~astropy.visualization.simple_norm` **kwargs : dict Keywords forwarded to `~matplotlib.pyplot.imshow` Returns ------- ax : `~matplotlib.pyplot.Axes` Plotting axes """ if ax is None: ax = plt.subplot(projection=self.wcs) kwargs_norm = kwargs_norm or {"min_cut": 0, "stretch": "asinh", "asinh_a": 0.01} flux = self.flux_upsampled_numpy norm = simple_norm(flux, **kwargs_norm) kwargs.setdefault("norm", norm) kwargs.setdefault("interpolation", "None") ax.imshow(flux, origin="lower", **kwargs) return ax
[docs] def as_gp_map(self): """Convert to Gammapy map Returns ------- map : `~gammapy.maps.WcsNDmap` Gammapy WCS map """ from gammapy.maps import Map, WcsGeom geom = WcsGeom(wcs=self.wcs, npix=self.shape_image) return Map.from_geom(geom=geom, data=self.flux_numpy)
[docs] class FluxComponents(nn.ModuleDict): """Flux components""" _registry_read = IO_FORMATS_FLUX_COMPONENTS_READ _registry_write = IO_FORMATS_FLUX_COMPONENTS_WRITE
[docs] def parameters(self): """Parameter list""" parameters = [] for component in self.values(): if not component.frozen: parameters.extend(component.parameters()) return parameters
@property def priors(self): """Priors associated with the componenet""" priors = Priors() for name, component in self.items(): priors[name] = component.prior return priors @property def flux_upsampled_total(self): """Total summed flux (`~torch.tensor`)""" values = list(self.values()) flux = torch.zeros(values[0].shape) for component in values: flux += component.flux_upsampled return flux @property def fluxes_numpy(self): """Fluxes (`dict` of `~numpy.ndarray`)""" fluxes = {} for name, component in self.items(): fluxes[name] = component.flux_numpy return fluxes @property def fluxes_upsampled_numpy(self): """Upsampled fluxes (`dict` of `~numpy.ndarray`)""" return self.to_numpy() @property def flux_upsampled_total_numpy(self): """Usampled total flux""" return np.sum([flux for flux in self.fluxes_upsampled_numpy.values()], axis=0) @property def flux_total_numpy(self): """Usampled total flux""" return np.sum([flux for flux in self.fluxes_numpy.values()], axis=0)
[docs] def to_dict(self, include_data=None): """Convert flux component configuration to dict, with simple data types. Parameters ---------- include_data : None or {"numpy"} Optionally include data array in the given format Returns ------- data : dict Parameter dict. """ fluxes = {} for name, component in self.items(): fluxes[name] = component.to_dict(include_data=include_data) return fluxes
[docs] @classmethod def from_dict(cls, data): """Create flux components from dict Parameters ---------- data : dict Parameter dict. Returns ------- flux_components : `FluxComponents` Flux components """ components = [] for name, component_data in data.items(): component = SpatialFluxComponent.from_dict(data=component_data) components.append((name, component)) return cls(components)
[docs] def to_numpy(self): """Fluxes of the components (dict of `~numpy.ndarray`)""" fluxes = {} for name, component in self.items(): flux_cpu = component.flux_upsampled.detach().cpu() fluxes[name] = np.squeeze(flux_cpu.numpy()) return fluxes
[docs] def to_flux_tuple(self): """Fluxes as tuple (tuple of `~torch.tensor`)""" return tuple([_.flux_upsampled for _ in self.values()])
[docs] def set_flux_errors(self, flux_errors): """Set flux errors""" for name, flux_error in flux_errors.items(): self[name]._flux_upsampled_error = flux_error
[docs] @classmethod @document_io_formats(registry=_registry_read) def read(cls, filename, format=None): """Read flux components from file Parameters ---------- filename : str or `Path` Output filename format : {formats} Format to use. Returns ------- flux_components : `FluxComponents` Flux components """ reader = get_reader( filename=filename, format=format, registry=cls._registry_read ) return reader(filename=filename)
[docs] @document_io_formats(registry=_registry_write) def write(self, filename, overwrite=False, format=None, **kwargs): """Write flux components to file Parameters ---------- filename : str or `Path` Output filename overwrite : bool Overwrite file. format : {formats} Format to use. """ writer = get_writer( filename=filename, format=format, registry=self._registry_write ) return writer( flux_components=self, filename=filename, overwrite=overwrite, **kwargs )
[docs] def plot(self, figsize=None, kwargs_norm=None, **kwargs): """Plot images of the flux components Parameters ---------- fisize : tuple of int Figure size. kwargs_norm: dict Keyword arguments passed to `~astropy.visualization.simple_norm` **kwargs : dict Keywords forwarded to `~matplotlib.pyplot.imshow` Returns ------- axes : list of `~matplotlib.pyplot.Axes` Plot axes """ ncols = len(self) + 1 if figsize is None: figsize = (ncols * 5, 5) fig, axes = plt.subplots( nrows=1, ncols=ncols, subplot_kw={"projection": list(self.values())[0].wcs}, figsize=figsize, ) kwargs_norm = kwargs_norm or {"min_cut": 0, "stretch": "asinh", "asinh_a": 0.01} flux = self.flux_total_numpy norm = simple_norm(flux, **kwargs_norm) im = axes[0].imshow(flux, origin="lower", norm=norm, **kwargs) axes[0].set_title("Total") for ax, name in zip(axes[1:], self.fluxes_numpy): component = self[name] component.plot(ax=ax, kwargs_norm=kwargs_norm, **kwargs) ax.set_title(name.title()) add_cbar(im=im, ax=ax, fig=fig) return axes
def __str__(self): return format_class_str(instance=self)