import logging
import inspect
from typing import Dict, NamedTuple, Optional, Tuple, Type, Union, TYPE_CHECKING
import numpy as np
from libertem import masks
from libertem.web.rpc import ProcedureProtocol
from .base import AnalysisResult, AnalysisResultSet
from .masks import BaseMasksAnalysis
from libertem.corrections import coordinates
from .helper import GeneratorHelper
if TYPE_CHECKING:
from libertem.web.rpc import RPCContext
log = logging.getLogger(__name__)
class ComTemplate(GeneratorHelper):
short_name = "com"
api = "create_com_analysis"
temp = GeneratorHelper.temp_analysis
temp_analysis = temp + ["print(com_result)"]
channels = [
"field",
"magnitude",
"divergence",
"curl",
"x",
"y"
]
def __init__(self, params):
self.params = params
def get_dependency(self):
return [
"from empyre.vis.colors import ColormapCubehelix"
]
def get_docs(self):
title = "COM Analysis"
from libertem.api import Context
docs_rst = inspect.getdoc(Context.create_com_analysis)
docs = self.format_docs(title, docs_rst)
return docs
def convert_params(self):
params = ['dataset=ds']
for k in ['cx', 'cy']:
params.append(f'{k}={self.params[k]}')
params.append(f"mask_radius={self.params['r']}")
if self.params.get('flip_y', False):
params.append("flip_y=True")
if self.params.get('scan_rotation') is not None:
params.append(f"scan_rotation={self.params['scan_rotation']}")
if self.params.get('ri') is not None:
params.append(f"mask_radius_inner={self.params['ri']}")
return ', '.join(params)
def get_plot(self):
plot = [
"fig, axes = plt.subplots()",
'axes.set_title("field")',
"x_centers, y_centers = com_result.field.raw_data",
"ch = ColormapCubehelix(start=1, rot=1, minLight=0.5, maxLight=0.5, sat=2)",
"axes.imshow(ch.rgb_from_vector((x_centers, y_centers, 0)))"
]
for channel in self.channels[1:3]:
plot.append("fig, axes = plt.subplots()")
plot.append(f'axes.set_title("{channel}")')
plot.append(f'axes.imshow(com_result.{channel}.raw_data)')
return ['\n'.join(plot)]
def get_save(self):
save = []
for channel in self.channels:
save.append(f"np.save('com_result_{channel}.npy', com_result['{channel}'].raw_data)")
return '\n'.join(save)
def com_masks_factory(detector_y, detector_x, cy, cx, r):
def disk_mask():
return masks.circular(
centerX=cx, centerY=cy,
imageSizeX=detector_x,
imageSizeY=detector_y,
radius=r,
)
return [
disk_mask,
lambda: masks.gradient_y(
imageSizeX=detector_x,
imageSizeY=detector_y,
) * disk_mask(),
lambda: masks.gradient_x(
imageSizeX=detector_x,
imageSizeY=detector_y,
) * disk_mask(),
]
def com_masks_generic(detector_y, detector_x, base_mask_factory):
"""
Create a CoM mask stack with a generic selection mask factory
Parameters
----------
detector_y : int
The detector height
detector_x : int
The detector width
base_mask_factory : () -> np.array
A factory function for creating the selection mask
Returns
-------
List[Function]
The mask stack as a list of factory functions
"""
return [
base_mask_factory,
lambda: masks.gradient_y(
imageSizeX=detector_x,
imageSizeY=detector_y,
) * base_mask_factory(),
lambda: masks.gradient_x(
imageSizeX=detector_x,
imageSizeY=detector_y,
) * base_mask_factory(),
]
def center_shifts(img_sum, img_y, img_x, ref_y, ref_x):
x_centers = np.divide(img_x, img_sum, where=img_sum != 0)
y_centers = np.divide(img_y, img_sum, where=img_sum != 0)
x_centers[img_sum == 0] = ref_x
y_centers[img_sum == 0] = ref_y
x_centers -= ref_x
y_centers -= ref_y
return (y_centers, x_centers)
def apply_correction(y_centers, x_centers, scan_rotation, flip_y, forward=True):
shape = y_centers.shape
if flip_y:
transform = coordinates.flip_y()
else:
transform = coordinates.identity()
# Transformations are applied right to left
transform = coordinates.rotate_deg(scan_rotation) @ transform
y_centers = y_centers.reshape(-1)
x_centers = x_centers.reshape(-1)
if not forward:
transform = np.linalg.inv(transform)
y_transformed, x_transformed = transform @ (y_centers, x_centers)
y_transformed = y_transformed.reshape(shape)
x_transformed = x_transformed.reshape(shape)
return (y_transformed, x_transformed)
def divergence(y_centers, x_centers):
return np.gradient(y_centers, axis=0) + np.gradient(x_centers, axis=1)
def curl_2d(y_centers, x_centers):
# https://en.wikipedia.org/wiki/Curl_(mathematics)#Usage
# DFy/dx - dFx/dy
# axis 0 is y, axis 1 is x
return np.gradient(y_centers, axis=1) - np.gradient(x_centers, axis=0)
def magnitude(y_centers, x_centers):
return np.sqrt(y_centers**2 + x_centers**2)
def coordinate_check(y_centers, x_centers, roi=None):
'''
Calculate the RMS curl as a function of :code:`scan_rotation` and :code:`flip_y`.
The curl for a purely electrostatic field is zero. That means
the correct settings for :code:`scan_rotation` and :code:`flip_y` should
minimize the RMS curl for atomic resolution STEM of non-magnetic specimens
along a high-symmetry zone axis.
Parameters
----------
y_centers, x_centers : numpy.ndarray
2D arrays with y and x component of the center of mass shift for each
scan position, as returned by :meth:`center_shifts` or
:meth:`apply_correction`
roi : Optional[numpy.ndarray]
Selector for values to consider in the statistics, compatible
with indexing an array with the shape of y_centers and x_centers.
By default, everything except the last row and last column are used
since these contain artefacts.
Returns
-------
(straight, flipped)
Root mean square of the curl as a function of :code:`scan_rotation` from
0 to 359 degrees in steps of one, with :code:`flip_y=False` (straight)
and code:`flip_y=True` (flipped).
'''
straight = np.zeros(360)
flipped = np.zeros(360)
if roi is None:
# The last row and column contain artifacts
roi = (slice(0, -1), slice(0, -1))
for angle in range(360):
for flip_y in (True, False):
y_transformed, x_transformed = apply_correction(
y_centers, x_centers, scan_rotation=angle, flip_y=flip_y
)
curl = curl_2d(y_transformed, x_transformed)
result = np.sqrt(np.mean(curl[roi]**2))
if flip_y:
flipped[angle] = result
else:
straight[angle] = result
return (straight, flipped)
[docs]class GuessResult(NamedTuple):
scan_rotation: int
flip_y: bool
cy: float
cx: float
[docs]def guess_corrections(
y_centers: np.ndarray,
x_centers: np.ndarray,
roi: Optional[Union[np.ndarray, Tuple[slice, ...]]] = None,
) -> GuessResult:
'''
Guess corrections for center shift, :code:`scan_rotation` and :code:`flip_y` from CoM data
This function can generate a CoM parameter guess for atomic resolution 4D STEM data
by using the following assumptions:
* The field is purely electrostatic, i.e. the RMS curl should be minimized
* There is no net field over the field of view and no descan error,
i.e. the mean deflection is zero.
* Atomic resolution STEM means that the divergence will be negative at atom columns
and consequently the histogram of divergence will have a stronger tail towards
negative values than towards positive values.
If any corrections were applied when generating the input data, please note that the corrections
should be applied relative to these previous value. In particular, the
center corrections returned by this function have to be back-transformed to the uncorrected
coordinate system, for example with :code:`apply_correction(..., forward=False)`
Parameters
----------
y_centers, x_centers : numpy.ndarray
2D arrays with y and x component of the center of mass shift for each
scan position, as returned by :meth:`center_shifts` or
:meth:`apply_correction`
roi : Optional[numpy.ndarray]
Selector for values to consider in the statistics, compatible
with indexing an array with the shape of y_centers and x_centers.
By default, everything except the last row and last column are used
since these contain artefacts.
Returns
-------
GuessResult : relative to current values
'''
if roi is None:
# The last row and column contain artefacts
roi = (slice(0, -1), slice(0, -1))
straight, flipped = coordinate_check(y_centers, x_centers, roi=roi)
# The one with lower minima is the correct one
flip_y = bool(np.min(flipped) < np.min(straight))
if flip_y:
angle = np.argmin(flipped)
else:
angle = np.argmin(straight)
corrected_y, corrected_x = apply_correction(
y_centers, x_centers, scan_rotation=angle, flip_y=flip_y
)
# There are two equivalent angles that minimize RMS curl since a 180°
# rotation inverts the coordinates and just flips the sign for divergence
# and curl. To distinguish the two, the distribution of the divergence is
# analyzed. With negative electrons and positive nuclei, the beam is
# deflected towards the nuclei and the divergence is negative there. Since
# the beam is deflected most strongly near the nuclei, the histogram should
# have more values at the negative end of the range than at the positive
# end.
div = divergence(corrected_y, corrected_x)[roi]
all_range = np.maximum(-np.min(div), np.max(div))
hist, bins = np.histogram(div, range=(-all_range, all_range), bins=5)
polarity_off = np.sum(hist[:1]) < np.sum(hist[-1:])
if polarity_off:
angle += 180
if angle > 180:
angle -= 360
return GuessResult(
scan_rotation=int(angle),
flip_y=flip_y,
cy=np.mean(y_centers[roi]),
cx=np.mean(x_centers[roi])
)
[docs]class COMResultSet(AnalysisResultSet):
"""
Running a :class:`COMAnalysis` via :meth:`libertem.api.Context.run` on a dataset
returns an instance of this class.
This analysis is usually applied to datasets with real values. If the dataset contains
complex numbers, this result contains the keys :attr:`x_real`, :attr:`y_real`,
:attr:`x_imag`, :attr:`y_imag` instead of the vector field.
By default, the shift is given in pixel coordinates, i.e. positive x shift goes to the right
and positive y shift goes to the bottom. See also :ref:`concepts`.
.. versionchanged:: 0.6.0
The COM analysis now supports flipping the y axis and rotating the vectors.
.. versionadded:: 0.3.0
Attributes
----------
field : libertem.analysis.base.AnalysisResult
Center of mass shift relative to the center given to the analysis within the given radius
as a vector field with components (x, y). The visualized result uses a
cubehelix color wheel.
magnitude : libertem.analysis.base.AnalysisResult
Magnitude of the center of mass shift.
divergence : libertem.analysis.base.AnalysisResult
Divergence of the center of mass vector field at a given point
curl : libertem.analysis.base.AnalysisResult
Curl of the center of mass 2D vector field at a given point.
.. versionadded:: 0.6.0
x : libertem.analysis.base.AnalysisResult
X component of the center of mass shift
y : libertem.analysis.base.AnalysisResult
Y component of the center of mass shift
x_real : libertem.analysis.base.AnalysisResult
Real part of the x component of the center of mass shift (complex dataset only)
y_real : libertem.analysis.base.AnalysisResult
Real part of y component of the center of mass shift (complex dataset only)
x_imag : libertem.analysis.base.AnalysisResult
Imaginary part of the x component of the center of mass shift (complex dataset only)
y_imag : libertem.analysis.base.AnalysisResult
Imaginary part of y component of the center of mass shift (complex dataset only)
"""
pass
class ParameterGuessProc:
async def __call__(self, rpc_context: "RPCContext") -> Dict:
comp_ana = rpc_context.get_compound_analysis()
analyses = comp_ana["details"]["analyses"]
analysis_details = [
rpc_context.get_analysis_details(a)
for a in analyses
]
try:
com_analysis = [
a
for a in analysis_details
if a["details"]["analysisType"] == "CENTER_OF_MASS"
][0]
except IndexError:
return {
"status": "error",
"message": "no CoM analysis found",
}
com_analysis_id = com_analysis["analysis"]
if not rpc_context.have_analysis_results(com_analysis_id):
# run with the current analysis parameters as set in the GUI:
await rpc_context.run_analysis(com_analysis_id)
result_info = rpc_context.get_analysis_results(com_analysis_id)
res = result_info.results
old_params = result_info.details["parameters"]
guess = await rpc_context.run_sync(guess_corrections, res.y.raw_data, res.x.raw_data)
# NOTE: convert guess results to absolute values to make sure we don't
# run into any nasty synchronization issues, for example, if state goes
# stale after the guess button was clicked.
flip_y = bool(old_params["flip_y"]) != bool(guess.flip_y)
backtransformed = apply_correction(
y_centers=np.array((guess.cy, )),
x_centers=np.array((guess.cx, )),
scan_rotation=old_params["scan_rotation"],
flip_y=old_params["flip_y"],
forward=False,
)
return {
'status': 'ok',
'guess': {
'cx': backtransformed[1][0] + old_params["cx"],
'cy': backtransformed[0][0] + old_params["cy"],
'scan_rotation': guess.scan_rotation + old_params["scan_rotation"],
'flip_y': flip_y,
},
}
class COMAnalysis(BaseMasksAnalysis, id_="CENTER_OF_MASS"):
TYPE = 'UDF'
def get_udf_results(self, udf_results, roi, damage):
data = udf_results['intensity'].data
img_sum, img_y, img_x = (
data[..., 0],
data[..., 1],
data[..., 2],
)
return self.get_generic_results(img_sum, img_y, img_x, damage=damage)
def get_generic_results(self, img_sum, img_y, img_x, damage):
from libertem.viz import CMAP_CIRCULAR_DEFAULT, visualize_simple
ref_x = self.parameters["cx"]
ref_y = self.parameters["cy"]
y_centers_raw, x_centers_raw = center_shifts(img_sum, img_y, img_x, ref_y, ref_x)
shape = y_centers_raw.shape
y_centers, x_centers = apply_correction(
y_centers_raw, x_centers_raw,
scan_rotation=self.parameters["scan_rotation"],
flip_y=self.parameters["flip_y"]
)
if img_sum.dtype.kind == 'c':
x_real, x_imag = np.real(x_centers), np.imag(x_centers)
y_real, y_imag = np.real(y_centers), np.imag(y_centers)
return COMResultSet([
AnalysisResult(raw_data=x_real, visualized=visualize_simple(x_real, damage=damage),
key="x_real", title="x [real]", desc="x component of the center"),
AnalysisResult(raw_data=y_real, visualized=visualize_simple(y_real, damage=damage),
key="y_real", title="y [real]", desc="y component of the center"),
AnalysisResult(raw_data=x_imag, visualized=visualize_simple(x_imag, damage=damage),
key="x_imag", title="x [imag]", desc="x component of the center"),
AnalysisResult(raw_data=y_imag, visualized=visualize_simple(y_imag, damage=damage),
key="y_imag", title="y [imag]", desc="y component of the center"),
])
else:
damage = damage & np.isfinite(x_centers) & np.isfinite(y_centers)
# Make sure that an all-False `damage` is handled since np.max()
# trips on an empty array.
# As a remark -- the NumPy error message
# "zero-size array to reduction operation maximum which has no identity"
# is probably wrong since -np.inf is the identity element for maximum on
# floating point numbers and should be returned here.
if np.count_nonzero(damage) > 0:
vmax = np.sqrt(np.max(x_centers[damage]**2 + y_centers[damage]**2))
else:
vmax = 1
f = CMAP_CIRCULAR_DEFAULT.rgb_from_vector((x_centers, y_centers, 0), vmax=vmax)
m = magnitude(y_centers, x_centers)
# Create results which are valid for any nav_shape
results_list = [
AnalysisResult(
raw_data=(x_centers, y_centers),
visualized=f,
key="field", title="field", desc="cubehelix colorwheel visualization",
include_in_download=False
),
AnalysisResult(
raw_data=m,
visualized=visualize_simple(m, damage=damage),
key="magnitude", title="magnitude", desc="magnitude of the vector field"
),
AnalysisResult(
raw_data=x_centers,
visualized=visualize_simple(x_centers, damage=damage),
key="x", title="x", desc="x component of the center"
),
AnalysisResult(
raw_data=y_centers,
visualized=visualize_simple(y_centers, damage=damage),
key="y", title="y", desc="y component of the center"
),
]
# Add results which depend on np.gradient, i.e. all(nav_shape) > 1
if all([s > 1 for s in shape]):
d = divergence(y_centers, x_centers)
c = curl_2d(y_centers, x_centers)
extra_results = [
AnalysisResult(
raw_data=d,
visualized=visualize_simple(d, damage=damage),
key="divergence", title="divergence", desc="divergence of the vector field"
),
AnalysisResult(
raw_data=c,
visualized=visualize_simple(c, damage=damage),
key="curl", title="curl", desc="curl of the 2D vector field"
),
]
# Insert the results at position 2 for backwards compatibility/tests
# This could later be replaced with results_list.extend(extra_results)
results_list[2:2] = extra_results
return COMResultSet(results_list)
def get_mask_factories(self):
if self.dataset.shape.sig.dims != 2:
raise ValueError("can only handle 2D signals currently")
if self.parameters.get('ri'):
# annular CoM:
return com_masks_generic(
detector_y=self.dataset.shape.sig[0],
detector_x=self.dataset.shape.sig[1],
base_mask_factory=lambda: masks.ring(
imageSizeY=self.dataset.shape.sig[0],
imageSizeX=self.dataset.shape.sig[1],
centerY=self.parameters['cy'],
centerX=self.parameters['cx'],
radius=self.parameters['r'],
radius_inner=self.parameters['ri'],
)
)
else:
# CoM with radius cut-off:
return com_masks_factory(
detector_y=self.dataset.shape.sig[0],
detector_x=self.dataset.shape.sig[1],
cx=self.parameters['cx'],
cy=self.parameters['cy'],
r=self.parameters['r'],
)
def get_parameters(self, parameters: Dict) -> Dict:
(detector_y, detector_x) = self.dataset.shape.sig
cx = parameters.get('cx', detector_x / 2)
cy = parameters.get('cy', detector_y / 2)
r = parameters.get('r', float('inf'))
ri = parameters.get('ri', 0.0)
scan_rotation = parameters.get('scan_rotation', 0.)
flip_y = parameters.get('flip_y', False)
use_sparse = parameters.get('use_sparse', False)
return {
'cx': cx,
'cy': cy,
'r': r,
'ri': ri,
'scan_rotation': scan_rotation,
'flip_y': flip_y,
'use_sparse': use_sparse,
'mask_count': 3,
'mask_dtype': np.float32,
}
@classmethod
def get_template_helper(cls) -> Type[GeneratorHelper]:
return ComTemplate
@classmethod
def get_rpc_definitions(cls) -> Dict[str, Type[ProcedureProtocol]]:
return {
"guess_parameters": ParameterGuessProc,
}
def need_rerun(self, old_params: Dict, new_params: Dict) -> bool:
"""
Don't need to re-run UDF if only `flip_y` or `scan_rotation`
have changed.
"""
ignore_keys = {"flip_y", "scan_rotation"}
old_without_ignored = {
k: v
for k, v in old_params.items()
if k not in ignore_keys
}
new_without_ignored = {
k: v
for k, v in new_params.items()
if k not in ignore_keys
}
return old_without_ignored != new_without_ignored