Source code for libertem.udf.masks

from typing_extensions import Literal
from libertem.common.math import prod
import numpy as np

from libertem.common.udf import UDFMethod
from libertem.udf import UDF, UDFMeta
from libertem.common.buffers import AuxBufferWrapper
from libertem.common.container import MaskContainer
from libertem.common.numba import rmatmul


class ApplyMasksEngine:
    def __init__(self, masks: MaskContainer, meta: UDFMeta, use_torch: bool):
        self.masks = masks
        self.meta = meta

        try:
            import torch
        except ImportError:
            torch = None

        torch_incompatible = (
            torch is None
            or self.meta.input_dtype.kind != 'f'
            or self.meta.input_dtype != self.masks.dtype
            or self.meta.device_class != 'cpu'
            or self.meta.array_backend != UDF.BACKEND_NUMPY
            or self.masks.use_sparse
        )

        self.needs_transpose = True
        if use_torch and (not torch_incompatible):
            self.process_flat = self._process_flat_torch
        elif (
            self.meta.array_backend == UDF.BACKEND_NUMPY
            and isinstance(self.masks.use_sparse, str)
            and self.masks.use_sparse.startswith('scipy.sparse')
        ):
            # Due to https://github.com/scipy/scipy/issues/13211
            self.process_flat = self._process_flat_spsp
        elif (
            self.meta.array_backend in (
                UDF.BACKEND_SCIPY_COO,
                UDF.BACKEND_SCIPY_CSR,
                UDF.BACKEND_SCIPY_CSC
            ) and isinstance(self.masks.use_sparse, str)
            and self.masks.use_sparse.startswith('sparse.pydata')
        ):
            self.process_flat = self._process_flat_sparsepyd
            self.needs_transpose = False
        else:
            self.process_flat = self._process_flat_standard

    def _get_masks(self):
        return self.masks.get_for_sig_slice(
            self.meta.sig_slice, transpose=self.needs_transpose
        )

    def _process_flat_torch(self, flat_tile, masks):
        import torch
        # CuPy back-end disables torch in get_task_data
        # FIXME use GPU torch with CuPy array?
        return torch.mm(
            torch.from_numpy(flat_tile),
            torch.from_numpy(masks),
        ).numpy()

    def _process_flat_spsp(self, flat_tile, masks):
        return rmatmul(flat_tile, masks)

    def _process_flat_sparsepyd(self, flat_tile, masks):
        # Make sure the sparse.pydata mask comes first
        # to choose the right multiplication method
        return (masks @ flat_tile.T).T

    def _process_flat_standard(self, flat_tile, masks):
        return flat_tile @ masks

    def process_tile(self, tile):
        flat_shape = (tile.shape[0], prod(tile.shape[1:]))
        # Avoid reshape since older versions of scipy.sparse don't support it
        flat_data = tile.reshape(flat_shape) if tile.shape != flat_shape else tile
        return self.process_flat(flat_data, self._get_masks())

    def process_frame_shifted(self, frame, shifts: tuple[int, ...]):
        sig_shape = self.meta.dataset_shape.sig
        masks = self._get_masks()
        num_masks = len(self.masks)
        shifted_slice = self.meta.sig_slice.shift_by(shifts)
        inverse_shifted_slice = self.meta.sig_slice.shift_by(-1 * shifts)
        left = self.meta.sig_slice.intersection_with(shifted_slice)
        right = self.meta.sig_slice.intersection_with(inverse_shifted_slice)
        if left.is_null():
            # Zero overlap after shifts, shortcut return
            return np.zeros((num_masks,), dtype=np.float32)
        mask_slice = right.get()
        if self.needs_transpose:
            # expects masks in shape (sig_size, num_masks)
            mask_slice = mask_slice + (slice(None), )
            masks = masks.reshape((*sig_shape, num_masks))
            final_mask_shape = (-1, num_masks)
        else:
            # expects masks in shape (num_masks, sig_size)
            # NOTE unexpectedly don't need to reverse sig_shape or mask_slice ?
            masks = masks.reshape((num_masks, *sig_shape))
            mask_slice = (slice(None), ) + mask_slice
            final_mask_shape = (num_masks, -1)

        sliced_masks = masks[mask_slice].reshape(final_mask_shape)
        # shift slicing requires sig_shape frames
        # sparse array backend can provide flat frame
        frame = frame.reshape(sig_shape)
        try:
            data = left.get(frame)
        except TypeError as e:
            # frame is in a form which doesn't support slicing
            # the only recognized case is scipy.sparse.coo
            if not hasattr(frame, 'getformat'):
                raise e  # pragma: no cover
            assert frame.getformat() == 'coo'
            frame = frame.tocsr()
            data = left.get(frame)
        flat_data = data.reshape((1, -1))
        return self.process_flat(flat_data, sliced_masks).reshape((num_masks,))


