Chandra Data Analysis With Jolideco

Chandra Data Analysis With Jolideco#

In this tutorial we will demonstrate how to use Jolideco together with an example Chandra dataset to perform image deconvolution. We will use 24 observations of a small region with an elomgated filament of the supernova remnant E0102. The dataset is available from Zenodo: https://zenodo.org/records/10849740

To prepare Chandra data for analysis with Joldideco you can use the following workflow: adonath/snakemake-workflow-chandra

A similar analysis from the Jolideco paper can be found in the following repository:

jolideco/jolideco-chandra-e0102-zoom-a

Let’s start with the following imports:

import tarfile
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
from astropy.io import fits
from astropy.utils.data import download_file
from astropy.visualization import simple_norm

from jolideco.core import MAPDeconvolver, MAPDeconvolverResult
from jolideco.models import (
    FluxComponents,
    NPredCalibration,
    NPredCalibrations,
    SpatialFluxComponent,
)
from jolideco.priors import GaussianMixtureModel, GMMPatchPrior
from jolideco.utils.norms import IdentityImageNorm

random_state = np.random.RandomState(428723)


URL = "https://zenodo.org/records/10849740/files/chandra-e0102-filament-all.tar.gz"

# Run decomvolution or use precomputed result
RUN_DECONVOLUTION = False

device = "cuda:0" if torch.cuda.is_available() else "cpu"

First we download and extract the data files:

path = Path("").absolute() / "chandra-e0102-filament-all"

if not path.exists():
    filename = download_file(URL, cache=True)
    with tarfile.open(filename, "r:gz") as tar:
        tar.extractall(path.parent, filter=lambda member, path: member)

Next we extract the observation IDs from the filenames:

filenames_counts = path.glob("e0102-zoom-a-*-counts.fits")

obs_ids = [int(filename.stem.split("-")[3]) for filename in filenames_counts]

The exampled dataset contains 24 observations. We will use the first observation as a reference for the calibration. By convention the reference observation has an exposure of unity, while all the other observations have an exposure that is relative to the reference observation.

obs_id_ref = 8365

Next we load the data and create a dictionary with the datasets:

def read_data(filename, dtype=np.float32):
    return fits.getdata(path / filename).astype(dtype)


datasets = {}

for obs_id in obs_ids:
    dataset = {}
    dataset["counts"] = read_data(f"e0102-zoom-a-{obs_id}-counts.fits")

    psf_slice = slice(254 - 64, 254 + 64)
    psf = read_data(f"e0102-zoom-a-{obs_id}-e0102-zoom-a-marx-psf.fits")
    dataset["psf"] = {"filament-flux": psf[psf_slice, psf_slice]}

    dataset["exposure"] = read_data(f"e0102-zoom-a-{obs_id}-exposure.fits")
    datasets[f"obs-{obs_id}"] = dataset

Let’s plot the counts for each observation:

fig, axes = plt.subplots(4, 6, figsize=(12, 8))

for ax, (name, dataset) in zip(axes.flat, datasets.items()):
    ax.imshow(dataset["counts"], origin="lower")
    ax.set_title(name)

plt.tight_layout()
plt.show()
obs-6765, obs-10654, obs-10656, obs-13093, obs-6766, obs-14258, obs-24577, obs-9694, obs-22805, obs-8365, obs-19850, obs-17688, obs-17380, obs-25618, obs-15467, obs-20639, obs-26987, obs-6759, obs-16589, obs-18418, obs-21804, obs-6758, obs-10655, obs-11957

As you can see the number of counts vary between the observations. Indicating the different exposure times.

Now we can plot the PSF images as well:

fig, axes = plt.subplots(4, 6, figsize=(12, 8))

for ax, (name, dataset) in zip(axes.flat, datasets.items()):
    psf = dataset["psf"]["filament-flux"]
    norm = simple_norm(psf, stretch="log")
    ax.imshow(psf, origin="lower", norm=norm)
    ax.set_title(name)

