Source code for libertem_holo.base.align

from __future__ import annotations

import math
import typing
from typing import Literal, NamedTuple

try:
    import cupy as cp
except ImportError:
    cp = None
import numpy as np
import numpy.typing as npt
import matplotlib.pyplot as plt
from sparseconverter import NUMPY, for_backend
from scipy.ndimage import gaussian_filter
import logging

from libertem_holo.base.reconstr import get_slice_fft, HoloParams, get_phase, reconstruct_bf
from libertem_holo.base.filters import central_line_filter, disk_aperture

log = logging.getLogger(__name__)


def _upsampled_dft(
    corrspecs: npt.NDArray,
    frequencies: tuple[np.ndarray, np.ndarray],
    upsampled_region_size: int,
    axis_offsets: tuple[float, float],
) -> np.ndarray:
    """
    From https://github.com/LiberTEM/LiberTEM-blobfinder, which is itself
    heavily adapted from skimage.registration._phase_cross_correlation.py
    which is itself based on code by Manuel Guizar released initially under a
    BSD 3-Clause license @ https://www.mathworks.com/matlabcentral/fileexchange/18401

    :meta private:
    """
    im2pi = -1j * 2 * np.pi
    upsampled = corrspecs
    for (ax_freq, ax_offset) in zip(frequencies[::-1], axis_offsets[::-1]):
        kernel = np.linspace(
            -ax_offset,
            (-ax_offset + upsampled_region_size - 1),
            num=int(upsampled_region_size),
        )
        kernel = np.exp(kernel[:, None] * ax_freq * im2pi, dtype=np.complex64)
        # Equivalent to:
        #   data[i, j, k] = kernel[i, :] @ data[j, k].T
        upsampled = np.tensordot(kernel, upsampled, axes=(1, -1))
    return upsampled


def _plot_cross_correlate(*, shifted_corr, pos, plot_title, src, target):
    fig, ax = plt.subplots(3, sharex=True, sharey=True)
    ax[0].imshow(for_backend(shifted_corr, NUMPY))
    ax[0].plot(pos[1], pos[0], 'x', color='red')
    ax[1].imshow(for_backend(src, NUMPY))
    ax[1].plot(pos[1], pos[0], 'x', color='red')
    ax[2].imshow(for_backend(target, NUMPY))
    ax[2].plot(pos[1], pos[0], 'x', color='red')
    fig.suptitle(plot_title)


