import json
import logging
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import numpy as np
from astropy.table import Table
from astropy.utils import lazyproperty
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from jolideco.utils.misc import format_class_str
from jolideco.utils.norms import PatchNorm, SubtractMeanPatchNorm
from jolideco.utils.numpy import compute_precision_cholesky, get_pixel_weights
__all__ = ["GaussianMixtureModel", "GMM_REGISTRY"]
log = logging.getLogger(__name__)
@dataclass
class GaussianMixtureModelMeta:
"""Gaussian mixture model meta data
Attributes
----------
stride : int
Stride of the patch. Will be used to compute a correction factor for
overlapping patches. Overlapping pixels are down-weighted in the
log-likelihood computation.
patch_norm : str
Patch normalization
"""
stride: Optional[int] = None
patch_norm: PatchNorm = PatchNorm.from_dict({"type": "subtract-mean"})
@classmethod
def from_table(cls, table):
"""Set meta data from table
Parameters
----------
table : `~astropy.table.Table`
Table with meta data
Returns
-------
meta : `GaussianMixtureModelMeta`
Meta data
"""
patch_norm_type = table.meta.get("PNPTYPE", "subtract-mean")
patch_norm = PatchNorm.from_dict({"type": patch_norm_type})
npix = int((table["means"].shape[-1]) ** 0.5)
stride = npix // 2
return cls(stride=stride, patch_norm=patch_norm)
[docs]
class GaussianMixtureModel(nn.Module):
"""Gaussian mixture model
Attributes
----------
means : `~torch.Tensor`
Means
covariances : `~torch.Tensor`
Covariances
weights : `~torch.Tensor`
Weights
precisions_cholesky : `~torch.Tensor`
Precision matrices
meta: `GaussianMixtureModelMeta`
Meta data
"""
def __init__(self, means, covariances, weights, precisions_cholesky, meta=None):
super().__init__()
self.register_buffer("means", means)
self.register_buffer("covariances", covariances)
self.register_buffer("weights", weights)
self.register_buffer("precisions_cholesky", precisions_cholesky)
self.meta = meta or GaussianMixtureModelMeta()
@lazyproperty
def means_numpy(self):
"""Means (~numpy.ndarray)"""
return self.means.detach().cpu().numpy()
@lazyproperty
def covariances_numpy(self):
"""Covariances (~numpy.ndarray)"""
return self.covariances.detach().cpu().numpy()
@lazyproperty
def weights_numpy(self):
"""Weights (~numpy.ndarray)"""
return self.weights.detach().cpu().numpy()
@lazyproperty
def precisions_cholesky_numpy(self):
"""Precisions Cholesky (~numpy.ndarray)"""
return self.precisions_cholesky.detach().cpu().numpy()
@lazyproperty
def log_weights_numpy(self):
"""Weights (~numpy.ndarray)"""
return np.log(self.weights_numpy)
@lazyproperty
def log_weights(self):
"""Log weights (~numpy.ndarray)"""
return torch.log(self.weights)
[docs]
@classmethod
def from_numpy(cls, means, covariances, weights, meta=None):
"""Gaussian mixture model
Parameters
----------
means : `~numpy.ndarray`
Means
covariances : `~numpy.ndarray`
Covariances
weights : `~numpy.ndarray`
Weights
meta : `GaussianMixtureModelMeta`
Meta data
Returns
-------
gmm : `GaussianMixtureModel`
Gaussian mixture model.
"""
precisions_cholesky = compute_precision_cholesky(covariances=covariances)
return cls(
means=torch.from_numpy(means.astype(np.float32)),
covariances=torch.from_numpy(covariances.astype(np.float32)),
weights=torch.from_numpy(weights.astype(np.float32)),
precisions_cholesky=torch.from_numpy(
precisions_cholesky.astype(np.float32)
),
meta=meta,
)
@lazyproperty
def patch_shape(self):
"""Patch shape (tuple)"""
shape_mean = self.means.shape
npix = int((shape_mean[-1]) ** 0.5)
return npix, npix
@lazyproperty
def n_features(self):
"""Number of features"""
_, n_features, _ = self.covariances.shape
return n_features
@lazyproperty
def n_components(self):
"""Number of features"""
n_components, _, _ = self.covariances.shape
return n_components
@lazyproperty
def eigen_images(self):
"""Eigen images"""
from scipy import linalg
eigen_images = []
for idx in range(self.n_components):
w, v = linalg.eigh(self.covariances_numpy[idx])
data = (v @ w).reshape(self.patch_shape)
eigen_images.append(data)
return np.stack(eigen_images)
[docs]
def plot_eigen_images(self, ncols=20, figsize=None):
"""Plot images"""
nrows = self.n_components // ncols
if figsize is None:
width = 12
height = width * nrows / ncols
figsize = (width, height)
_, axes = plt.subplots(ncols=ncols, nrows=nrows, figsize=figsize)
for idx, ax in enumerate(axes.flat):
data = self.eigen_images[idx]
ax.imshow(data)
ax.set_axis_off()
ax.set_title(f"{idx}")
[docs]
def plot_mean_images(self, ncols=20, figsize=None):
"""Plot mean images"""
nrows = self.n_components // ncols
if figsize is None:
width = 12
height = width * nrows / ncols
figsize = (width, height)
_, axes = plt.subplots(ncols=ncols, nrows=nrows, figsize=figsize)
for idx, ax in enumerate(axes.flat):
ax.imshow(self.means_numpy[idx].reshape(self.patch_shape))
ax.set_axis_off()
ax.set_title(f"{idx}")
@lazyproperty
def means_precisions_cholesky(self):
"""Precision matrices pytorch"""
means_precisions = []
iterate = zip(self.means, self.precisions_cholesky)
for mu, prec_chol in iterate:
y = torch.matmul(mu, prec_chol)
means_precisions.append(y)
return means_precisions
@lazyproperty
def log_det_cholesky_numpy(self):
"""Compute the log-det of the cholesky decomposition of matrices"""
return self.log_det_cholesky.detach().cpu().numpy()
@lazyproperty
def log_det_cholesky(self):
"""Precision matrices pytorch"""
reshaped = self.precisions_cholesky.reshape(self.n_components, -1)
reshaped = reshaped[:, :: self.n_features + 1]
return torch.sum(torch.log(reshaped), axis=1)
[docs]
def estimate_log_prob_numpy(self, x):
"""Compute log likelihood for given feature vector"""
n_samples, n_features = x.shape
log_prob = np.empty((n_samples, self.n_components))
for k, (mu, prec_chol) in enumerate(
zip(self.means_numpy, self.precisions_cholesky_numpy)
):
y = np.dot(x, prec_chol) - np.dot(mu, prec_chol)
log_prob[:, k] = np.sum(np.square(y) * self.pixel_weights_numpy, axis=1)
# Since we are using the precision of the Cholesky decomposition,
# `- 0.5 * log_det_precision` becomes `+ log_det_precision_chol`
return (
-0.5 * (n_features * np.log(2 * np.pi) + log_prob)
+ self.log_det_cholesky_numpy
+ self.log_weights_numpy
)
[docs]
def estimate_log_prob(self, x):
"""Compute log likelihood for given feature vector"""
n_samples, n_features = x.shape
log_prob = torch.empty((n_samples, self.n_components), device=self.means.device)
iterate = zip(self.means_precisions_cholesky, self.precisions_cholesky)
for k, (mu_prec, prec_chol) in enumerate(iterate):
y = torch.matmul(x, prec_chol) - mu_prec
log_prob[:, k] = torch.sum(torch.square(y) * self.pixel_weights, axis=1)
# Since we are using the precision of the Cholesky decomposition,
# `- 0.5 * log_det_precision` becomes `+ log_det_precision_chol`
two_pi = torch.tensor(2 * np.pi)
return (
-0.5 * (n_features * torch.log(two_pi) + log_prob)
+ self.log_det_cholesky
+ self.log_weights
)
@lazyproperty
def pixel_weights(self):
"""Pixel weights"""
return torch.from_numpy(self.pixel_weights_numpy.astype(np.float32)).to(
self.means.device
)
@lazyproperty
def pixel_weights_numpy(self):
"""Pixel weights"""
if self.meta.stride is None:
weights = np.ones(self.patch_shape)
else:
weights = get_pixel_weights(
patch_shape=self.patch_shape, stride=self.meta.stride
)
return weights.reshape((1, -1))
[docs]
@classmethod
def from_sklearn_gmm(cls, gmm):
"""Create from sklearn GMM"""
return cls.from_numpy(
means=gmm.means_,
covariances=gmm.covariances_,
weights=gmm.weights_,
)
[docs]
@classmethod
def from_registry(cls, name, **kwargs):
"""Create GMM from registry
Parameters
----------
name : str
Name of the registered GMM.
Returns
-------
gmm : `GaussianMixtureModel`
Gaussian mixture model.
"""
from jolideco.priors.patches.gmm import GMM_REGISTRY
available_names = list(GMM_REGISTRY.keys())
if name not in available_names:
raise ValueError(
f"Not a supported GMM {name}, choose from {available_names}"
)
kwargs.update(GMM_REGISTRY[name])
return cls.read(**kwargs)
[docs]
@classmethod
def read(cls, filename, format="epll-matlab", **kwargs):
"""Read from matlab file
Parameters
----------
filename : str or Path
Filename
format : {"epll-matlab", "epll-matlab-16x16", "table"}
Format
**kwargs : dict
Keyword arguments passed to GaussianMixtureModel
Returns
-------
gmm : `GaussianMixtureModel`
Gaussian mixture model.
"""
import scipy.io as sio
filename = str(Path(os.path.expandvars(filename)))
if format == "epll-matlab":
gmm_dict = sio.loadmat(filename)
gmm_data = gmm_dict["GS"]
means = gmm_data["means"][0][0].T
covariances = gmm_data["covs"][0][0].T
weights = gmm_data["mixweights"][0][0][:, 0]
meta = GaussianMixtureModelMeta(
stride=4, patch_norm=SubtractMeanPatchNorm()
)
elif format == "epll-matlab-16x16":
gmm_dict = sio.loadmat(filename)
gmm_data = gmm_dict["GMM"]
means = np.zeros((200, 256))
covariances = gmm_data["covs"][0][0].T
weights = gmm_data["mixweights"][0][0][:, 0]
meta = GaussianMixtureModelMeta(
stride=8, patch_norm=SubtractMeanPatchNorm()
)
elif format == "table":
table = Table.read(filename)
means = table["means"].data
weights = table["weights"].data
covariances = table["covariances"].data
meta = GaussianMixtureModelMeta.from_table(table=table)
else:
raise ValueError(f"Not a supported format {format}")
return cls.from_numpy(
means=means, covariances=covariances, weights=weights, meta=meta, **kwargs
)
@lazyproperty
def covariance_det(self):
"""Covariance determinant"""
covar = self.covariances_numpy[0]
return np.linalg.det(covar)
[docs]
def kl_divergence(self, other):
"""Compute KL divergence with respect to another GMM"
See https://mr-easy.github.io/2020-04-16-kl-divergence-between-2-gaussian-distributions/
Parameters
----------
other : `~GaussianMixtureModel`
Other GMM
Returns
-------
value : float
KL divergence
"""
if not (self.n_components == 1 and other.n_components == 1):
raise ValueError(
"KL divergence can onlyy be computed for single component GMM"
)
k = self.means_numpy.shape[1]
diff = self.means_numpy[0] - other.means[0]
term_mean = diff.T @ other.precisons_cholesky[0] @ diff
term_trace = np.trace(other.precisions_cholesky[0] * self.covariances_numpy[0])
term_log = np.log(other.covariance_det / self.covariance_det)
return 0.5 * (term_log - k + term_mean + term_trace)
[docs]
def is_equal(self, other):
# TODO: improve check here?
if not self.covariances.shape == other.covariances.shape:
return False
else:
return np.allclose(self.covariances_numpy, other.covariances_numpy)
[docs]
def symmetric_kl_divergence(self, other):
"""Symmetric KL divergence"""
return other.kl_divergence(other=self) + self.kl_divergence(other=other)
[docs]
def to_dict(self):
"""To dict"""
data = {}
from jolideco.priors.patches.gmm import GMM_REGISTRY
for name in GMM_REGISTRY:
gmm = GaussianMixtureModel.from_registry(name=name)
if gmm.is_equal(self):
break
data["type"] = name
return data
[docs]
@classmethod
def from_dict(cls, data):
"""Create from dict
Parameters
----------
data : dict
Data dictionary
Returns
-------
gmm : `~GaussianMixtureModel`
Gaussian mixture model
"""
return cls.from_registry(name=data["type"])
def __str__(self):
return format_class_str(instance=self)
def get_gmm_registry():
"""Get GMM registry"""
# TODO: automatically download and cache stuff from
# "https://raw.githubusercontent.com/adonath/jolideco-gmm-library/main/"
filename = "$JOLIDECO_GMM_LIBRARY/jolideco-gmm-library-index.json"
path = Path(os.path.expandvars(filename))
log.debug(f"Reading GMM registry from {path}")
with path.open() as f:
data = json.load(f)
return data
GMM_REGISTRY = get_gmm_registry()