Source code for libertem_holo.base.filters

"""Useful image filtering helpers."""
import numpy as np
import sparse
from scipy import ndimage
from scipy.optimize import least_squares
from scipy.signal import fftconvolve
from skimage.filters import window
from skimage.restoration import unwrap_phase


[docs]def highpass(img: np.ndarray, sigma: float = 2) -> np.ndarray: """Return highpass by subtracting a gaussian lowpass filter.""" return img - ndimage.gaussian_filter(img, sigma=sigma)
[docs]def exclusion_mask(img: np.ndarray, sigma: float = 6) -> np.ndarray: """Generate outlier mask. Return a mask with `True` entries for pixels deviating more than `sigma` from the mean. """ return np.abs(img) > (img.mean() + sigma * img.std())
[docs]def clipped(img: np.ndarray, sigma: float = 6): """Mask out outliers. Return `img`, but with pixels deviating more than `sigma` from the mean masked out. Useful for plotting: >>> plt.imshow(img, vmax=np.max(clipped(img))) # doctest: +SKIP """ sigma_mask = exclusion_mask(img, sigma=sigma) sigma_mask = ~sigma_mask return img[sigma_mask]
[docs]def phase_ramp_finding(img, order=1): """Find a phase ramp in `img`. A phase ramp finding function that is used to find the phase ramp across the field of view. Parameters ---------- img : 2d nd array Complex image or phase image. order : int Phase ramp, 1 (default) is linear. ramp : 2d tuple, () Phase ramp in y, x, if not None. Returns ------- ramp, order, tuple, float """ # The ramp is determined by the maximum and minimum values of the image. # TODO least-square-fitting, polynomial order # TODO How to find phase ramps in complex images if img.dtype.kind != 'c': if order == 1: ramp_x = np.mean(np.gradient(img, axis=0)) ramp_y = np.mean(np.gradient(img, axis=1)) ramp = (ramp_y, ramp_x) else: raise ValueError(f"can only handle `order=1` for now, not order={order}") else: raise ValueError(f"cannot handle input of type {img.dtype}") return ramp
[docs]def phase_ramp_removal(size, order=1, ramp=None): """Remove phase ramp. A phase ramp removal function that finds and removes a phase ramp across the field of view. Parameters ---------- size : 2d tuple, () Size of the Complex image or phase image order : int Phase ramp, 1 (default) is linear. ramp : 2d tuple, () Phase ramp in y, x, if not None. Returns ------- 2d nd array of the corrected image """ # TODO How to find phase ramps in complex images img = np.zeros(size) if ramp is None: ramp = phase_ramp_finding(size, order=1) else: (ramp_y, ramp_x) = ramp yy = np.arange(0, size[0], 1) xx = np.arange(0, size[1], 1) y, x = np.meshgrid(yy, xx) if order == 1: img = ramp_x * x + ramp_y * y else: # To be expanded. raise ValueError(f"cannot handle order={order}") return img
[docs]def phase_unwrap(image): """ A phase_unwrap function that is unwrap the complex / wrapped phase image. Parameters ---------- image : 2d nd array Complex or Wrapped phase image Returns ------- 2d nd array of the unwrapped phase image """ if image.dtype.kind != 'c': image_new = unwrap_phase(image) else: angle = np.angle(image) image_new = unwrap_phase(angle) return image_new
[docs]def remove_dead_pixels(img, sigma_lowpass=2.0, sigma_exclusion=6.0): """Remove dead pixels. Parameters ---------- img : np.array Input array sigma_lowpass : float How much of the low frequencies should be removed before finding bad pixels sigma_exclusion : float Pixels deviating more than this value from the mean will be removed """ from libertem.corrections.detector import correct mask = exclusion_mask(highpass(img, sigma=sigma_lowpass), sigma=sigma_exclusion) coords = sparse.COO(mask) return correct( buffer=img.reshape((1, *img.shape)), excluded_pixels=coords.coords, sig_shape=tuple(img.shape), ).squeeze()
[docs]def window_filter(input_array, window_type, window_shape): """Apply window-based filter. Return a filtered array with the same size of the input array Parameters ---------- input_array: array Input array window_type : string, float or tuple The type of window to be created. Any window type supported by ``scipy.signal.get_window`` is allowed here. See notes below for a current list, or the SciPy documentation for the version of SciPy on your machine. window_shape : tuple of int or int The shape of the window. If an integer is provided, a 2D window is generated. Notes ----- This function is based on ``scipy.signal.get_window`` and thus can access all of the window types available to that function (e.g., ``"hann"``, ``"boxcar"``). Note that certain window types require parameters that have to be supplied with the window name as a tuple (e.g., ``("tukey", 0.8)``). If only a float is supplied, it is interpreted as the beta parameter of the Kaiser window. https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.windows.get_window.html it is recommended to check the that after fft shift, the input array has value of 0 at the border. """ if isinstance(window_shape, int): window_shape = (window_shape, window_shape) win = window(window_type, window_shape) array_filtered = np.fft.fftshift(fftconvolve(np.fft.fftshift(input_array), win, mode="same")) return array_filtered / np.max(array_filtered)
[docs]def ramp_compensation(image): """ A ramp or wedge compensation for a 2D image with a linear optimization methods. Parameters ---------- image : 2D-Array Input array """ def linear_gradient(c, dy, dx, y, x): return c+y*dy+x*dx x = np.linspace(0, image.shape[0]-1, image.shape[0]) y = np.linspace(0, image.shape[1]-1, image.shape[1]) def fun(initial_value): function = image_not_compensated - linear_gradient(initial_value[0], initial_value[1], initial_value[2], yv, xv) return function.reshape((-1, )) yv, xv = np.meshgrid(y, x) image_not_compensated = np.copy(image) m_initial = np.gradient(image_not_compensated) dy_initial = np.mean(m_initial[0]) dx_initial = np.mean(m_initial[1]) c_initial = image[0, 0] initial_value = np.array([c_initial, dy_initial, dx_initial]) res1 = least_squares(fun, initial_value) gradient_compensation = linear_gradient(res1.x[0], res1.x[1], res1.x[2], yv, xv) return image - gradient_compensation