Source code for libertem.io.dataset.ser

import os
import logging
from typing import Optional
import warnings
import contextlib

import numpy as np
from ncempy.io.ser import fileSER
from sparseconverter import CUDA, NUMPY, ArrayBackend

from libertem.common.math import prod, flat_nonzero
from libertem.common import Shape, Slice
from libertem.io.dataset.base.tiling_scheme import TilingScheme
from libertem.common.messageconverter import MessageConverter
from .base import (
    DataSet, FileSet, BasePartition, DataSetException, DataSetMeta,
    DataTile,
)

log = logging.getLogger(__name__)


class SERDatasetParams(MessageConverter):
    SCHEMA = {
        "$schema": "http://json-schema.org/draft-07/schema#",
        "$id": "http://libertem.org/SERDatasetParams.schema.json",
        "title": "SERDatasetParams",
        "type": "object",
        "properties": {
            "type": {"const": "SER"},
            "path": {"type": "string"},
            "nav_shape": {
                "type": "array",
                "items": {"type": "number", "minimum": 1},
                "minItems": 2,
                "maxItems": 2
            },
            "sig_shape": {
                "type": "array",
                "items": {"type": "number", "minimum": 1},
                "minItems": 2,
                "maxItems": 2
            },
            "sync_offset": {"type": "number"},
        },
        "required": ["type", "path"]
    }

    def convert_to_python(self, raw_data):
        data = {
            k: raw_data[k]
            for k in ["path"]
        }
        if "nav_shape" in raw_data:
            data["nav_shape"] = tuple(raw_data["nav_shape"])
        if "sig_shape" in raw_data:
            data["sig_shape"] = tuple(raw_data["sig_shape"])
        if "sync_offset" in raw_data:
            data["sync_offset"] = raw_data["sync_offset"]
        return data


class SERFile:
    def __init__(self, path, num_frames):
        self._path = path
        self._num_frames = num_frames

    def _get_handle(self):
        return fileSER(self._path)

    @contextlib.contextmanager
    def get_handle(self):
        with self._get_handle() as f:
            yield f

    @property
    def num_frames(self):
        return self._num_frames

    @property
    def start_idx(self):
        return 0

    @property
    def end_idx(self):
        return self.num_frames


class SERFileSet(FileSet):
    pass