[docs] class ApplyMasksUDF(UDF): ''' Apply masks to signals/frames in the dataset. This can not only be used to integrate over regions with a binary mask - the integration can be weighted by using float or complex valued masks. The result will be returned in a single sig-shaped buffer called intensity. Its shape will be :code:`(*nav_shape, len(masks))`. Parameters ---------- mask_factories : Union[Callable[[], array_like], Iterable[Callable[[], array_like]]] Function or list of functions that take no arguments and create masks. The returned masks can be numpy arrays, scipy.sparse or sparse https://sparse.pydata.org/ matrices. The mask factories should not reference large objects because they can create significant overheads when they are pickled and unpickled. Each factory function should, when called, return a numpy array with the same shape as frames in the dataset (so dataset.shape.sig). use_torch : bool, optional Use pytorch back-end if available. Default True use_sparse : Union[None, False, True, 'scipy.sparse', 'scipy.sparse.csc', \ 'sparse.pydata'], optional Which sparse back-end to use. * None (default): Where possible use sparse matrix multiplication if all factory \ functions return a sparse mask, otherwise convert all masks to dense matrices \ and use dense matrix multiplication * True: Convert all masks to sparse matrices and use default sparse back-end. * False: Convert all masks to dense matrices. * 'scipy.sparse': Use scipy.sparse.csr_matrix (default sparse) * 'scipy.sparse.csc': Use scipy.sparse.csc_matrix * 'sparse.pydata': Use sparse.pydata COO matrix mask_count : int, optional Specify the number of masks if a single factory function is used so that the number of masks can be determined without calling the factory function. mask_dtype : numpy.dtype, optional Specify the dtype of the masks so that mask dtype can be determined without calling the mask factory functions. This can be used to override the mask dtype in the result dtype determination. As an example, setting this to np.float32 means that masks of type float64 will not switch the calculation and result dtype to float64 or complex128. preferred_dtype : numpy.dtype, optional Let :meth:`get_preferred_input_dtype` return the specified type instead of the default `float32`. This can perform the calculation with integer types if both input data and mask data are compatible with this. backends : Iterable containing strings "numpy" and/or "cupy", or None Control which back-ends are used. Default is numpy and cupy shifts : Union[Tuple[int, int], AuxBufferWrapper], optional (Y/X)-shifts to apply to all masks before multiplying with each frame. Can be either a length-2 array-like for a constant :code:`(Y, X)` shift, or an :class:`~libertem.common.buffers.AuxBufferWrapper` of :code:`(kind='nav', extra_shape=(2,), dtype=int)` defining a per-frame shift to apply. A positive y-shift moves the mask 'down' relative to the frame, while a positive x-shift moves the mask 'right' relative to the frame. Elements of the mask and frame which do not overlap after the shift are discarded. A shift resulting in no overlap at all will return a sum of :code:`0.` for that frame. .. note:: Float shift values are cast to integers internally; round values before passing the shifts argument to better control the exact shifts used. .. note:: The :code:`shifts` parameter requires frame-by-frame processing to function, and so adds a performance penalty compared to unshifted mask application. If applying a constant shift it may be worthwhile to manually create a new, pre-shifted mask rather than relying on this feature. Shifting is also currently incompatible with :code:`scipy.sparse` masks. If sparse processing is required then where possible :code:`scipy.sparse` masks are converted to :code:`sparse.pydata` equivalents. A consequence of this is that sparse processing is not yet supported through CuPy when shifts are enabled, as :code:`sparse.pydata` has no current CuPy implementation. If sparse masks are supplied on a CuPy backend when :code:`use_sparse=None` (the default) they will be densified to allow the calculation to take place. Examples -------- >>> dataset.shape (16, 16, 32, 32) >>> def my_masks(): ... return [np.ones((32, 32)), np.zeros((32, 32))] >>> udf = ApplyMasksUDF(mask_factories=my_masks) >>> res = ctx.run_udf(dataset=dataset, udf=udf)['intensity'] >>> res.data.shape (16, 16, 2) >>> np.allclose(res.data[..., 1], 0) # same order as in the mask factory True Mask factories can also return all masks as a single array, stacked on the first axis: >>> def my_masks_2(): ... masks = np.zeros((2, 32, 32)) ... masks[1, ...] = 1 ... return masks >>> udf = ApplyMasksUDF(mask_factories=my_masks_2) >>> res_2 = ctx.run_udf(dataset=dataset, udf=udf)['intensity'] >>> np.allclose(res_2.data, res.data) True Masks can be shifted relative to the data using the :code:`shifts` parameter, this can either be a constant shift for all frames: >>> udf = ApplyMasksUDF(mask_factories=my_masks, shifts=(2, -5)) >>> res_shift_constant = ctx.run_udf(dataset=dataset, udf=udf)['intensity'] or a per-frame shift supplied using an :class:`~libertem.common.buffers.AuxBufferWrapper` created using :meth:`~libertem.udf.base.UDF.aux_data`: >>> shifts = np.random.randint(-8, 8, size=(16, 16, 2)).ravel() >>> udf = ApplyMasksUDF( ... mask_factories=my_masks, ... shifts=ApplyMasksUDF.aux_data( ... shifts, ... kind='nav', ... extra_shape=(2,), ... dtype=shifts.dtype, ... ) ... ) >>> res_shift_aux = ctx.run_udf(dataset=dataset, udf=udf)['intensity'] .. versionadded:: 0.4.0 .. versionchanged:: 0.13.0 Added the :code:`shifts` parameter ''' def __init__(self, mask_factories, use_torch=True, use_sparse=None, mask_count=None, mask_dtype=None, preferred_dtype=None, backends=None, shifts=None, **kwargs): _backends = backends not_supported = ( self.BACKEND_SCIPY_COO_ARRAY, self.BACKEND_SCIPY_CSR_ARRAY, self.BACKEND_SCIPY_CSC_ARRAY, ) supported_backends = tuple(b for b in self.BACKEND_ALL if b not in not_supported) if backends is None: backends = supported_backends backends = tuple(b for b in backends if b in supported_backends) if shifts is not None: if isinstance(use_sparse, str) and use_sparse.startswith('scipy.sparse'): # This is 'unsupported' because we need to slice the mask stack # in the signal dimensions to shift it, and sig is flattened # to give to 2D matrix in scipy.sparse raise ValueError( f'Sparse backend {use_sparse} not supported for ' 'shifts, use sparse.pydata instead.' ) if not isinstance(shifts, AuxBufferWrapper): shifts = np.asarray(shifts) backends = tuple( b for b in backends if b not in ( # Here we are doing frame-by-frame processing, so we can # accept scipy.sparse-style frames, however we need to # perform a reshape into sig-shaped frames which normally # casts the frame into coo_matrix form, which then # has to be re-cast into csr_matrix form to be sliced (i.e. shifted) self.BACKEND_SCIPY_COO, # cannot be sliced self.BACKEND_CUPY_SCIPY_COO, # cannot be sliced ) ) if len(backends) == 0: raise ValueError(f'No compatible backend found in {_backends}') self._mask_container = None super().__init__( mask_factories=mask_factories, use_torch=use_torch, use_sparse=use_sparse, mask_count=mask_count, mask_dtype=mask_dtype, preferred_dtype=preferred_dtype, backends=backends, shifts=shifts, **kwargs ) def get_preferred_input_dtype(self): '' if self.params.preferred_dtype is None: return super().get_preferred_input_dtype() else: return self.params.preferred_dtype def get_mask_dtype(self): if self.params.mask_dtype is None: return self.masks.dtype else: return self.params.mask_dtype def get_mask_count(self): if self.params.mask_count is None: return len(self.masks) else: return self.params.mask_count @property def masks(self): if self._mask_container is None: self._mask_container = self._make_mask_container() return self._mask_container def _make_mask_container(self): p = self.params if self.meta.array_backend in self.CUPY_BACKENDS: backend = self.BACKEND_CUPY else: backend = self.BACKEND_NUMPY # In the default case defer to default kwarg on MaskContainer default_sparse = {} if p.shifts is None: default_sparse['default_sparse'] = 'scipy.sparse' else: default_sparse['default_sparse'] = 'sparse.pydata' return MaskContainer( p.mask_factories, dtype=p.mask_dtype, use_sparse=p.use_sparse, count=p.mask_count, backend=backend, **default_sparse, ) def get_task_data(self): '' engine = ApplyMasksEngine(self.masks, self.meta, self.params.use_torch) return { 'engine': engine, } def get_result_buffers(self): '' dtype = np.result_type(self.meta.input_dtype, self.get_mask_dtype()) count = self.get_mask_count() return { 'intensity': self.buffer( kind='nav', extra_shape=(count, ), dtype=dtype, where='device' ) } def get_backends(self): '' return self.params.backends def get_method(self) -> Literal[UDFMethod.FRAME, UDFMethod.TILE]: """ :meta private: """ if self.params.get('shifts') is not None: return UDFMethod.FRAME else: return UDFMethod.TILE def process_tile(self, tile): """ Used for simple mask application, without shifts :meta private: """ self.results.intensity[:] += self.forbuf( self.task_data.engine.process_tile(tile), self.results.intensity, ) def process_frame(self, frame): """ Apply shifted masks to a frame :meta private: """ shifts = self.params.shifts.astype(int) self.results.intensity[:] += self.forbuf( self.task_data.engine.process_frame_shifted(frame, shifts), self.results.intensity, )