Source code for jolideco.priors.lira
import torch
from torch.distributions import Dirichlet
from jolideco.utils.torch import cycle_spin, view_as_overlapping_patches_torch
from .core import Prior
[docs]
class LIRAPrior(Prior):
"""LIRA multiscale prior
Parameters
----------
alphas : list of float
Alpha values
"""
def __init__(self, alphas, cycle_spin=True, random_state=None, generator=None):
self.alphas = alphas
self.random_state = random_state
self.cycle_spin = cycle_spin
if generator is None:
generator = torch.Generator()
self.generator = generator
[docs]
def __call__(self, flux):
if self.cycle_spin:
flux = cycle_spin(image=flux, patch_shape=(2, 2), generator=self.generator)
log_prior = 0
for alpha in self.alphas:
# TODO: add downsampling...
patches = view_as_overlapping_patches_torch(flux, shape=(2, 2), stride=2)
patches = patches / torch.sum(patches, dim=1, keepdims=True)
dirichlet = Dirichlet(patches)
values = dirichlet.log_prob(...)
log_prior += torch.sum(values)
return log_prior
[docs]
def to_dict(self):
"""To dict"""
raise NotImplementedError