"""Utility functions for working with holography data."""
# Functions freq_array, aperture_function, estimate_sideband_position
# estimate_sideband_size are adopted from Hyperspy
# and are subject of following copyright:
#
# Copyright 2007-2016 The HyperSpy developers
#
# HyperSpy is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# HyperSpy is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with HyperSpy. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright 2019 The LiberTEM developers
#
# LiberTEM is distributed under the terms of the GNU General
# Public License as published by the Free Software Foundation,
# version 3 of the License.
# see: https://github.com/LiberTEM/LiberTEM
from __future__ import annotations
from functools import lru_cache
from typing import Any, Literal
import typing
import logging
try:
import cupy as cp
except ImportError:
cp = None
import numpy as np
from numpy.fft import fft2
from skimage.draw import polygon
from scipy.ndimage import gaussian_filter
from sparseconverter import NUMPY, for_backend
log = logging.getLogger(__name__)
XPType = Any # Union[Module("numpy"), Module("cupy")]
[docs]def freq_array(
shape: tuple[int, int],
sampling: tuple[float, float] = (1.0, 1.0),
) -> np.ndarray:
"""Generate a frequency array.
Parameters
----------
shape : (int, int)
The shape of the array.
sampling: (float, float), optional, (Default: (1., 1.))
The sampling rates of the array.
Returns
-------
Array of the frequencies.
"""
f_freq_1d_y = np.fft.fftfreq(shape[0], sampling[0])
f_freq_1d_x = np.fft.fftfreq(shape[1], sampling[1])
f_freq_mesh = np.meshgrid(f_freq_1d_x, f_freq_1d_y)
return np.hypot(f_freq_mesh[0], f_freq_mesh[1])
[docs]def get_slice_fft(
out_shape: tuple[int, int],
sig_shape: tuple[int, int],
) -> tuple[slice, slice]:
"""Get a slice in fourier space to achieve the given output shape."""
sy, sx = sig_shape
oy, ox = out_shape
y_min = int(sy / 2 - oy / 2)
y_max = int(sy / 2 + oy / 2)
x_min = int(sx / 2 - ox / 2)
x_max = int(sx / 2 + ox / 2)
return (slice(y_min, y_max), slice(x_min, x_max))
[docs]def estimate_sideband_position(
holo_data: np.ndarray,
holo_sampling: tuple[float, float],
central_band_mask_radius: float | None = None,
sb: Literal["lower", "upper"] = "lower",
xp: XPType = np,
) -> tuple[float, float]:
"""Find the position of the sideband and return its position.
Parameters
----------
holo_data: ndarray
The data of the hologram.
holo_sampling: tuple
The sampling rate in both image directions.
central_band_mask_radius: float, optional
The aperture radius used to mask out the centerband.
sb : str, optional
Chooses which sideband is taken. 'lower' or 'upper'
xp
Pass in either the numpy or cupy module to select CPU or GPU processing
Returns
-------
Tuple of the sideband position (y, x), referred to the unshifted FFT.
"""
from .filters import disk_aperture
sb_position = (0, 0)
f_freq = freq_array(holo_data.shape, holo_sampling)
# If aperture radius of centerband is not given, it will be set to 5 % of
# the Nyquist frequency.
if central_band_mask_radius is None:
central_band_mask_radius = 1 / 20.0 * np.max(f_freq)
aperture = disk_aperture(
holo_data.shape,
central_band_mask_radius,
xp=xp,
)
# A small aperture masking out the centerband.
aperture_central_band = np.subtract(1.0, aperture)
# imitates 0
fft_holo = fft2(holo_data) / np.prod(holo_data.shape)
fft_filtered = fft_holo * aperture_central_band
# Sideband position in pixels referred to unshifted FFT
if sb == "lower":
fft_sb = fft_filtered[: int(fft_filtered.shape[0] / 2), :]
sb_position = (
np.unravel_index(np.abs(fft_sb).argmax(), fft_sb.shape)
)
elif sb == "upper":
fft_sb = fft_filtered[int(fft_filtered.shape[0] / 2):, :]
sb_position = np.unravel_index(np.abs(fft_sb).argmax(), fft_sb.shape)
sb_position = (
xp.add(
xp.asarray(sb_position),
xp.asarray([int(fft_filtered.shape[0] / 2), 0]),
)
)
if xp is cp:
sb_position = sb_position.get()
return tuple(float(c) for c in sb_position)
[docs]def estimate_sideband_size(
sb_position: tuple[float, float],
holo_shape: tuple[int, int],
sb_size_ratio: float = 0.5,
xp: XPType = np,
) -> float:
"""Estimate the size of sideband filter.
Parameters
----------
holo_shape : array_like
Holographic data array
sb_position : tuple
The sideband position (y, x), referred to the non-shifted FFT.
sb_size_ratio : float, optional
Size of sideband as a fraction of the distance to central band
xp
Pass in either the numpy or cupy module to select CPU or GPU processing
Returns
-------
sb_size : float
Size of sideband filter
"""
h = (
np.array(
(
np.asarray(sb_position) - np.asarray([0, 0]),
np.asarray(sb_position) - np.asarray([0, holo_shape[1]]),
np.asarray(sb_position) - np.asarray([holo_shape[0], 0]),
np.asarray(sb_position) - np.asarray(holo_shape),
),
)
* sb_size_ratio
)
return float(np.min(np.linalg.norm(h, axis=1)))
[docs]class HoloParams(typing.NamedTuple):
"""Holoparams class contians all parameters necessary for reconstruction."""
sb_size: tuple[float, float]
sb_position: tuple[float, float]
aperture: np.ndarray # actually can be a cupy ndarray, too! hm...
orig_shape: tuple[int, int]
out_shape: tuple[int, int]
scale_factor: float # by how much is the phase image scaled down
xp: XPType
@property
def sb_position_int(self) -> tuple[int, int]:
"""Sideband position from float to int."""
return tuple(
int(c)
for c in self.sb_position
)
[docs] @classmethod
def from_hologram(
cls,
hologram: np.ndarray,
*,
central_band_mask_radius: int,
out_shape: tuple = None,
line_filter_length: float = 0.9,
line_filter_width: float | None = 20,
xp: XPType = np,
) -> HoloParams:
"""Determine reconstruction parameters from a hologram.
Automatically estimates sideband position and size, and returns
the main parameters needed for holography reconstruction.
Parameters
----------
hologram
A single hologram, can be either a numpy or a cupy array
central_band_mask_radius
When estimating the sideband position, use a mask of this size
to remove the central band
out_shape
The reconstruction shape, should be larger than the sideband size
line_filter_length
Length ratio of the line filter; as a fraction of the distance between
the central band and the sideband
line_filter_width
Width of the line filter, in pixels. Passing in `None` will disable
the line filter completely
xp
Pass in either the numpy or cupy module to select CPU or GPU processing
"""
from .filters import butterworth_line, butterworth_disk
hologram = xp.asarray(hologram)
sb_position = estimate_sideband_position(
holo_data=hologram,
holo_sampling=(1, 1),
sb='upper',
central_band_mask_radius=central_band_mask_radius,
xp=xp,
)
sb_size = estimate_sideband_size(sb_position, hologram.shape, xp=xp)
if out_shape is None:
out_side = 2 * int(sb_size) + 16
out_shape = (out_side, out_side)
fft_slice = get_slice_fft(out_shape, hologram.shape)
# Disk aperture
aperture = butterworth_disk(hologram.shape, radius=sb_size, order=20)
sb_position_int = tuple(
int(c)
for c in sb_position
)
if line_filter_width is None:
aperture = np.fft.fftshift(aperture[fft_slice])
else:
lf = butterworth_line(
shape=hologram.shape,
width=line_filter_width,
sb_position=fft_shift_coords(
sb_position_int, shape=hologram.shape
),
length_ratio=line_filter_length,
order=2,
)
aperture = np.fft.fftshift(aperture[fft_slice] * lf[fft_slice])
aperture = xp.asarray(aperture)
return cls(
sb_size=sb_size,
sb_position=sb_position,
aperture=aperture,
out_shape=out_shape,
orig_shape=hologram.shape,
scale_factor=out_shape[0] / hologram.shape[0],
xp=xp,
)
def filter_aperture_gaussian(self, sigma: float) -> HoloParams:
aperture = for_backend(self.aperture, NUMPY)
new_aperture = self.xp.asarray(gaussian_filter(aperture, sigma=sigma))
return HoloParams(
sb_size=self.sb_size,
sb_position=self.sb_position,
aperture=new_aperture,
orig_shape=self.orig_shape,
out_shape=self.out_shape,
scale_factor=self.scale_factor,
xp=self.xp,
)
@lru_cache
def shifted_coords_for_shape(shape):
return np.fft.fftshift(np.moveaxis(np.mgrid[0:shape[0], 0:shape[1]], 0, -1), axes=(0, 1))
def fft_shift_coords(pos, shape):
coords = shifted_coords_for_shape(shape)
return tuple(int(n) for n in coords[pos[0], pos[1]])
[docs]def other_sb(sb_position, shape):
"""
Given the sb_position (as from the estimate function in fft coordinates),
and the shape of the hologram, calculate the position of the other sideband
position (also in fft coordinates)
"""
sb_pos_shifted = fft_shift_coords(sb_position, shape)
center = (shape[0]//2, shape[1]//2)
center_to_sb = (
sb_pos_shifted[0]-center[0],
sb_pos_shifted[1]-center[1],
)
other_sb = (
center[0] - center_to_sb[0],
center[1] - center_to_sb[1],
)
return fft_shift_coords(other_sb, shape)
def line_filter_coords(length_ratio, sb_position_shifted, width, orig_shape):
# let's start from a "unit rectangle" which is 1x1 and not rotated:
coords = np.array([
[0, 0],
[0, 1],
[1, 1],
[1, 0],
]).astype(np.float32)
# shift to the origin in y direction:
coords -= np.array([0.5, 0])
# let's determine the length from the sideband position and ratio:
center = (orig_shape[0]//2, orig_shape[1]//2)
center_to_sb = (
sb_position_shifted[0]-center[0],
sb_position_shifted[1]-center[1],
)
sb_dist = np.linalg.norm(np.array(center_to_sb))
length = sb_dist * length_ratio
length_rest = sb_dist * (1 - length_ratio)
# stretch such that the width (in x direction) corresponds to the length,
# and the height (y direction) corresponds to the width of the filter:
scale = np.array([
[width, 0],
[0, length],
])
# angle from -pi to +pi between the "x-axis" and the vector from center to sb:
angle = np.arctan2(*center_to_sb)
rotate = np.array([
[np.cos(-angle), np.sin(-angle)],
[-np.sin(-angle), np.cos(-angle)],
])
# apply scale:
coords = coords @ scale
# move to the right:
coords += np.array([0, length_rest])
# rotate:
coords = coords @ rotate
coords += np.array(center)
return coords
def draw_lf_rect(dest, orig_shape, sb_position_shifted, length_ratio, width):
# we "draw" a rotated rectangle into `dest`, starting at `sb_position_shifted`
# and ending at `length_ratio` times the vector in the direction to the center of `out_shape`.
coords = line_filter_coords(
length_ratio=length_ratio,
sb_position_shifted=sb_position_shifted,
width=width,
orig_shape=orig_shape
)
rr, cc = polygon(coords[:, 0], coords[:, 1], shape=dest.shape)
dest[rr, cc] = True