import logging
import numpy as np
import torch
import torch.nn.functional as F
__all__ = [
"convolve_fft_torch",
"view_as_overlapping_patches_torch",
"view_as_random_overlapping_patches_torch",
"view_as_windows_torch",
"TORCH_DEFAULT_DEVICE",
"interp1d_torch",
"grid_weights",
"get_default_generator",
]
TORCH_DEFAULT_DEVICE = "cpu"
log = logging.getLogger(__name__)
def transpose(x):
"""Transpose tensor"""
return x.permute(*torch.arange(x.ndim - 1, -1, -1))
[docs]
def grid_weights(x, y, x0, y0):
"""Compute 4-pixel weights such that centroid is preserved."""
dx = torch.abs(x - x0)
dx = torch.where(dx < 1, 1 - dx, 0)
dy = torch.abs(y - y0)
dy = torch.where(dy < 1, 1 - dy, 0)
return dx * dy
def cycle_spin(image, patch_shape, generator):
"""Cycle spin
Parameters
----------
image : `~pytorch.Tensor`
Image tensor
patch_shape : tuple of int
Patch shape
generator : `~torch.Generator`
Random number generator
Returns
-------
image, shifts: `~pytorch.Tensor`, tuple of `~pytorch.Tensor`
Shifted tensor
"""
x_max, y_max = patch_shape
x_width, y_width = x_max // 4, y_max // 4
shift_x = torch.randint(
-x_width, x_width + 1, (1,), generator=generator, device=image.device
)
shift_y = torch.randint(
-y_width, y_width + 1, (1,), generator=generator, device=image.device
)
shifts = (int(shift_x), int(shift_y))
dims = (image.ndim - 2, image.ndim - 1)
return torch.roll(image, shifts=shifts, dims=dims), shifts
def cycle_spin_subpixel(image, generator):
"""Cycle spin
Parameters
----------
image : `~pytorch.Tensor`
Image tensor
generator : `~torch.Generator`
Random number generator
Returns
-------
image: `~pytorch.Tensor`
Shifted tensor
"""
grid = torch.arange(-1, 2, device=image.device)
y, x = torch.meshgrid(grid, grid, indexing="ij")
x_0 = torch.rand(1, generator=generator, device=image.device) - 0.5
y_0 = torch.rand(1, generator=generator, device=image.device) - 0.5
kernel = grid_weights(x, y, x_0, y_0)
kernel = kernel.reshape((1, 1, 3, 3))
return F.conv2d(image, kernel, padding="same")
[docs]
def interp1d_torch(x, xp, fp, **kwargs):
"""Linear interpolation
Parameters
----------
x : `~torch.Tensor`
x values
xp : `~torch.Tensor`
x values
fp : `~torch.Tensor`
x values
Returns
-------
interp : `~torch.Tensor`
Interpolated x values
"""
idx = torch.clip(torch.searchsorted(xp, x), 0, len(xp) - 2)
y0, y1 = fp[idx - 1], fp[idx]
x0, x1 = xp[idx - 1], xp[idx]
weights = (x - x0) / (x1 - x0)
return torch.lerp(y0, y1, weights, **kwargs)
[docs]
def view_as_windows_torch(image, shape, stride):
"""View tensor as overlapping rectangular windows
Parameters
----------
image : `~torch.Tensor`
Image tensor
shape : tuple
Shape of the patches.
stride : int
Stride of the patches. By default it is half of the patch size.
Returns
-------
windows : `~torch.Tensor`
Tensor of overlapping windows
"""
if stride is None:
stride = shape[0] // 2
windows = image.unfold(image.ndim - 2, shape[0], stride)
return windows.unfold(image.ndim - 1, shape[0], stride)
[docs]
def view_as_overlapping_patches_torch(image, shape, stride=None):
"""View tensor as overlapping rectangular patches
Parameters
----------
image : `~torch.Tensor`
Image tensor
shape : tuple
Shape of the patches.
stride : int
Stride of the patches. By default it is half of the patch size.
Returns
-------
patches : `~torch.Tensor`
Tensor of overlapping patches of shape
(n_patches, patch_shape_flat)
"""
if stride is None:
stride = shape[0] // 2
patches = view_as_windows_torch(image=image, shape=shape, stride=stride)
ncols = shape[0] * shape[1]
return torch.reshape(patches, (-1, ncols))
[docs]
def view_as_random_overlapping_patches_torch(image, shape, stride, generator):
"""View tensor as randomly ("jittered") overlapping rectangular patches
Parameters
----------
image : `~torch.Tensor`
Image tensor
shape : tuple
Shape of the patches.
stride : int
Stride of the patches. By default it is half of the patch size.
generator : `~torch.Generator`
Random number generator
Returns
-------
patches : `~torch.Tensor`
Tensor of overlapping patches of shape
(n_patches, patch_shape_flat)
"""
overlap = max(shape) - stride
ny, nx = image.shape[-2:]
idx = torch.arange(overlap, nx - stride - overlap, stride, device=image.device)
idy = torch.arange(overlap, ny - stride - overlap, stride, device=image.device)
jitter_x = torch.randint(
low=-overlap,
high=overlap + 1,
size=(len(idx),),
generator=generator,
device=image.device,
)
jitter_y = torch.randint(
low=-overlap,
high=overlap + 1,
size=(len(idy),),
generator=generator,
device=image.device,
)
idx += jitter_x
idy += jitter_y
idy, idx = torch.meshgrid(idy, idx, indexing="ij")
patches = view_as_windows_torch(image=image, shape=shape, stride=1)
# idx = torch.clip(idx, 0, patches.shape[-1])
# idy = torch.clip(idy, 0, patches.shape[-2])
patches = patches[:, :, idy, idx]
size = np.multiply(*shape)
n_patches = np.multiply(*idx.shape)
patches = torch.reshape(patches, (n_patches, size))
return patches
def _centered(arr, newshape):
# Return the center newshape portion of the array.
newshape = torch.tensor(newshape)
currshape = torch.tensor(arr.shape)
startind = torch.div(currshape - newshape, 2, rounding_mode="trunc")
endind = startind + newshape
myslice = [slice(startind[k], endind[k]) for k in range(len(endind))]
return arr[tuple(myslice)]
[docs]
def convolve_fft_torch(image, kernel):
"""Convolve FFT for torch tensors
Parameters
----------
image : `~torch.Tensor`
Image tensor
kernel : `~torch.Tensor`
Kernel tensor
Returns
-------
result : `~torch.Tensor`
Convolution result
"""
# TODO: support 4D tensors
image_2d, kernel_2d = image[0][0], kernel[0][0]
shape = [image_2d.shape[i] + kernel_2d.shape[i] - 1 for i in range(image_2d.ndim)]
image_ft = torch.fft.rfft2(image, s=shape)
kernel_ft = torch.fft.rfft2(kernel, s=shape)
result = torch.fft.irfft2(image_ft * kernel_ft, s=shape)
return _centered(result, image.shape)
[docs]
def get_default_generator(device):
"""Get default torch generator
Parameters
----------
device : str
Device name
Returns
-------
generator : `~torch.Generator`
Random number generator
"""
try:
generator = torch.Generator(device=device)
except RuntimeError:
log.warning(
f"Device {device} not available, falling back to {TORCH_DEFAULT_DEVICE}"
)
generator = torch.Generator(device=TORCH_DEFAULT_DEVICE)
return generator