import typing
import logging
import time
import numpy as np
# Makes colormaps available in matplotlib
import colorcet # noqa:F401
from matplotlib import colors, cm, colormaps
from skimage.color import rgb2luv, luv2rgb
from libertem.udf.base import UDFRunner
# Import here for backwards compatibility, refs #1031
from libertem.common.viz import encode_image # NOQA: F401
if typing.TYPE_CHECKING:
from libertem.common.buffers import BufferWrapper
logger = logging.getLogger(__name__)
def _get_norm(result, norm_cls=colors.Normalize, vmin=None, vmax=None, damage=None):
if (vmin is not None) and (vmax is not None):
return norm_cls(vmin=vmin, vmax=vmax)
result = result.astype(np.float32)
if damage is None:
damage = (result != 0)
damage = damage & np.isfinite(result)
if damage.sum() == 0:
return norm_cls(vmin=1, vmax=1) # all-NaN or all-zero
if vmin is None:
vmin = np.min(result[damage])
if vmax is None:
vmax = np.max(result[damage])
return norm_cls(vmin=vmin, vmax=vmax)
[docs]
def visualize_simple(result, colormap=None, logarithmic=False, vmin=None, vmax=None, damage=None):
"""
Normalize and visualize ``result`` with ``colormap`` and return the
resulting RGBA data as an array.
Parameters
----------
result : numpy.ndarray
2d array of intensity values
colormap : matplotlib colormap or None
colormap used for visualizing intensity values, defaults to matplotlib.cm.gist_earth
Returns
-------
np.array
A numpy array of shape (Y, X, 4) containing RGBA data, suitable for
passing to `Image.fromarray` in PIL.
"""
if logarithmic:
# Convert to the smallest dtype that supports subtractions
dtype = np.result_type(result, np.int8)
result = result.astype(dtype)
cnorm = colors.LogNorm
result = result - np.min(result) + 1
else:
cnorm = colors.Normalize
if colormap is None:
colormap = cm.gist_earth
norm = _get_norm(result, norm_cls=cnorm, vmin=vmin, vmax=vmax, damage=damage)
shape = result.shape
normalized = norm(result.reshape((-1,))).reshape(shape)
colored = colormap(normalized, bytes=True)
return colored
cet_cyclic_isoluminant = colormaps.get_cmap('cet_cyclic_isoluminant')
def rgb_from_2dvector(y, x, vmax=None):
y = np.asarray(y)
x = np.asarray(x)
absval = np.sqrt(np.abs(y)**2 + np.abs(x)**2)
if vmax is None:
vmax = np.max(absval[np.isfinite(absval)])
if vmax == 0:
vmax = 1
scaled_absval = absval / vmax
scaled_absval = np.minimum(scaled_absval, 1)
# Transform coordinates to match previous circular color map
angle = np.arctan2(-y, -x)
# Get color from isoluminant cyclic map of the angle
color = np.asarray(visualize_simple(
angle, colormap=cet_cyclic_isoluminant, vmin=-np.pi, vmax=np.pi
))
# Transform to LUV space for linear axes
luv = rgb2luv(color[..., :3])
# Tweak luminance to match previous map
luv[..., 0] *= 0.8
# Adjust color intensity to roughly match previous implementation
color_scale = scaled_absval * 2.2
# Scale color
luv[..., 1] *= color_scale
luv[..., 2] *= color_scale
rgb = (luv2rgb(luv) * 255).astype(np.uint8)
return rgb
def make_colormap():
angles = np.arange(360)
angles_rad = angles/180*np.pi
yy, xx = -np.sin(angles_rad), -np.cos(angles_rad)
cols = rgb_from_2dvector(y=yy, x=xx) / 255
cmap = colors.LinearSegmentedColormap.from_list('libertem_cyclic', cols)
return cmap
libertem_cyclic = make_colormap()
def get_plottable_2D_channels(buffers):
return [
k
for k, buf in buffers.items()
# 2D data, removing axes of size 1
if len(buf.data.squeeze().shape) == 2
]
[docs]
class Live2DPlot:
"""
Base plotting class for interactive use. Please see the subclasses for
concrete details.
.. versionadded:: 0.7.0
Parameters
----------
dataset : DataSet
The dataset on which the UDF will be run. This allows
to determine the shape of the plots for initialization.
udf : UDF
The UDF instance this plot is associated to. This needs to be
the same instance that is passed to
:meth:`~libertem.api.Context.run_udf`.
roi : numpy.ndarray or None
Region of interest (ROI) that the UDF will
be run on. This is necessary for UDFs where the `extra_shape`
parameter of result buffers is a function of the ROI, such as
:class:`~libertem.udf.raw.PickUDF`.
channel : misc
Indicate the channel to be plotted.
- :code:`None`: The first plottable (2D) channel of the UDF is plotted.
- :code:`str`: The UDF result buffer name that should be plotted.
- :code:`tuple(str, function(ndarray) -> ndarray)`: The UDF result buffer name that
should be plotted together with a function that extracts a plottable result
- :code:`function(udf_result, damage) -> (ndarray, damage)`: Function that derives a
plottable 2D ndarray and damage indicator from the full
UDF results and the processed nav space. See :ref:`plotting` for more details!
title : str
The plot title. By default UDF class name and channel name.
min_delta : float
Minimum time span in seconds between updates to reduce overheads for slow plotting.
udfresult : UDFResults, optional
UDF result to initialize the plot data and determine plot shape. If None (default),
this is determined using :meth:`~libertem.udf.base.UDFRunner.dry_run` on the dataset,
UDF and ROI. This parameter allows re-using buffers to avoid unnecessary dry runs.
"""
def __init__(
self, dataset, udf, roi=None, channel=None, title=None, min_delta=0, udfresult=None
):
if udfresult is None:
udfresult = UDFRunner.dry_run([udf], dataset, roi)
eligible_channels = get_plottable_2D_channels(udfresult.buffers[0])
if channel is None:
if not eligible_channels:
raise ValueError(f"No plottable 2D channel found for {udf.__class__.__name__}")
channel = eligible_channels[0]
channel_title = channel
if callable(channel):
extract = channel
channel_title = channel.__name__
channel = None
elif isinstance(channel, (tuple, list)):
channel, func = channel
if channel not in udfresult.buffers[0]:
raise ValueError(
f"channel {channel} not found, have: {udfresult.buffers[0].keys()}"
)
kind = udfresult.buffers[0][channel].kind
if kind == 'nav':
def extract(udf_results, damage):
return (func(udf_results[channel].data), damage)
else:
def extract(udf_results, damage):
return (func(udf_results[channel].data), True)
channel_title = f"{func.__name__}({channel})"
else:
extract = None
if channel not in eligible_channels:
raise ValueError("channel {} not found or not plottable, have: {!r}".format(
channel, eligible_channels
))
channel_title = channel
self._extract = extract
self.channel = channel
if title is None:
title = f"{udf.__class__.__name__}: {channel_title}"
self.title = title
self.data, _ = self.extract(udfresult.buffers[0], udfresult.damage)
self.udf = udf
self.last_update = 0
self.min_delta = min_delta
[docs]
def get_udf(self):
"""
Returns the associated UDF instance
"""
return self.udf
[docs]
def new_data(self, udf_results: dict[str, "BufferWrapper"], damage, force: bool = False):
"""
This method is called with the raw `udf_results` any time a new
partition has finished processing.
The :code:`damage` parameter is filtered to only cover finite
values of :code:`self.data` and passed to :meth:`self.update`,
which should then be implemented by a subclass.
"""
t0 = time.time()
delta = t0 - self.last_update
if (not force) and delta < self.min_delta:
return # don't update plot if we recently updated
(self.data, damage) = self.extract(udf_results, damage)
finite = np.isfinite(self.data)
# duck typing "array-like" since damage can also be bool
if hasattr(damage, "shape"):
start_axis = len(damage.shape)
stop_axis = len(self.data.shape)
# Reduce surplus axes, for example RGB plotting #1052
if stop_axis > start_axis:
finite = np.all(finite, axis=tuple(range(start_axis, stop_axis)))
damage = damage & finite
self.update(damage, force=force)
self.last_update = time.time()
logger.debug("%s updated in %.3f seconds", self.__class__.__name__, self.last_update - t0)
[docs]
def update(self, damage, force=False):
"""
Update the plot based on `self.data`.
Parameters
----------
damage : numpy.ndarray Boolean array with the shape of
:code:`self.data`. It is :code:`True` for all positions in
:code:`self.data` that contain finite values and have potentially
been touched by the UDF. This can be used to extract the correct
plot range by ignoring invalid buffer portions.
force : bool Force an update, disabling any throttling mechanisms
"""
raise NotImplementedError()
[docs]
def display(self):
"""
Show the plot, for example in the current Jupyter cell.
"""
raise NotImplementedError()
[docs]
class Dummy2DPlot(Live2DPlot):
'''
No-op plot. This is useful for test and example code to not
attempt displaying :code:`matplotlib` plots in a headless environment
or during batch operation.
'''
[docs]
def update(self, damage, force=False):
pass
[docs]
def display(self):
pass