plt.tight_layout()
plt.show()
obs-6765, obs-10654, obs-10656, obs-13093, obs-6766, obs-14258, obs-24577, obs-9694, obs-22805, obs-8365, obs-19850, obs-17688, obs-17380, obs-25618, obs-15467, obs-20639, obs-26987, obs-6759, obs-16589, obs-18418, obs-21804, obs-6758, obs-10655, obs-11957

We can see again that the PSF varies between the observations. However this is something we can handle with Jolideco.

In addition to the counts, PSF and exposure we will also provide a background. For now we will just use a constant background:

for dataset in datasets.values():
    dataset["background"] = 0.1 * np.ones_like(dataset["counts"])

To run Jolideco we first need to define the Gaussian Mixture Model (GMM) to be used with patch prior. As we have sufficient data we can use the the GMM learned from the JWST Cas A image, which imposes rather strong correlation between the pixels.

gmm = GaussianMixtureModel.from_registry("jwst-cas-a-v0.1")
gmm.meta.stride = 4
print(gmm)
GaussianMixtureModel
--------------------

  type                  : jwst-cas-a-v0.1

For illustration we can also plot the mean images of the GMM:

gmm.plot_mean_images(ncols=16, figsize=(12, 8))
plt.show()
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127

Now we can define the patch prior:

patch_prior = GMMPatchPrior(
    gmm=gmm,
    cycle_spin=True,
    norm=IdentityImageNorm(),
    device=device,
)

We have specified to use cycle spinning, which is a technique to reduce the impact of the fixed patch grid on the result.

shape = datasets[f"obs-{obs_id_ref}"]["counts"].shape
flux_init = random_state.normal(loc=3, scale=0.01, size=shape).astype(np.float32)

Now we can define the spatial flux component. More specifically we defined the initial flux, the prior and the upsampling factor. We also specify that the internal flux representation is in log space.

component = SpatialFluxComponent.from_numpy(
    flux=flux_init,
    prior=patch_prior,
    use_log_flux=True,
    upsampling_factor=2,
)

components = FluxComponents()
components["filament-flux"] = component

print(components)
FluxComponents
--------------

  filament-flux         :

    use_log_flux        : True
    upsampling_factor   : 2
    frozen              : False
    prior               :

      type              : gmm-patches
      stride            : 4
      cycle_spin        : True
      cycle_spin_subpix : False
      jitter            : False
      gmm               :

        type            : jwst-cas-a-v0.1

      norm              :

        type            : identity

      patch_norm        :

        type            : std-subtract-mean

      device            : cpu

When working with a real dataset it is important to “calibrate” the expected number of counts. This is done by defining a calibration model for each dataset. This model include three additional parameters that are used to adjust the expected number of counts. The background normalization and absolute shift in the x and y direction.

calibrations = NPredCalibrations()

for name in datasets:
    calibration = NPredCalibration(background_norm=1.0, frozen=False)
    calibrations[name] = calibration

We freeze the shift parameters for the reference observation:

calibrations[f"obs-{obs_id_ref}"].shift_xy.requires_grad = False

