GMMPatchPrior#

class jolideco.priors.GMMPatchPrior(gmm=None, stride=None, cycle_spin=True, cycle_spin_subpix=False, generator=None, norm=IdentityImageNorm(), patch_norm=None, jitter=False, marginalize=False, device='cpu')[source]#

Bases: Prior

Patch prior

gmm#

Gaussian mixture model.

Type:

GaussianMixtureModel

stride#

Stride of the patches. By default it is half of the patch size.

Type:

int or “random”

cycle_spin#

Apply cycle spin.

Type:

bool

cycle_spin_subpix#

Apply subpixel cycle spin.

Type:

bool

generator#

Random number generator

Type:

~torch.Generator

norm#

Image normalisation applied before the GMM patch prior.

Type:

~jolideco.utils.ImageNorm

patch_norm#

Patch normalisation.

Type:

~jolideco.utils.PatchNorm

jitter#

Jitter patch positions.

Type:

bool

device#

Pytorch device

Type:

~pytorch.Device

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Attributes Summary

log_like_weight

Log likelihood weight

overlap

Patch overlap

patch_shape

Patch shape (tuple)

Methods Summary

__call__(flux[, mask])

Evaluate the prior

from_dict(data)

Create from dict

prior_image(flux)

Compute a patch image from the eigenimages of the best fittign patches.

prior_image_average(flux[, n_average])

Compute an average patch image by averaging using cycle spinning.

to_dict()

To dict

Attributes Documentation

log_like_weight#

Log likelihood weight

overlap#

Patch overlap

patch_shape#

Patch shape (tuple)

Methods Documentation

__call__(flux, mask=None)[source]#

Evaluate the prior

Parameters:

flux (~pytorch.Tensor) – Reconstructed flux

Returns:

log_prior – Summed log prior over all overlapping patches.

Return type:

float

classmethod from_dict(data)[source]#

Create from dict

prior_image(flux)[source]#

Compute a patch image from the eigenimages of the best fittign patches.

Parameters:

flux (~pytorch.Tensor) – Reconstructed flux

Returns:

prior_image – Average prior image.

Return type:

~numpy.ndarray

prior_image_average(flux, n_average=100)[source]#

Compute an average patch image by averaging using cycle spinning.

Parameters:
  • flux (~pytorch.Tensor) – Reconstructed flux

  • n_average (int) – Number of image to average over using cycle spinning.

Returns:

prior_image – Average prior image.

Return type:

~numpy.ndarray

to_dict()[source]#

To dict