[docs] class SERDataSet(DataSet): """ Read TIA SER files. Examples -------- >>> ds = ctx.load("ser", path="/path/to/file.ser") # doctest: +SKIP Parameters ---------- path: str Path to the .ser file nav_shape: tuple of int, optional A n-tuple that specifies the size of the navigation region ((y, x), but can also be of length 1 for example for a line scan, or length 3 for a data cube, for example) sig_shape: tuple of int, optional Signal/detector size (height, width) sync_offset: int, optional If positive, number of frames to skip from start If negative, number of blank frames to insert at start """ def __init__(self, path, emipath=None, nav_shape=None, sig_shape=None, sync_offset=0, io_backend=None): super().__init__(io_backend=io_backend) if io_backend is not None: raise ValueError("SERDataSet currently doesn't support alternative I/O backends") self._path = path self._meta = None self._filesize = None self._num_frames = None if emipath is not None: warnings.warn( "emipath is not used anymore, as it was removed from ncempy", DeprecationWarning ) self._nav_shape = tuple(nav_shape) if nav_shape else nav_shape self._sig_shape = tuple(sig_shape) if sig_shape else sig_shape self._sync_offset = sync_offset def _do_initialize(self): self._filesize = os.stat(self._path).st_size reader = SERFile(path=self._path, num_frames=None) with reader.get_handle() as f1: self._num_frames = f1.head['ValidNumberElements'] if f1.head['ValidNumberElements'] == 0: raise DataSetException("no data found in file") data, meta_data = f1.getDataset(0) dtype = f1._dictDataType[meta_data['DataType']] nav_dims = tuple( reversed([ int(dim['DimensionSize']) for dim in f1.head['Dimensions'] ]) ) self._image_count = int(self._num_frames) if self._nav_shape is None: self._nav_shape = nav_dims if self._sig_shape is None: self._sig_shape = tuple(data.shape) elif int(prod(self._sig_shape)) != int(prod(data.shape)): raise DataSetException( "sig_shape must be of size: %s" % int(prod(data.shape)) ) self._nav_shape_product = int(prod(self._nav_shape)) self._sync_offset_info = self.get_sync_offset_info() self._shape = Shape(self._nav_shape + self._sig_shape, sig_dims=len(self._sig_shape)) self._meta = DataSetMeta( shape=self._shape, raw_dtype=dtype, sync_offset=self._sync_offset, image_count=self._image_count, ) return self def initialize(self, executor): return executor.run_function(self._do_initialize) @classmethod def get_msg_converter(cls): return SERDatasetParams @classmethod def get_supported_extensions(cls): return {"ser"} @classmethod def get_supported_io_backends(self): return [] @classmethod def detect_params(cls, path, executor): if path.lower().endswith(".ser"): ds = cls(path) ds = ds.initialize(executor) return { "parameters": { "path": path, "nav_shape": tuple(ds.shape.nav), "sig_shape": tuple(ds.shape.sig), }, "info": { "image_count": int(prod(ds.shape.nav)), "native_sig_shape": tuple(ds.shape.sig), } } return False @property def dtype(self): return self._meta.raw_dtype @property def shape(self): return self._meta.shape def check_valid(self): try: with fileSER(self._path) as f1: if f1.head['ValidNumberElements'] == 0: raise DataSetException("no data found in file") if f1.head['DataTypeID'] not in (0x4120, 0x4122): raise DataSetException("unknown datatype id: %s" % f1.head['DataTypeID']) return True except OSError as e: raise DataSetException("invalid dataset: %s" % e) from e def get_cache_key(self): return { "path": self._path, "shape": tuple(self.shape), "sync_offset": self._sync_offset, } def _get_fileset(self): assert self._num_frames is not None return SERFileSet([ SERFile( path=self._path, num_frames=self._num_frames, ) ]) def get_base_shape(self, roi): return (1,) + tuple(self.shape.sig) def adjust_tileshape(self, tileshape, roi): # force single-frame tiles return (1,) + tileshape[1:] def get_partitions(self): fileset = self._get_fileset() for part_slice, start, stop in self.get_slices(): yield SERPartition( path=self._path, meta=self._meta, partition_slice=part_slice, fileset=fileset, start_frame=start, num_frames=stop - start, io_backend=self.get_io_backend(), decoder=None, ) def __repr__(self): return f"<SERDataSet for {self._path}>"
class SERPartition(BasePartition): def __init__(self, path, *args, **kwargs): self._path = path super().__init__(*args, **kwargs) def validate_tiling_scheme(self, tiling_scheme): if tiling_scheme.shape.sig != self.shape.sig: raise ValueError( f"invalid tiling scheme ({tiling_scheme.shape!r}): sig shape must match" ) def _preprocess(self, tile_data, tile_slice): if self._corrections is None: return self._corrections.apply(tile_data, tile_slice) def get_tiles(self, tiling_scheme: TilingScheme, dest_dtype="float32", roi=None, array_backend: Optional[ArrayBackend] = None): if array_backend is None: array_backend = self.meta.array_backends[0] assert array_backend in (NUMPY, CUDA) sync_offset = self.meta.sync_offset shape = Shape((1,) + tuple(self.shape.sig), sig_dims=self.shape.sig.dims) tiling_scheme = tiling_scheme.adjust_for_partition(self) self.validate_tiling_scheme(tiling_scheme) start = self._start_frame if start < self.meta.image_count: stop = min(start + self._num_frames, self.meta.image_count) if roi is None: indices = np.arange(max(0, start), stop) # in case of a negative sync_offset, 'start' can be negative if start < 0: offset = abs(sync_offset) else: offset = start - sync_offset else: # The following is taken (effectively) from _default_get_read_ranges roi_nonzero = flat_nonzero(roi) shifted_roi = roi_nonzero + sync_offset roi_mask = np.logical_and(shifted_roi >= max(0, start), shifted_roi < stop) indices = shifted_roi[roi_mask] # in case of a negative sync_offset, 'start' can be negative if start < 0: offset = np.sum(roi_nonzero < abs(sync_offset)) else: offset = np.sum(roi_nonzero < start - sync_offset) with fileSER(self._path) as f: for num, idx in enumerate(indices): origin = (num + offset,) + tuple([0] * self.shape.sig.dims) tile_slice = Slice(origin=origin, shape=shape) data, metadata = f.getDataset(int(idx)) if data.dtype != np.dtype(dest_dtype): data = data.astype(dest_dtype) data = data.reshape(shape) self._preprocess(data, tile_slice) yield DataTile( data, tile_slice=tile_slice, scheme_idx=0, )