import numpy as np
from libertem_blobfinder.base.utils import frame_peaks
import libertem_blobfinder.common.gridmatching as grm
from libertem_blobfinder.common.patterns import MatchPattern
from libertem_blobfinder.udf.correlation import (
FastCorrelationUDF, SparseCorrelationUDF, FullFrameCorrelationUDF
)
[docs]class RefinementMixin():
'''
To be combined with a :class:`libertem_blobfinder.CorrelationUDF`
using multiple inheritance.
The mixin must come before the UDF in the inheritance list.
The subclasses implement a :code:`postprocess` method that calculates a
refinement of start_zero, start_a and start_b based on the correlation
result and populates the appropriate result buffers with this refinement
result.
This allows combining arbitrary implementations of correlation-based
matching with arbitrary implementations of the refinement by declaring an
ad-hoc class that inherits from one subclass of RefinementMixin and one
subclass of CorrelationUDF.
'''
[docs] def get_result_buffers(self):
"""
This adds :code:`zero`, :code:`a`, :code:`b`, :code:`selector`,
:code:`error` to the superclass result buffer declaration.
:code:`zero`, :code:`a`, :code:`b`:
Grid refinement parameters for each frame.
:code:`selector`:
Boolean mask of the peaks that were used in the fit.
:code:`error`:
Residual of the fit.
See source code for the exact buffer declaration.
"""
super_buffers = super().get_result_buffers()
num_disks = len(self.params.peaks)
my_buffers = {
'zero': self.buffer(
kind="nav", extra_shape=(2,), dtype="float32"
),
'a': self.buffer(
kind="nav", extra_shape=(2,), dtype="float32"
),
'b': self.buffer(
kind="nav", extra_shape=(2,), dtype="float32"
),
'selector': self.buffer(
kind="nav", extra_shape=(num_disks,), dtype="bool"
),
'error': self.buffer(
kind="nav", dtype="float32",
),
}
super_buffers.update(my_buffers)
return super_buffers
[docs] def apply_match(self, index, match):
"""
Override this method to change how a match is saved in the result
buffers, for example to support binned processing or ragged result
arrays.
"""
r = self.results
# We cast from float64 to float32 here
r.zero[index] = match.zero
r.a[index] = match.a
r.b[index] = match.b
r.selector[index] = match.selector
r.error[index] = match.error
[docs]class FastmatchMixin(RefinementMixin):
'''
Refinement using :meth:`~libertem_blobfinder.common.gridmatching.Matcher.fastmatch`
'''
[docs] def __init__(self, *args, **kwargs):
'''
Parameters
----------
matcher : libertem_blobfinder.common.gridmatching.Matcher
Instance of :class:`~libertem_blobfinder.common.gridmatching.Matcher`
start_zero : numpy.ndarray
Approximate value (y, x) in px for "zero" point (origin, zero order peak)
start_a : numpy.ndarray
Approximate value (y, x) in px for "a" vector.
start_b : numpy.ndarray
Approximate value (y, x) in px for "b" vector.
'''
super().__init__(*args, **kwargs)
def postprocess(self):
super().postprocess()
p = self.params
r = self.results
for index in range(len(self.results.centers)):
match = p.matcher.fastmatch(
centers=r.centers[index],
refineds=r.refineds[index],
peak_values=r.peak_values[index],
peak_elevations=r.peak_elevations[index],
zero=p.start_zero + self.get_zero_shift(index),
a=p.start_a,
b=p.start_b,
)
self.apply_match(index, match)
[docs]class AffineMixin(RefinementMixin):
'''
Refinement using :meth:`~libertem_blobfinder.common.gridmatching.Matcher.affinematch`
'''
[docs] def __init__(self, *args, **kwargs):
'''
Parameters
----------
matcher : libertem_blobfinder.common.gridmatching.Matcher
Instance of :class:`~libertem_blobfinder.common.gridmatching.Matcher`
indices : numpy.ndarray
List of indices [(h1, k1), (h2, k2), ...] of all peaks. The indices can be
non-integer and relative to any base vectors, including virtual ones like
(1, 0); (0, 1). See documentation of
:meth:`~libertem_blobfinder.common.gridmatching.Matcher.affinematch` for details.
'''
super().__init__(*args, **kwargs)
def postprocess(self):
super().postprocess()
p = self.params
r = self.results
for index in range(len(self.results.centers)):
match = p.matcher.affinematch(
centers=r.centers[index],
refineds=r.refineds[index],
peak_values=r.peak_values[index],
peak_elevations=r.peak_elevations[index],
indices=p.indices,
)
self.apply_match(index, match)
[docs]def run_refine(
ctx, dataset, zero, a, b, match_pattern: MatchPattern, matcher: grm.Matcher,
correlation='fast', match='fast', indices=None, steps=5, zero_shift=None,
upsample=False, **kwargs):
'''
Wrapper function to refine the given lattice for each frame by calculating
approximate peak positions and refining them for each frame using a
combination of :class:`libertem_blobfinder.CorrelationUDF` and
:class:`libertem_blobfinder.RefinementMixin`.
.. versionchanged:: 0.3.0
Support for :class:`FullFrameCorrelationUDF`
through parameter :code:`correlation = 'fullframe'`
Parameters
----------
ctx : libertem.api.Context
Instance of a LiberTEM :class:`~libertem.api.Context`
dataset : libertem.io.dataset.base.DataSet
Instance of a :class:`~libertem.io.dataset.base.DataSet`
zero : numpy.ndarray
Approximate value for "zero" point (y, x) in px (origin, zero order
peak)
a : numpy.ndarray
Approximate value for "a" vector (y, x) in px.
b : numpy.ndarray
Approximate value for "b" vector (y, x) in px.
match_pattern : MatchPattern
Instance of :class:`~MatchPattern`
matcher : libertem_blobfinder.common.gridmatching.Matcher
Instance of :class:`~libertem_blobfinder.common.gridmatching.Matcher`
to perform the matching
correlation : {'fast', 'sparse', 'fullframe'}, optional
'fast', 'sparse' or 'fullframe' to select :class:`~FastCorrelationUDF`,
:class:`~SparseCorrelationUDF` or :class:`~FullFrameCorrelationUDF`
match : {'fast', 'affine'}, optional
'fast' or 'affine' to select
:class:`~FastmatchMixin` or :class:`~AffineMixin`
indices : numpy.ndarray, optional
Indices to refine. This is trimmed down to
positions within the frame. As a convenience, for the indices parameter
this function accepts both shape (n, 2) and (2, n, m) so that
numpy.mgrid[h:k, i:j] works directly to specify indices. This saves
boilerplate code when using this function.
Default: numpy.mgrid[-10:10, -10:10].
steps : int, optional
Only for correlation == 'sparse': Correlation steps. See
:meth:`~SparseCorelationUDF.__init__` for
details.
zero_shift : Union[AUXBufferWrapper, numpy.ndarray, None], optional
Zero shift, for example descan error. Can be :code:`None`, :code:`numpy.array((y, x))`
or AUX data with :code:`(y, x)` for each frame. Only supported for correlation methods
:code:`fast` and `fullframe`.
upsample: Union[bool, int], optional
Use DFT upsampling for the refinement step, by default False. Supplying
True will choose a reasonable default upsampling factor, while any
positive integer > 1 will upsample the correlation peak by this factor.
DFT upsampling can provide more accurate center values, especially when
peak shifts are small, but does require more computation time.
kwargs : passed through to :meth:`~libertem.api.Context.run_udf`
Returns
-------
result : Dict[str, BufferWrapper]
Result buffers of the UDF. See
:meth:`libertem_blobfinder.correlation.CorrelationUDF.get_result_buffers` and
:meth:`RefinementMixin.get_result_buffers` for details on the available
buffers.
used_indices : numpy.ndarray
The peak indices that were within the frame.
Examples
--------
>>> dataset = ctx.load(
... filetype="memory",
... data=np.zeros(shape=(2, 2, 128, 128), dtype=np.float32)
... )
>>> (result, used_indices) = run_refine(
... ctx, dataset,
... zero=(64, 64), a=(1, 0), b=(0, 1),
... match_pattern=libertem_blobfinder.common.patterns.RadialGradient(radius=4),
... matcher=grm.Matcher()
... )
>>> result['centers'].data #doctest: +ELLIPSIS
array(...)
'''
if indices is None:
indices = np.mgrid[-10:11, -10:11]
if upsample is True:
upsample = 20
(fy, fx) = tuple(dataset.shape.sig)
indices, peaks = frame_peaks(
fy=fy, fx=fx, zero=zero, a=a, b=b,
r=match_pattern.search, indices=indices
)
peaks = peaks.astype('int')
if correlation == 'fast':
method = FastCorrelationUDF
elif correlation == 'sparse':
method = SparseCorrelationUDF
elif correlation == 'fullframe':
method = FullFrameCorrelationUDF
else:
raise ValueError(
"Unknown correlation method %s. Supported are 'fast' and 'sparse'" % correlation
)
if match == 'affine':
mixin = AffineMixin
elif match == 'fast':
mixin = FastmatchMixin
else:
raise ValueError(
"Unknown match method %s. Supported are 'fast' and 'affine'" % match
)
# The inheritance order matters: FIRST the mixin, which calls
# the super class methods.
class MyUDF(mixin, method):
pass
udf = MyUDF(
peaks=peaks,
indices=indices,
start_zero=zero,
start_a=a,
start_b=b,
match_pattern=match_pattern,
matcher=matcher,
steps=steps,
zero_shift=zero_shift,
upsample=upsample,
)
result = ctx.run_udf(
dataset=dataset,
udf=udf,
**kwargs
)
return (result, indices)