LiberTEM UDFs

[1]:
%matplotlib nbagg
[2]:
import os
import matplotlib.pyplot as plt
import libertem.api as lt
import numpy as np
[3]:
ctx = lt.Context()

Specifying the dataset

Most formats can be loaded using the "auto" type, but some may need additional parameters.

See the loading data section of the LiberTEM docs for details.

[4]:
data_base_path = os.environ.get("TESTDATA_BASE_PATH", "/home/alex/Data/")
[5]:
ds = ctx.load("auto", path=os.path.join(data_base_path, "01_ms1_3p3gK.hdr"))

After loading, some information is available in the diagnostics attribute:

[6]:
ds.diagnostics
[6]:
[{'name': 'Data type', 'value': 'u16'},
 {'name': 'Partition shape', 'value': '(2075, 256, 256)'},
 {'name': 'Number of partitions', 'value': '33'},
 {'name': 'Number of frames skipped at the beginning', 'value': 0},
 {'name': 'Number of frames ignored at the end', 'value': 0},
 {'name': 'Number of blank frames inserted at the beginning', 'value': 0},
 {'name': 'Number of blank frames inserted at the end', 'value': 0}]

Standard analyses: virtual detector

A standard analysis to run on 4D STEM data is to apply a virtual detector. Here, we define a ring detector, with radii in pixels:

[7]:
ring = ctx.create_ring_analysis(dataset=ds, ri=60, ro=70)

The analysis can be run with the Context.run method:

[8]:
ring_res = ctx.run(ring, progress=True)
ring_res
100%|██████████| 33/33 [00:01<00:00, 32.00it/s]
[8]:
[<AnalysisResult: intensity>, <AnalysisResult: intensity_log>]

As the analysis mirrors what the web GUI does, we have to access the data using the raw_data attribute, as we would get a viusalized result otherwise. Here we do the visualization ourselves using matplotlib:

[9]:
plt.figure()
plt.imshow(ring_res.intensity.raw_data)
[9]:
<matplotlib.image.AxesImage at 0x7fa50ac67730>

Simple UDF definition

User-defined funtions provide a way for you to implement your own data processing functionality. As a very simple example, we define a function that just sums up the pixels of each frame:

[10]:
def sum_of_pixels(frame):
    return np.sum(frame)

The easiest way to run this on the data is to use the Context.map function:

[11]:
res_pixelsum_1 = ctx.map(dataset=ds, f=sum_of_pixels, progress=True)
res_pixelsum_1
100%|██████████| 33/33 [00:00<00:00, 80.59it/s]
[11]:
<BufferWrapper kind=nav dtype=uint64 extra_shape=()>

The result is of type BufferWrapper, but can be used by any function that expects a numpy array, for example for plotting it:

[12]:
plt.figure()
plt.imshow(res_pixelsum_1)
[12]:
<matplotlib.image.AxesImage at 0x7fa50ac1b820>

The Context.map function is a shortcut for implementing very easy mapping over data, in a frame-by-frame fashion. The longer way of writing this would be as follows:

[13]:
from libertem.udf import UDF


class SumOfPixels(UDF):
    def get_result_buffers(self):
        return {
            'sum_of_pixels': self.buffer(kind='nav', dtype='float32')
        }

    def process_frame(self, frame):
        self.results.sum_of_pixels[:] = np.sum(frame)

This can now be run using the Context.run_udf method:

[14]:
res_pixelsum_2 = ctx.run_udf(dataset=ds, udf=SumOfPixels(), progress=True)
res_pixelsum_2
100%|██████████| 33/33 [00:00<00:00, 84.38it/s]
[14]:
{'sum_of_pixels': <BufferWrapper kind=nav dtype=float32 extra_shape=()>}

The result is now a dict, which maps buffer names, as defined in get_result_buffers, to the BufferWrapper result, so we can use the following to plot the results:

[15]:
plt.figure()
plt.imshow(res_pixelsum_2['sum_of_pixels'])
[15]:
<matplotlib.image.AxesImage at 0x7fa50ab119a0>

extra_shape: more than one result per scan position

[16]:
class StatsUDF(UDF):
    def get_result_buffers(self):
        return {
            'all_stats': self.buffer(kind='nav', dtype='float32', extra_shape=(4,)),
        }

    def process_frame(self, frame):
        self.results.all_stats[:] = (np.mean(frame), np.min(frame), np.max(frame), np.std(frame))
[17]:
res_stats = ctx.run_udf(dataset=ds, udf=StatsUDF(), progress=True)
100%|██████████| 33/33 [00:01<00:00, 32.95it/s]

Result now has an extra dimension, as specified by extra_shape above:

[18]:
res_stats['all_stats'].data.shape
[18]:
(186, 357, 4)

Let’s plot the stddev of each frame:

[19]:
plt.figure()
plt.imshow(res_stats['all_stats'].data[..., 3])