import os
import logging
from typing import Optional
import warnings
import contextlib
import numpy as np
from ncempy.io.ser import fileSER
from sparseconverter import CUDA, NUMPY, ArrayBackend
from libertem.common.math import prod, flat_nonzero
from libertem.common import Shape, Slice
from libertem.io.dataset.base.tiling_scheme import TilingScheme
from libertem.common.messageconverter import MessageConverter
from .base import (
DataSet, FileSet, BasePartition, DataSetException, DataSetMeta,
DataTile,
)
log = logging.getLogger(__name__)
class SERDatasetParams(MessageConverter):
SCHEMA = {
"$schema": "http://json-schema.org/draft-07/schema#",
"$id": "http://libertem.org/SERDatasetParams.schema.json",
"title": "SERDatasetParams",
"type": "object",
"properties": {
"type": {"const": "SER"},
"path": {"type": "string"},
"nav_shape": {
"type": "array",
"items": {"type": "number", "minimum": 1},
"minItems": 2,
"maxItems": 2
},
"sig_shape": {
"type": "array",
"items": {"type": "number", "minimum": 1},
"minItems": 2,
"maxItems": 2
},
"sync_offset": {"type": "number"},
},
"required": ["type", "path"]
}
def convert_to_python(self, raw_data):
data = {
k: raw_data[k]
for k in ["path"]
}
if "nav_shape" in raw_data:
data["nav_shape"] = tuple(raw_data["nav_shape"])
if "sig_shape" in raw_data:
data["sig_shape"] = tuple(raw_data["sig_shape"])
if "sync_offset" in raw_data:
data["sync_offset"] = raw_data["sync_offset"]
return data
class SERFile:
def __init__(self, path, num_frames):
self._path = path
self._num_frames = num_frames
def _get_handle(self):
return fileSER(self._path)
@contextlib.contextmanager
def get_handle(self):
with self._get_handle() as f:
yield f
@property
def num_frames(self):
return self._num_frames
@property
def start_idx(self):
return 0
@property
def end_idx(self):
return self.num_frames
class SERFileSet(FileSet):
pass
[docs]
class SERDataSet(DataSet):
"""
Read TIA SER files.
Examples
--------
>>> ds = ctx.load("ser", path="/path/to/file.ser") # doctest: +SKIP
Parameters
----------
path: str
Path to the .ser file
nav_shape: tuple of int, optional
A n-tuple that specifies the size of the navigation region ((y, x), but
can also be of length 1 for example for a line scan, or length 3 for
a data cube, for example)
sig_shape: tuple of int, optional
Signal/detector size (height, width)
sync_offset: int, optional
If positive, number of frames to skip from start
If negative, number of blank frames to insert at start
"""
def __init__(self, path, emipath=None, nav_shape=None,
sig_shape=None, sync_offset=0, io_backend=None):
super().__init__(io_backend=io_backend)
if io_backend is not None:
raise ValueError("SERDataSet currently doesn't support alternative I/O backends")
self._path = path
self._meta = None
self._filesize = None
self._num_frames = None
if emipath is not None:
warnings.warn(
"emipath is not used anymore, as it was removed from ncempy", DeprecationWarning
)
self._nav_shape = tuple(nav_shape) if nav_shape else nav_shape
self._sig_shape = tuple(sig_shape) if sig_shape else sig_shape
self._sync_offset = sync_offset
def _do_initialize(self):
self._filesize = os.stat(self._path).st_size
reader = SERFile(path=self._path, num_frames=None)
with reader.get_handle() as f1:
self._num_frames = f1.head['ValidNumberElements']
if f1.head['ValidNumberElements'] == 0:
raise DataSetException("no data found in file")
data, meta_data = f1.getDataset(0)
dtype = f1._dictDataType[meta_data['DataType']]
nav_dims = tuple(
reversed([
int(dim['DimensionSize'])
for dim in f1.head['Dimensions']
])
)
self._image_count = int(self._num_frames)
if self._nav_shape is None:
self._nav_shape = nav_dims
if self._sig_shape is None:
self._sig_shape = tuple(data.shape)
elif int(prod(self._sig_shape)) != int(prod(data.shape)):
raise DataSetException(
"sig_shape must be of size: %s" % int(prod(data.shape))
)
self._nav_shape_product = int(prod(self._nav_shape))
self._sync_offset_info = self.get_sync_offset_info()
self._shape = Shape(self._nav_shape + self._sig_shape, sig_dims=len(self._sig_shape))
self._meta = DataSetMeta(
shape=self._shape,
raw_dtype=dtype,
sync_offset=self._sync_offset,
image_count=self._image_count,
)
return self
def initialize(self, executor):
return executor.run_function(self._do_initialize)
@classmethod
def get_msg_converter(cls):
return SERDatasetParams
@classmethod
def get_supported_extensions(cls):
return {"ser"}
@classmethod
def get_supported_io_backends(self):
return []
@classmethod
def detect_params(cls, path, executor):
if path.lower().endswith(".ser"):
ds = cls(path)
ds = ds.initialize(executor)
return {
"parameters": {
"path": path,
"nav_shape": tuple(ds.shape.nav),
"sig_shape": tuple(ds.shape.sig),
},
"info": {
"image_count": int(prod(ds.shape.nav)),
"native_sig_shape": tuple(ds.shape.sig),
}
}
return False
@property
def dtype(self):
return self._meta.raw_dtype
@property
def shape(self):
return self._meta.shape
def check_valid(self):
try:
with fileSER(self._path) as f1:
if f1.head['ValidNumberElements'] == 0:
raise DataSetException("no data found in file")
if f1.head['DataTypeID'] not in (0x4120, 0x4122):
raise DataSetException("unknown datatype id: %s" % f1.head['DataTypeID'])
return True
except OSError as e:
raise DataSetException("invalid dataset: %s" % e) from e
def get_cache_key(self):
return {
"path": self._path,
"shape": tuple(self.shape),
"sync_offset": self._sync_offset,
}
def _get_fileset(self):
assert self._num_frames is not None
return SERFileSet([
SERFile(
path=self._path,
num_frames=self._num_frames,
)
])
def get_base_shape(self, roi):
return (1,) + tuple(self.shape.sig)
def adjust_tileshape(self, tileshape, roi):
# force single-frame tiles
return (1,) + tileshape[1:]
def get_partitions(self):
fileset = self._get_fileset()
for part_slice, start, stop in self.get_slices():
yield SERPartition(
path=self._path,
meta=self._meta,
partition_slice=part_slice,
fileset=fileset,
start_frame=start,
num_frames=stop - start,
io_backend=self.get_io_backend(),
decoder=None,
)
def __repr__(self):
return f"<SERDataSet for {self._path}>"
class SERPartition(BasePartition):
def __init__(self, path, *args, **kwargs):
self._path = path
super().__init__(*args, **kwargs)
def validate_tiling_scheme(self, tiling_scheme):
if tiling_scheme.shape.sig != self.shape.sig:
raise ValueError(
f"invalid tiling scheme ({tiling_scheme.shape!r}): sig shape must match"
)
def _preprocess(self, tile_data, tile_slice):
if self._corrections is None:
return
self._corrections.apply(tile_data, tile_slice)
def get_tiles(self, tiling_scheme: TilingScheme, dest_dtype="float32", roi=None,
array_backend: Optional[ArrayBackend] = None):
if array_backend is None:
array_backend = self.meta.array_backends[0]
assert array_backend in (NUMPY, CUDA)
sync_offset = self.meta.sync_offset
shape = Shape((1,) + tuple(self.shape.sig), sig_dims=self.shape.sig.dims)
tiling_scheme = tiling_scheme.adjust_for_partition(self)
self.validate_tiling_scheme(tiling_scheme)
start = self._start_frame
if start < self.meta.image_count:
stop = min(start + self._num_frames, self.meta.image_count)
if roi is None:
indices = np.arange(max(0, start), stop)
# in case of a negative sync_offset, 'start' can be negative
if start < 0:
offset = abs(sync_offset)
else:
offset = start - sync_offset
else:
# The following is taken (effectively) from _default_get_read_ranges
roi_nonzero = flat_nonzero(roi)
shifted_roi = roi_nonzero + sync_offset
roi_mask = np.logical_and(shifted_roi >= max(0, start),
shifted_roi < stop)
indices = shifted_roi[roi_mask]
# in case of a negative sync_offset, 'start' can be negative
if start < 0:
offset = np.sum(roi_nonzero < abs(sync_offset))
else:
offset = np.sum(roi_nonzero < start - sync_offset)
with fileSER(self._path) as f:
for num, idx in enumerate(indices):
origin = (num + offset,) + tuple([0] * self.shape.sig.dims)
tile_slice = Slice(origin=origin, shape=shape)
data, metadata = f.getDataset(int(idx))
if data.dtype != np.dtype(dest_dtype):
data = data.astype(dest_dtype)
data = data.reshape(shape)
self._preprocess(data, tile_slice)
yield DataTile(
data,
tile_slice=tile_slice,
scheme_idx=0,
)