Source code for libertem_holo.udf.reconstr

"""UDFs for hologram reconstruction.

Based on the functions available in :code:`libertem_holo.base.reconstr`.
"""
from __future__ import annotations

from typing import Any

import numpy as np
from libertem.udf import UDF

from libertem_holo.base.mask import disk_aperture
from libertem_holo.base.reconstr import get_slice_fft, reconstruct_frame


[docs]class HoloReconstructUDF(UDF): """Reconstruct off-axis electron holograms using a Fourier-based method. Running :meth:`~libertem.api.Context.run_udf` on an instance of this class will reconstruct a complex electron wave. Use the :code:`wave` key to access the raw data in the result. See :ref:`holography app` for detailed application example .. versionadded:: 0.3.0 Examples -------- >>> shape = tuple(dataset.shape.sig) >>> sb_position = [2, 3] >>> sb_size = 4.4 >>> aperture = disk_aperture(out_shape=shape, radius=sb_size) >>> holo_udf = HoloReconstructUDF( ... out_shape=shape, ... sb_position=sb_position, ... aperture=aperture, ... ) >>> wave = ctx.run_udf(dataset=dataset, udf=holo_udf)['wave'].data """ def __init__( self, *, out_shape: tuple[int, int], sb_position: tuple[float, float], aperture: np.ndarray, precision: bool = True, ) -> None: """Off-axis electron holography reconstruction. The aperture is built outside of the UDF, which enables flexibility for both adding additional filtering (like a line filter), and for applying smoothing. Parameters ---------- out_shape Shape of the returned complex wave image. Note that the result should fit into the main memory. See :ref:`holography app` for more details sb_position Coordinates of sideband position with respect to non-shifted FFT of a hologram precision Defines precision of the reconstruction, True for complex128 for the resulting complex wave, otherwise results will be complex64 aperture The aperture used to mask out the sideband. Should have a shape equal to the `out_shape` parameter, and should be fft-shifted (i.e. assume that the side band is shifted to the corners of the image) """ super().__init__( out_shape=out_shape, sb_position=sb_position, precision=precision, aperture=aperture, ) def get_result_buffers(self) -> dict[str, Any]: "" extra_shape = self.params.out_shape dtype = np.complex128 if self.params.precision else np.complex64 return { "wave": self.buffer(kind="nav", dtype=dtype, extra_shape=extra_shape), } def get_task_data(self) -> dict[str, Any]: "" slice_fft = get_slice_fft( self.params.out_shape, self.meta.partition_shape.sig, ) return { "aperture": self.xp.array(self.params.aperture), "slice": slice_fft, } def process_frame(self, frame: np.ndarray) -> None: "" wav = reconstruct_frame( frame, sb_pos=self.params.sb_position, aperture=self.task_data.aperture, slice_fft=self.task_data.slice, precision=self.params.precision, xp=self.xp, ) self.results.wave[:] = self.forbuf(wav, self.results.wave) def get_backends(self) -> tuple[str, ...]: "" return ("numpy", "cupy")
[docs] @classmethod def with_default_aperture( cls, *, out_shape: tuple[int, int], sb_size: float, sb_position: tuple[float, float], precision: bool = True, ) -> HoloReconstructUDF: """Instantiate with a default disk-shaped aperture. Examples -------- >>> udf = HoloReconstructUDF.with_default_aperture( ... out_shape=(128, 128), ... sb_size=7.6, ... sb_position=(32, 32), ... ) """ aperture = disk_aperture(out_shape=out_shape, radius=sb_size) return cls( out_shape=out_shape, sb_position=sb_position, aperture=aperture, precision=precision, )