import functools
import numpy as np
import sparseconverter
from libertem.udf import UDF
from libertem.common.container import MaskContainer
from libertem_blobfinder.base import masks
from libertem_blobfinder.common.patterns import MatchPattern
import libertem_blobfinder.base.correlation as ltbc
from libertem_blobfinder.common.correlation import get_peaks
[docs]class CorrelationUDF(UDF):
'''
Base class for peak correlation implementations
'''
[docs] def __init__(self, peaks, zero_shift=None, *args, **kwargs):
'''
Parameters
----------
peaks : numpy.ndarray
Numpy array of (y, x) coordinates with peak positions in px to correlate
zero_shift : Union[AUXBufferWrapper, numpy.ndarray, None], optional
Zero shift, for example descan error. Can be :code:`None`, :code:`numpy.array((y, x))`
or AUX data with :code:`(y, x)` for each frame.
'''
super().__init__(peaks=np.round(peaks).astype(int), zero_shift=zero_shift, *args, **kwargs)
[docs] def get_result_buffers(self):
'''
The common buffers for all correlation methods.
:code:`centers`:
(y, x) integer positions. NOTE: the returned positions
can be out-of-frame and the user should perform bounds
checking if directly indexing into the frame array.
:code:`refineds`:
(y, x) positions with subpixel refinement.
:code:`peak_values`:
Peak height in the log scaled frame.
:code:`peak_elevations`:
Peak quality (result of :meth:`peak_elevation`).
See source code for details of the buffer declaration.
'''
num_disks = len(self.params.peaks)
return {
'centers': self.buffer(
kind="nav", extra_shape=(num_disks, 2), dtype=np.int32,
),
'refineds': self.buffer(
kind="nav", extra_shape=(num_disks, 2), dtype="float32"
),
'peak_values': self.buffer(
kind="nav", extra_shape=(num_disks,), dtype="float32",
),
'peak_elevations': self.buffer(
kind="nav", extra_shape=(num_disks,), dtype="float32",
),
}
[docs] def output_buffers(self):
'''
This function allows abstraction of the result buffers from
the default implementation in :meth:`get_result_buffers`.
Override this function if you wish to redirect the results to different
buffers, for example ragged arrays or binned processing.
'''
r = self.results
return (r.centers, r.refineds, r.peak_values, r.peak_elevations)
def postprocess(self):
pass
def get_peaks(self):
return self.params.peaks
def get_zero_shift(self, index=None):
if self.params.zero_shift is None:
result = np.array((0, 0))
elif index is None:
# Called when masked with view
result = self.params.zero_shift
else:
# Called when not masked, in postprocess() etc.
result = self.params.zero_shift[index]
return result
[docs]class FastCorrelationUDF(CorrelationUDF):
'''
Fourier-based fast correlation-based refinement of peak positions within a search frame
for each peak.
'''
[docs] def __init__(self, peaks, match_pattern, zero_shift=None, *args, **kwargs):
'''
Parameters
----------
peaks : numpy.ndarray
Numpy array of (y, x) coordinates with peak positions in px to correlate
match_pattern : MatchPattern
Instance of :class:`~libertem_blobfinder.MatchPattern`
zero_shift : Union[AUXBufferWrapper, numpy.ndarray, None], optional
Zero shift, for example descan error. Can be :code:`None`, :code:`numpy.array((y, x))`
or AUX data with :code:`(y, x)` for each frame.
upsample: Union[bool, int], optional
Use DFT upsampling for the refinement step, by default False. Supplying
True will choose a reasonable default upsampling factor, while any
positive integer > 1 will upsample the correlation peak by this factor.
DFT upsampling can provide more accurate center values, especially when
peak shifts are small, but does require more computation time.
'''
# For testing purposes, allow to inject a different limit via
# an internal kwarg
# It has to come through kwarg because of how UDFs are run
self.limit = kwargs.get('__limit', 2**19) # 1/2 MB
super().__init__(
peaks=peaks, match_pattern=match_pattern, zero_shift=zero_shift, *args, **kwargs
)
def get_task_data(self):
""
n_peaks = len(self.get_peaks())
mask = self.get_pattern()
crop_size = mask.get_crop_size()
template = self.xp.array(mask.get_template(sig_shape=(2 * crop_size, 2 * crop_size)))
dtype = np.result_type(self.meta.input_dtype, np.float32)
crop_bufs = ltbc.allocate_crop_bufs(
crop_size, n_peaks, dtype=dtype, limit=self.limit, xp=self.xp
)
if self.meta.array_backend in (
self.BACKEND_SPARSE_COO, self.BACKEND_SPARSE_GCXS, self.BACKEND_CUPY):
crop_function = ltbc.crop_disks_from_frame_slicing
elif self.meta.array_backend in (self.BACKEND_NUMPY, ):
crop_function = ltbc.crop_disks_from_frame
else: # pragma: no cover
raise RuntimeError(f"Unsupported array backend {self.meta.array_backend}")
kwargs = {
'crop_bufs': crop_bufs,
'template': template,
'crop_function': crop_function,
}
return kwargs
def get_pattern(self):
return self.params.match_pattern
def get_template(self):
return self.task_data.template
def process_frame(self, frame):
match_pattern = self.get_pattern()
(centers, refineds, peak_values, peak_elevations) = self.output_buffers()
ltbc.process_frame_fast(
template=self.get_template(), crop_size=match_pattern.get_crop_size(),
frame=frame, peaks=self.get_peaks() + np.round(self.get_zero_shift()).astype(int),
out_centers=centers, out_refineds=refineds,
out_heights=peak_values, out_elevations=peak_elevations,
crop_bufs=self.task_data.crop_bufs,
upsample=self.params.get('upsample', False),
crop_function=self.task_data.crop_function,
)
[docs] def get_backends(self):
return (
self.BACKEND_NUMPY,
self.BACKEND_CUPY,
self.BACKEND_SPARSE_COO,
self.BACKEND_SPARSE_GCXS,
)
[docs]class FullFrameCorrelationUDF(CorrelationUDF):
'''
Fourier-based correlation-based refinement of peak positions within a search
frame for each peak using a single correlation step. This can be faster for
correlating a large number of peaks in small frames in comparison to
:class:`FastCorrelationUDF`. However, it is more sensitive to interference
from strong peaks next to the peak of interest.
.. versionadded:: 0.3.0
'''
[docs] def __init__(self, peaks, match_pattern, zero_shift=None, *args, **kwargs):
'''
Parameters
----------
peaks : numpy.ndarray
Numpy array of (y, x) coordinates with peak positions in px to correlate
match_pattern : MatchPattern
Instance of :class:`~libertem_blobfinder.MatchPattern`
zero_shift : Union[AUXBufferWrapper, numpy.ndarray, None], optional
Zero shift, for example descan error. Can be :code:`None`, :code:`numpy.array((y, x))`
or AUX data with :code:`(y, x)` for each frame.
upsample: Union[bool, int], optional
Use DFT upsampling for the refinement step, by default False. Supplying
True will choose a reasonable default upsampling factor, while any
positive integer > 1 will upsample the correlation peak by this factor.
DFT upsampling can provide more accurate center values, especially when
peak shifts are small, but does require more computation time.
'''
# For testing purposes, allow to inject a different limit via
# an internal kwarg
# It has to come through kwarg because of how UDFs are run
self.limit = kwargs.get('__limit', 2**19) # 1/2 MB
super().__init__(
peaks=peaks, match_pattern=match_pattern, zero_shift=zero_shift, *args, **kwargs
)
def get_task_data(self):
""
mask = self.get_pattern()
n_peaks = len(self.params.peaks)
template = self.xp.array(mask.get_template(sig_shape=self.meta.dataset_shape.sig))
dtype = np.result_type(self.meta.input_dtype, np.float32)
frame_buf = self.xp.array(
ltbc.zeros(shape=self.meta.dataset_shape.sig, dtype=dtype)
)
crop_size = mask.get_crop_size()
if self.meta.array_backend in (
self.BACKEND_SPARSE_COO, self.BACKEND_SPARSE_GCXS, self.BACKEND_CUPY):
crop_function = ltbc.crop_disks_from_frame_slicing
elif self.meta.array_backend in (self.BACKEND_NUMPY, ):
crop_function = ltbc.crop_disks_from_frame
else: # pragma: no cover
raise RuntimeError(f"Unsupported array backend {self.meta.array_backend}")
kwargs = {
'template': template,
'frame_buf': frame_buf,
'buf_count': ltbc.get_buf_count(crop_size, n_peaks, dtype, self.limit),
'crop_function': crop_function,
}
return kwargs
def get_pattern(self):
return self.params.match_pattern
def get_template(self):
return self.task_data.template
def process_frame(self, frame):
match_pattern = self.get_pattern()
(centers, refineds, peak_values, peak_elevations) = self.output_buffers()
ltbc.process_frame_full(
template=self.get_template(),
crop_size=match_pattern.get_crop_size(),
frame=frame,
peaks=self.get_peaks() + np.round(self.get_zero_shift()).astype(int),
out_centers=centers,
out_refineds=refineds,
out_heights=peak_values,
out_elevations=peak_elevations,
frame_buf=self.task_data.frame_buf,
buf_count=self.task_data.buf_count,
upsample=self.params.get('upsample', False),
crop_function=self.task_data.crop_function,
)
[docs] def get_backends(self):
# At this time cannot FFT on a full sparse frame so not
# specifying sparse backends to trigger auto-densification
return (
self.BACKEND_NUMPY,
self.BACKEND_CUPY,
)
[docs]class SparseCorrelationUDF(CorrelationUDF):
'''
Direct correlation using sparse matrices
This method allows to adjust the number of correlation steps independent of the template size.
'''
[docs] def __init__(self, peaks, match_pattern, steps, *args, **kwargs):
'''
Parameters
----------
peaks : numpy.ndarray
Numpy array of (y, x) coordinates with peak positions in px to correlate
match_pattern : MatchPattern
Instance of :class:`~libertem_blobfinder.MatchPattern`
steps : int
The template is correlated with 2 * steps + 1 symmetrically around the peak position
in x and y direction. This defines the maximum shift that can be
detected. The number of calculations grows with the square of this value, that means
keeping this as small as the data allows speeds up the calculation.
'''
super().__init__(
peaks=peaks, match_pattern=match_pattern, steps=steps, *args, **kwargs
)
if self.params.zero_shift is not None:
raise ValueError("Parameter zero_shift not supported for SparseCorrelationUDF")
[docs] def get_result_buffers(self):
"""
This method adds the :code:`corr` buffer to the result of
:meth:`CorrelationUDF.get_result_buffers`. See source code for the
exact buffer declaration.
"""
super_buffers = super().get_result_buffers()
num_disks = len(self.params.peaks)
steps = self.params.steps * 2 + 1
my_buffers = {
'corr': self.buffer(
kind="nav", extra_shape=(num_disks * steps**2,), dtype="float32"
),
}
super_buffers.update(my_buffers)
return super_buffers
def get_task_data(self):
""
match_pattern = self.params.match_pattern
crop_size = match_pattern.get_crop_size()
size = (2 * crop_size + 1, 2 * crop_size + 1)
template = match_pattern.get_mask(sig_shape=size)
steps = self.params.steps
peak_offsetY, peak_offsetX = np.mgrid[-steps:steps + 1, -steps:steps + 1]
offsetY = self.params.peaks[:, 0, np.newaxis, np.newaxis] + peak_offsetY - crop_size
offsetX = self.params.peaks[:, 1, np.newaxis, np.newaxis] + peak_offsetX - crop_size
offsetY = offsetY.flatten()
offsetX = offsetX.flatten()
stack = functools.partial(
masks.sparse_template_multi_stack,
mask_index=range(len(offsetY)),
offsetX=offsetX,
offsetY=offsetY,
template=template,
imageSizeX=self.meta.dataset_shape.sig[1],
imageSizeY=self.meta.dataset_shape.sig[0]
)
if self.meta.array_backend in sparseconverter.CPU_BACKENDS:
backend = 'numpy'
elif self.meta.array_backend in sparseconverter.CUDA_BACKENDS:
backend = 'cupy'
else: # pragma: no cover
raise ValueError("Unknown device class")
if self.meta.array_backend == self.BACKEND_SPARSE_COO:
use_sparse = 'sparse.pydata'
elif self.meta.array_backend == self.BACKEND_SPARSE_GCXS:
use_sparse = 'sparse.pydata.GCXS'
elif self.meta.array_backend in (self.BACKEND_CUPY, self.BACKEND_NUMPY):
use_sparse = 'scipy.sparse.csc'
else: # pragma: no cover
raise RuntimeError(f'Unsupported array backend {self.meta.array_backend}')
# CSC matrices in combination with transposed data are fastest
container = MaskContainer(mask_factories=stack, dtype=np.float32,
use_sparse=use_sparse, backend=backend)
kwargs = {
'mask_container': container,
'crop_size': crop_size,
}
return kwargs
def process_tile(self, tile):
tile_slice = self.meta.slice
c = self.task_data.mask_container
tile_t = ltbc.log_scale(tile.reshape((tile.shape[0], -1)).T, out=None)
sl = c.get(key=tile_slice, transpose=False)
self.results.corr[:] += self.forbuf(sl.dot(tile_t).T, self.results.corr)
[docs] def postprocess(self):
"""
The correlation results are evaluated during postprocessing since this
implementation uses tiled processing where the correlations are
incomplete in :meth:`process_tile`.
"""
steps = 2 * self.params.steps + 1
corrmaps = self.results.corr.reshape((
-1, # frames
len(self.params.peaks), # peaks
steps, # Y steps
steps, # X steps
))
peaks = self.params.peaks
(centers, refineds, peak_values, peak_elevations) = self.output_buffers()
for f in range(corrmaps.shape[0]):
ltbc.evaluate_correlations(
corrs=corrmaps[f], peaks=peaks, crop_size=self.params.steps,
out_centers=centers[f], out_refineds=refineds[f],
out_heights=peak_values[f], out_elevations=peak_elevations[f]
)
[docs] def get_backends(self):
return (
self.BACKEND_NUMPY,
self.BACKEND_CUPY,
self.BACKEND_SPARSE_COO,
self.BACKEND_SPARSE_GCXS
)
[docs]def run_fastcorrelation(
ctx, dataset, peaks, match_pattern: MatchPattern, zero_shift=None, upsample=False, **kwargs
):
"""
Wrapper function to construct and run a :class:`FastCorrelationUDF`
Parameters
----------
ctx : libertem.api.Context
dataset : libertem.io.dataset.base.DataSet
peaks : numpy.ndarray
List of peaks with (y, x) coordinates
match_pattern : libertem_blobfinder.patterns.MatchPattern
zero_shift : Union[AUXBufferWrapper, numpy.ndarray, None], optional
Zero shift, for example descan error. Can be :code:`None`, :code:`numpy.array((y, x))`
or AUX data with :code:`(y, x)` for each frame.
upsample : Union[bool, int], optional
Whether to use upsampling DFT for refinement. False to deactivate (default) or a positive
integer >1 to upsample by this factor when refining the correlation peak positions. Upsample
True will choose a sensible upsampling factor.
kwargs : passed through to :meth:`~libertem.api.Context.run_udf`
Returns
-------
buffers : Dict[libertem.common.buffers.BufferWrapper]
See :meth:`CorrelationUDF.get_result_buffers` for details.
"""
peaks = peaks.astype(int)
udf = FastCorrelationUDF(
peaks=peaks, match_pattern=match_pattern, zero_shift=zero_shift, upsample=upsample,
)
return ctx.run_udf(dataset=dataset, udf=udf, **kwargs)
[docs]def run_blobfinder(
ctx, dataset, match_pattern: MatchPattern, num_peaks, roi=None, upsample=False, progress=False
):
"""
Wrapper function to find peaks in a dataset and refine their position using
:class:`FastCorrelationUDF`
Parameters
----------
ctx : libertem.api.Context
dataset : libertem.io.dataset.base.DataSet
match_pattern : libertem_blobfinder.patterns.MatchPattern
num_peaks : int
Number of peaks to look for
roi : numpy.ndarray, optional
Boolean mask of the navigation dimension to select region of interest (ROI)
upsample : Union[bool, int], optional
Whether to use upsampling DFT for refinement. False to deactivate (default) or a positive
integer >1 to upsample by this factor when refining the correlation peak positions. Upsample
True will choose a sensible upsampling factor.
progress : bool, optional
Show progress bar
Returns
-------
sum_result : numpy.ndarray
Log-scaled sum frame of the dataset/ROI
centers, refineds, peak_values, peak_elevations : libertem.common.buffers.BufferWrapper
See :meth:`CorrelationUDF.get_result_buffers` for details.
peaks : numpy.ndarray
List of found peaks with (y, x) coordinates
"""
if upsample is True:
upsample = 20
sum_analysis = ctx.create_sum_analysis(dataset=dataset)
sum_result = ctx.run(sum_analysis, roi=roi)
sum_result = ltbc.log_scale(sum_result.intensity.raw_data, out=None)
peaks = get_peaks(
sum_result=sum_result,
match_pattern=match_pattern,
num_peaks=num_peaks,
)
pass_2_results = run_fastcorrelation(
ctx=ctx,
dataset=dataset,
peaks=peaks,
match_pattern=match_pattern,
roi=roi,
upsample=upsample,
progress=progress
)
return (sum_result, pass_2_results['centers'],
pass_2_results['refineds'], pass_2_results['peak_values'],
pass_2_results['peak_elevations'], peaks)