[docs] def cross_correlate( src, target, plot: bool = False, plot_title: str = "", normalization: Literal['phase'] | None = 'phase', upsample_factor=1, xp=np, ) -> tuple[np.ndarray, np.ndarray]: """Rigid image registration by cross-correlation. Supports optional phase normalization. Based on the `phase_cross_correlation` function of scikit-image, but with added GPU support via cupy, and some debugging facilities built in. Parameters ========== src The static image, either a numpy or cupy array (if cupy, you should set `xp=cp`, too) target The moving image, either a numpy or cupy array (if cupy, you should set `xp=cp`, too) normalization 'phase' or None, same as for `phase_cross_correlation` upsample_factor Subpixel scaling factor, same as for `phase_cross_correlation`. Note that we test with a factor of up to 10; with larger values the results may not improve. xp numpy or cupy """ src = xp.asarray(src) target = xp.asarray(target) src_freq = xp.fft.fftn(src) target_freq = xp.fft.fftn(target) image_product = src_freq * target_freq.conj() if normalization == 'phase': eps = np.finfo(image_product.real.dtype).eps image_product /= np.maximum(np.abs(image_product), 100 * eps) elif normalization is not None: raise ValueError(f"unknown normalization {normalization}") cross_correlation = xp.fft.ifftn(image_product) shifted_corr = xp.fft.fftshift(np.abs(cross_correlation)) maxima = xp.unravel_index( xp.argmax(shifted_corr), shifted_corr.shape ) float_dtype = image_product.real.dtype midpoint = xp.array([xp.fix(axis_size / 2) for axis_size in src_freq.shape]) shift = xp.stack(maxima).astype(float_dtype, copy=False) shift -= midpoint # estimate sublixel shifts using the upsampled DFT method: if upsample_factor > 1: frequencies = ( xp.fft.fftfreq(src.shape[0], upsample_factor), xp.fft.fftfreq(src.shape[1], upsample_factor), ) # Initial shift estimate in upsampled grid upsample_factor = xp.array(upsample_factor, dtype=float_dtype) shift = xp.round(shift * upsample_factor) / upsample_factor upsampled_region_size = xp.ceil(upsample_factor * 1.5) # Center of output array at dftshift + 1 dftshift = xp.fix(upsampled_region_size / 2.0) # Matrix multiply DFT around the current shift estimate sample_region_offset = dftshift - np.round(shift * upsample_factor) cross_correlation = _upsampled_dft( image_product.conj(), frequencies, upsampled_region_size, sample_region_offset, ).conj() # Locate maximum and map back to original pixel grid maxima = xp.unravel_index( xp.argmax(xp.abs(cross_correlation)), cross_correlation.shape ) maxima = xp.stack(maxima).astype(float_dtype, copy=False) maxima -= dftshift shift += maxima / upsample_factor if xp is np: shift = tuple(float(x) for x in shift) else: shift = tuple(float(for_backend(x, NUMPY)) for x in shift) # for "backwards compat", return correlation maxima and not shift pos = xp.array(shift) + midpoint if plot: _plot_cross_correlate( shifted_corr=shifted_corr, pos=pos, plot_title=plot_title, src=src, target=target, ) return pos, shifted_corr
def gradient(image: np.ndarray, scale=1): scale = [scale] * image.ndim gradients = np.gradient(np.asarray(image), *scale) return np.stack(gradients, axis=-1)
[docs] def get_grad_angle(image, scale=3): """From an image, get the angle of the gradient.""" grad = gradient(image, scale=scale) return np.arctan2(grad[..., 0], grad[..., 1])
[docs] def get_grad_xy(image, scale=3): """From an image, get the angle of the gradient.""" grad = gradient(image, scale=scale) return (grad[..., 0], grad[..., 1])
[docs] def is_left( a: np.ndarray, b: np.ndarray, c: np.ndarray, ) -> np.ndarray: """ Points a and b are points on a line and result is an array of True or False if c is left or right of line resp. """ return (b[1] - a[1])*(c[0] - a[0]) - (b[0] - a[0])*(c[1] - a[1]) > 0
[docs] class RegResult(NamedTuple): maximum: tuple[float, float] shift: tuple[float, float] corrmap: np.ndarray
class Correlator: def prepare_input( self, img: np.ndarray, ) -> typing.Any: raise NotImplementedError() def correlate( self, ref_image: typing.Any, moving_image: typing.Any, plot: bool = False, ) -> RegResult: raise NotImplementedError()
[docs] class ImageCorrelator(Correlator): """ Cross correlation based image registration with some pre-filtering. Assumes real-space images as input. """ def __init__( self, upsample_factor: int = 1, normalization: Literal['phase'] | None = 'phase', hanning: bool = True, binning: int = 1, xp: typing.Any = np, ) -> None: self._xp = xp self._upsample_factor = upsample_factor self._normalization = normalization self._hanning = hanning self._binning = binning self._zoom_factor = 1 / binning def prepare_input( self, img: np.ndarray, ) -> typing.Any: xp = self._xp if xp is np: import scipy.ndimage as ni else: import cupyx.scipy.ndimage as ni # apply hanning filter: if self._hanning: img = img * xp.outer(xp.hanning(img.shape[0]), xp.hanning(img.shape[1])) # apply binning: if self._zoom_factor != 1: img = ni.zoom(img, self._zoom_factor) return img def correlate( self, ref_image: typing.Any, moving_image: typing.Any, plot: bool = False, ) -> RegResult: pos, corrmap = cross_correlate( ref_image, moving_image, xp=self._xp, plot=plot, upsample_factor=self._upsample_factor, normalization=self._normalization, ) pos_rel = ( pos[0] - (moving_image.shape[0]) // 2, pos[1] - (moving_image.shape[1]) // 2, ) if self._binning != 1: pos_rel = ( pos_rel[0] / self._zoom_factor, pos_rel[1] / self._zoom_factor, ) return RegResult(maximum=pos, shift=pos_rel, corrmap=corrmap)
[docs] class BiprismDeletionCorrelator(Correlator): """ Cross correlation on low magnification while removing biprism. """ def __init__( self, mask: np.ndarray, upsample_factor: int = 1, normalization: Literal['phase'] | None = 'phase', xp: typing.Any = np, ) -> None: self._mask = mask self._xp = xp self._upsample_factor = upsample_factor self._normalization = normalization def prepare_input( self, img: np.ndarray, ) -> typing.Any: overview = np.zeros_like(img) overview[:] = img overview[self._mask] = img.mean() return overview def correlate( self, ref_image: typing.Any, moving_image: typing.Any, plot: bool = False, ) -> RegResult: pos, corrmap = cross_correlate( ref_image, moving_image, xp=self._xp, plot=plot, upsample_factor=self._upsample_factor, normalization=self._normalization, ) pos_rel = ( pos[0] - (moving_image.shape[0]) // 2, pos[1] - (moving_image.shape[1]) // 2, ) return RegResult(maximum=pos, shift=pos_rel, corrmap=corrmap)
[docs] @classmethod def plot_get_coords(cls, img, coords_out): """ At low magnification, plot image of area with biprism visible. Click on edges of biprism to create the coordinates to mask it out, for cross correlation. First, click one edge of biprism from left side, then same edge, right side. Then, click other edge of biprism from left side, then right side. -----1-------2----- -----3-------4----- """ fig, ax = plt.subplots(1) ax.imshow(img) def onclick(event): plt.plot(event.xdata, event.ydata, 'ro') coords_out.append((event.ydata, event.xdata)) if len(coords_out) == 4: fig.canvas.mpl_disconnect(cid) cid = fig.canvas.mpl_connect('button_press_event', onclick)
[docs] @classmethod def get_masked(cls, img, coords): """Uses coordinates from plot_get_coords to create a mask of biprism.""" yx = np.mgrid[0:img.shape[0], 0:img.shape[1]] mask = is_left(coords[0], coords[1], yx) & ~ is_left(coords[2], coords[3], yx) return mask
[docs] class BrightFieldCorrelator(Correlator): """ Cross correlation on bright field of hologram. """ def __init__( self, holoparams: HoloParams, upsample_factor: int = 1, normalization: Literal['phase'] | None = 'phase', xp: typing.Any = np, ) -> None: self._holoparams = holoparams self._xp = xp self._normalization = normalization self._upsample_factor = upsample_factor def prepare_input( self, img: np.ndarray, ) -> typing.Any: holoparams = self._holoparams line_filter = central_line_filter( sb_position=holoparams.sb_position_int, out_shape=holoparams.out_shape, orig_shape=img.shape, length_ratio=0.95, width=20 ) aperture = disk_aperture(out_shape=holoparams.out_shape, radius=holoparams.sb_size//3) slice_fft = get_slice_fft(out_shape=holoparams.out_shape, sig_shape=img.shape) line_filter = line_filter[slice_fft] aperture[line_filter] = 0 aperture = np.fft.fftshift(gaussian_filter(aperture, sigma=6)) holo_bf = np.abs( reconstruct_bf( frame=img, aperture=aperture, slice_fft=slice_fft, xp=self._xp ) ) holo_bf = np.gradient(holo_bf)[0] return holo_bf def correlate( self, ref_image: typing.Any, moving_image: typing.Any, plot: bool = False, ) -> RegResult: pos, corrmap = cross_correlate( ref_image, moving_image, xp=self._xp, plot=plot, upsample_factor=self._upsample_factor, normalization=self._normalization, ) pos_rel = ( pos[0] - (moving_image.shape[0]) // 2, pos[1] - (moving_image.shape[1]) // 2, ) return RegResult(maximum=pos, shift=pos_rel, corrmap=corrmap)
[docs] class PhaseImageCorrelator(Correlator): """ Cross correlation on reconstructed phase image. """ def __init__( self, holoparams: HoloParams, upsample_factor: int = 1, normalization: Literal['phase'] | None = 'phase', xp: typing.Any = np, ) -> None: self._holoparams = holoparams self._xp = xp self._normalization = normalization self._upsample_factor = upsample_factor def prepare_input( self, img: np.ndarray, ) -> typing.Any: holoparams = self._holoparams phase = get_phase(img, holoparams, xp=self._xp) return phase def correlate( self, ref_image: typing.Any, moving_image: typing.Any, plot: bool = False, ) -> RegResult: pos, corrmap = cross_correlate( ref_image, moving_image, xp=self._xp, plot=plot, normalization=self._normalization, upsample_factor=self._upsample_factor, ) pos_rel = ( pos[0] - (moving_image.shape[0]) // 2, pos[1] - (moving_image.shape[1]) // 2, ) return RegResult(maximum=pos, shift=pos_rel, corrmap=corrmap)
[docs] class GradAngleCorrelator(Correlator): """ Cross correlation on gradient angle of phase image. """ def __init__( self, holoparams: HoloParams, upsample_factor: int = 1, normalization: Literal['phase'] | None = 'phase', xp: typing.Any = np, ) -> None: self._holoparams = holoparams self._xp = xp self._normalization = normalization self._upsample_factor = upsample_factor def prepare_input( self, img: np.ndarray, ) -> np.ndarray: holoparams = self._holoparams grad_angle = get_grad_angle(get_phase(img, holoparams, xp=self._xp)) return grad_angle def correlate( self, ref_image: typing.Any, moving_image: typing.Any, plot: bool = False, ) -> RegResult: pos, corrmap = cross_correlate( ref_image, moving_image, xp=self._xp, plot=plot, upsample_factor=self._upsample_factor, normalization=self._normalization, ) pos_rel = ( pos[0] - (moving_image.shape[0]) // 2, pos[1] - (moving_image.shape[1]) // 2, ) return RegResult(maximum=pos, shift=pos_rel, corrmap=corrmap)
[docs] class GradXYCorrelator(Correlator): """ Cross correlation on gradient x and Y, correlation maps summed. """ def __init__( self, holoparams: HoloParams, xp: typing.Any = np, ) -> None: self._holoparams = holoparams self._xp = xp def prepare_input( self, img: np.ndarray, ) -> typing.Any: holoparams = self._holoparams (grad_x, grad_y) = get_grad_xy( get_phase(img, holoparams, xp=self._xp), scale=3, ) # because `gradient` interpolates at the edge, we get a nice # vertical artifact that the cross correlation latches onto, # so we need to slice the edges away. take care to slice # everything the same, so the shapes match, and we need to # slice enough such that the interpolated region is removed # completely (I think this relates to the `scale` argument # above): return (grad_x[4:-5, 4:-5], grad_y[4:-5, 4:-5]) def correlate( self, ref_image: typing.Any, moving_image: typing.Any, plot: bool = False, ) -> RegResult: xp = self._xp ref_image_x, ref_image_y = ref_image moving_image_x, moving_image_y = moving_image pos_x, corrmap_x = cross_correlate( ref_image_x, moving_image_x, xp=xp, plot=plot, ) pos_y, corrmap_y = cross_correlate( ref_image_y, moving_image_y, xp=xp, plot=plot, ) corrmap = corrmap_x + corrmap_y pos = xp.unravel_index(xp.argmax(corrmap), corrmap.shape) if xp is np: pos = tuple(float(x) for x in pos) else: pos = tuple(float(for_backend(x, NUMPY)) for x in pos) pos_rel = ( pos[0] - (moving_image_y.shape[0]) // 2, pos[1] - (moving_image_y.shape[1]) // 2, ) return RegResult(maximum=pos, shift=pos_rel, corrmap=corrmap)
[docs] class NoopCorrelator(Correlator): """Do nothing, successfully.""" def prepare_input( self, img: np.ndarray, ) -> typing.Any: return img def correlate( self, ref_image: typing.Any, moving_image: typing.Any, plot: bool = False, ) -> RegResult: return RegResult(maximum=0.0, shift=0.0, corrmap=np.zeros_like(ref_image))
[docs] def align_stack( stack: np.ndarray, wave_stack: np.ndarray, static: np.ndarray | None, correlator: Correlator | None = None, xp=np, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Align stacks of N holograms. Parameters ========== stack Stack of images with shape (N, h, w) which is used for the alignment. It should be of dtype float32 or float64, so for holograms, you may want to align on the abs(wave). You can also slice the original data, if you want to align to a region of interest. wave_stack The (complex) result of the reconstruction. This should have shape (N, h', w'), meaning it should have the same number of complex images as the stack argument, but the 2d items in the stack can have a different shape. static A reference image to align against. If this is not given, the first image of the stack will be taken. correlator A :class:`Correlator` instance. By default, an :class:`ImageCorrelator` is used, with reasonable default parameters. If you need control over these, you can construct your own :class:`Correlator` and pass it in. A requirement is that it has to work on the value given as the stack parameter. Returns ======= aligned_stack The aligned stack. Same shape and dtype as wave_stack shifts The shifts that were applied to the stack. Shape (N, 2) reference The pre-processed reference image, as used in the registration. Useful for debugging, for example if the correlator filters its input. corrs The correlation maps, as returned from the correlator. Useful for debugging. """ wave_stack = xp.asarray(wave_stack) stack = xp.asarray(stack) aligned_stack = xp.zeros_like(wave_stack) if correlator is None: correlator = ImageCorrelator( upsample_factor=10, normalization='phase', hanning=True, binning=1, xp=xp, ) if static is None: reference = stack[0] else: reference = static reference = correlator.prepare_input(reference) corrs = xp.zeros((stack.shape[0],) + reference.shape, dtype=np.float32) shifts = xp.zeros((wave_stack.shape[0], 2), dtype="float32") if xp is np: import scipy.ndimage as ni else: import cupyx.scipy.ndimage as ni for i, (reg_frame, wave_frame) in enumerate(zip(stack, wave_stack)): wave_frame = xp.asarray(wave_frame) pre_reg_frame = correlator.prepare_input(reg_frame) reg_result = correlator.correlate( reference, pre_reg_frame, plot=False, ) corrs[i] = reg_result.corrmap shifted = xp.fft.ifft2(ni.fourier_shift( xp.fft.fftn(wave_frame), xp.asarray(reg_result.shift), )) # support for non-complex data: explicitly discard imaginary part if not np.iscomplexobj(wave_stack): shifted = shifted.real aligned_stack[i] = shifted shifts[i] = xp.stack(reg_result.shift) return aligned_stack, shifts, reference, corrs
[docs] def stack_alignment_quality(wave_stack: np.ndarray, shifts): """Stack quality ,judged by standard deviation on the stacking axis. This should be mostly noise, if not, there may be issues from the alignment. """ std_image = np.std(np.abs(wave_stack), axis=0) offset = math.ceil(shifts.max()) + 1 return std_image[offset:-offset, offset:-offset]