print(calibrations)
NPredCalibrations
-----------------

  obs-6765              :

    shift_x             : 0.000
    shift_y             : 0.000
    background_norm     : 1.000
    frozen              : False

  obs-10654             :

    shift_x             : 0.000
    shift_y             : 0.000
    background_norm     : 1.000
    frozen              : False

  obs-10656             :

    shift_x             : 0.000
    shift_y             : 0.000
    background_norm     : 1.000
    frozen              : False

  obs-13093             :

    shift_x             : 0.000
    shift_y             : 0.000
    background_norm     : 1.000
    frozen              : False

  obs-6766              :

    shift_x             : 0.000
    shift_y             : 0.000
    background_norm     : 1.000
    frozen              : False

  obs-14258             :

    shift_x             : 0.000
    shift_y             : 0.000
    background_norm     : 1.000
    frozen              : False

  obs-24577             :

    shift_x             : 0.000
    shift_y             : 0.000
    background_norm     : 1.000
    frozen              : False

  obs-9694              :

    shift_x             : 0.000
    shift_y             : 0.000
    background_norm     : 1.000
    frozen              : False

  obs-22805             :

    shift_x             : 0.000
    shift_y             : 0.000
    background_norm     : 1.000
    frozen              : False

  obs-8365              :

    shift_x             : 0.000
    shift_y             : 0.000
    background_norm     : 1.000
    frozen              : False

  obs-19850             :

    shift_x             : 0.000
    shift_y             : 0.000
    background_norm     : 1.000
    frozen              : False

  obs-17688             :

    shift_x             : 0.000
    shift_y             : 0.000
    background_norm     : 1.000
    frozen              : False

  obs-17380             :

    shift_x             : 0.000
    shift_y             : 0.000
    background_norm     : 1.000
    frozen              : False

  obs-25618             :

    shift_x             : 0.000
    shift_y             : 0.000
    background_norm     : 1.000
    frozen              : False

  obs-15467             :

    shift_x             : 0.000
    shift_y             : 0.000
    background_norm     : 1.000
    frozen              : False

  obs-20639             :

    shift_x             : 0.000
    shift_y             : 0.000
    background_norm     : 1.000
    frozen              : False

  obs-26987             :

    shift_x             : 0.000
    shift_y             : 0.000
    background_norm     : 1.000
    frozen              : False

  obs-6759              :

    shift_x             : 0.000
    shift_y             : 0.000
    background_norm     : 1.000
    frozen              : False

  obs-16589             :

    shift_x             : 0.000
    shift_y             : 0.000
    background_norm     : 1.000
    frozen              : False

  obs-18418             :

    shift_x             : 0.000
    shift_y             : 0.000
    background_norm     : 1.000
    frozen              : False

  obs-21804             :

    shift_x             : 0.000
    shift_y             : 0.000
    background_norm     : 1.000
    frozen              : False

  obs-6758              :

    shift_x             : 0.000
    shift_y             : 0.000
    background_norm     : 1.000
    frozen              : False

  obs-10655             :

    shift_x             : 0.000
    shift_y             : 0.000
    background_norm     : 1.000
    frozen              : False

  obs-11957             :

    shift_x             : 0.000
    shift_y             : 0.000
    background_norm     : 1.000
    frozen              : False

And we define the deconvolver. We will use a learning rate of 0.1 and a beta of 1.0. The number of epochs is set to 250. We also specify that the computation should be done on the CPU. If you have a GPU you can set the device to “cuda:0” or any other valid PyTorch device.

deconvolve = MAPDeconvolver(n_epochs=250, learning_rate=0.1, beta=1.0, device=device)
print(deconvolve)
MAPDeconvolver
--------------

  n_epochs              : 250
  beta                  : 1.000
  learning_rate         : 0.100
  compute_error         : False
  stop_early            : False
  stop_early_n_average  : 10
  display_progress      : True
  device                : cpu
  optimizer             : adam

Now we can run the deconvolution. This will take a while (~30 min. on an M1 cpu), so we will not run it in this notebook. But if you have GPU acceleration it should not take more than a few minutes.

filename_result = path / "chandra-e0102-filament-jolideco.fits"

if RUN_DECONVOLUTION:
    result = deconvolve.run(
        components=components,
        calibrations=calibrations,
        datasets=datasets,
    )
    result.write(filename_result, overwrite=True)

It is very good practice to always write the result to disk, after running the deconvolution. This is especially important for large datasets, as it allows to continue the analysis at a later time.

Thus we just continue with the precomputed result:

Now we can plot the result:

fig, axes = plt.subplots(1, 2, figsize=(12, 6))

counts_all = np.sum([_["counts"] for _ in datasets.values()], axis=0)

axes[0].imshow(counts_all, origin="lower")
axes[0].set_title("Counts")

axes[1].imshow(result.flux_total, origin="lower")
axes[1].set_title("Flux Jolideco")

plt.show()
Counts, Flux Jolideco

The result looks very promising. The filament is clearly visible in the deconvolved image. However, the result is not perfect. There are still some artifacts visible in the deconvolved image, which comes from the fact that there is an individual shift per observation and currently there is no dedicated boundary handling for this.

Gallery generated by Sphinx-Gallery