import time
import logging
from typing import Optional, TYPE_CHECKING
from collections.abc import Sequence
import psutil
import numpy as np
from sparseconverter import (
BACKENDS, ArrayBackend, for_backend, get_backend, SPARSE_BACKENDS, get_device_class,
conversion_cost
)
from libertem.common.math import prod, count_nonzero, flat_nonzero
from libertem.common.messageconverter import MessageConverter
from libertem.io.dataset.base import (
FileSet, BasePartition, DataSet, DataSetMeta, TilingScheme,
File, MMapBackend, DataSetException
)
from libertem.io.dataset.base.backend_mmap import MMapBackendImpl, MMapFile
from libertem.common import Shape, Slice
from libertem.io.dataset.base import DataTile
if TYPE_CHECKING:
from libertem.common.executor import JobExecutor
log = logging.getLogger(__name__)
class FakeMMapFile(MMapFile):
"""
Implementing the same interface as MMapFile, without filesystem backing
"""
def open(self):
self._arr = self.desc._data
self._mmap = self.desc._data
return self
def close(self):
self._arr = None
self._mmap = None
class MemBackend(MMapBackend):
def get_impl(self):
return MemBackendImpl()
class MemBackendImpl(MMapBackendImpl):
FILE_CLS = FakeMMapFile
def _set_readahead_hints(self, roi, fileset):
pass
def _get_tiles_roi(self, tiling_scheme, open_files, read_ranges, roi, sync_offset):
ds_sig_shape = tiling_scheme.dataset_shape.sig
sig_dims = tiling_scheme.shape.sig.dims
slices, ranges, scheme_indices = read_ranges
fh = open_files[0]
memmap = fh.array.reshape((fh.desc.num_frames,) + tuple(ds_sig_shape))
flat_roi = np.full(roi.reshape((-1,)).shape, False)
roi_nonzero = flat_nonzero(roi)
offset_roi = np.clip(roi_nonzero + sync_offset, 0, flat_roi.size)
flat_roi[offset_roi] = True
data_w_roi = memmap[flat_roi]
for idx in range(slices.shape[0]):
origin, shape = slices[idx]
scheme_idx = scheme_indices[idx]
tile_slice = Slice(
origin=origin,
shape=Shape(shape, sig_dims=sig_dims)
)
if sync_offset >= 0:
data_slice = tile_slice.get()
else:
frames_to_skip = count_nonzero(roi.reshape((-1,))[:abs(sync_offset)])
data_slice = Slice(
origin=(origin[0] - frames_to_skip,) + tuple(origin[-sig_dims:]),
shape=Shape(shape, sig_dims=sig_dims)
)
data_slice = data_slice.get()
data = data_w_roi[data_slice]
yield DataTile(
data,
tile_slice=tile_slice,
scheme_idx=scheme_idx,
)
def get_tiles(
self, decoder, tiling_scheme, fileset, read_ranges, roi, native_dtype, read_dtype,
sync_offset, corrections, array_backend: ArrayBackend,
):
if roi is None:
# support arbitrary tiling in case of no roi
with self.open_files(fileset) as open_files:
if sync_offset >= 0:
for tile in self._get_tiles_straight(
tiling_scheme, open_files, read_ranges, sync_offset
):
if tile.dtype != read_dtype or tile.c_contiguous is False:
data = tile.data.astype(read_dtype)
else:
data = tile.data
self.preprocess(data, tile.tile_slice, corrections)
data = for_backend(data, array_backend)
yield DataTile(data, tile.tile_slice, tile.scheme_idx)
else:
for tile in self._get_tiles_w_copy(
tiling_scheme=tiling_scheme,
open_files=open_files,
read_ranges=read_ranges,
read_dtype=read_dtype,
native_dtype=native_dtype,
decoder=decoder,
corrections=corrections,
):
data = for_backend(tile.data, array_backend)
yield DataTile(data, tile.tile_slice, tile.scheme_idx)
else:
with self.open_files(fileset) as open_files:
for tile in self._get_tiles_roi(
tiling_scheme=tiling_scheme,
open_files=open_files,
read_ranges=read_ranges,
roi=roi,
sync_offset=sync_offset,
):
data = tile.data.astype(read_dtype)
self.preprocess(data, tile.tile_slice, corrections)
data = for_backend(data, array_backend)
yield DataTile(data, tile.tile_slice, tile.scheme_idx)
class MemDatasetParams(MessageConverter):
SCHEMA = {
"$schema": "http://json-schema.org/draft-07/schema#",
"$id": "http://libertem.org/MEMDatasetParams.schema.json",
"title": "MEMDatasetParams",
"type": "object",
"properties": {
"type": {"const": "MEMORY"},
"tileshape": {
"type": "array",
"items": {"type": "number", "minimum": 1},
},
"datashape": {
"type": "array",
"items": {"type": "number", "minimum": 1},
},
"num_partitions": {"type": "number", "minimum": 1},
"sig_dims": {"type": "number", "minimum": 1},
"check_cast": {"type": "boolean"},
"crop_frames": {"type": "boolean"},
"tiledelay": {"type": "number"},
"nav_shape": {
"type": "array",
"items": {"type": "number", "minimum": 1},
"minItems": 2,
"maxItems": 2
},
"sig_shape": {
"type": "array",
"items": {"type": "number", "minimum": 1},
"minItems": 2,
"maxItems": 2
},
"sync_offset": {"type": "number"},
"array_backend": {"type": "string"},
},
"required": ["type", "tileshape", "num_partitions"],
}
def convert_to_python(self, raw_data):
data = {
k: raw_data[k]
for k in ["tileshape", "num_partitions", "sig_dims", "check_cast",
"crop_frames", "tiledelay", "datashape", "array_backend"]
if k in raw_data
}
if "nav_shape" in raw_data:
data["nav_shape"] = tuple(raw_data["nav_shape"])
if "sig_shape" in raw_data:
data["sig_shape"] = tuple(raw_data["sig_shape"])
if "sync_offset" in raw_data:
data["sync_offset"] = raw_data["sync_offset"]
return data
class MemoryFile(File):
def __init__(self, data, check_cast=True, *args, **kwargs):
self._data = data
self._check_cast = check_cast
super().__init__(*args, **kwargs)
@property
def data(self):
return self._data
[docs]
class MemoryDataSet(DataSet):
'''
This dataset is constructed from a NumPy array in memory for testing
purposes. It is not recommended for production use since it performs poorly with a
distributed executor.
Examples
--------
>>> data = np.zeros((2, 2, 64, 64), dtype=np.float32)
>>> ds = ctx.load('memory', data=data, sig_dims=2)
'''
def __init__(self, tileshape=None, num_partitions=None, data=None, sig_dims=None,
check_cast=True, tiledelay=None, datashape=None, base_shape=None,
force_need_decode=False, io_backend=None,
nav_shape=None, sig_shape=None, sync_offset=0, array_backends=None):
super().__init__(io_backend=io_backend)
if io_backend is not None:
raise ValueError("MemoryDataSet currently doesn't support alternative I/O backends")
# For HTTP API testing purposes: Allow to create empty dataset with given shape
if data is None:
if datashape is None:
raise DataSetException('MemoryDataSet can be created from either data [np.ndarray],'
' or datashape [tuple | Shape], both arguments are None')
data = np.zeros(datashape, dtype=np.float32)
if num_partitions is None:
num_partitions = psutil.cpu_count(logical=False)
# if tileshape is None:
# sig_shape = data.shape[-sig_dims:]
# target = 2**20
# framesize = prod(sig_shape)
# framecount = max(1, min(prod(data.shape[:-sig_dims]), int(target / framesize)))
# tileshape = (framecount, ) + sig_shape
self.data = data
self._base_shape = base_shape
self.num_partitions = num_partitions
if sig_dims is None:
if sig_shape is not None:
sig_dims = len(sig_shape)
elif nav_shape is not None:
sig_dims = len(data.shape) - len(nav_shape)
else:
sig_dims = 2
else:
if sig_shape is not None and len(sig_shape) != sig_dims:
raise ValueError(
f"Length of sig_shape {sig_shape} not matching sig_dims {sig_dims}."
)
self.sig_dims = sig_dims
if nav_shape is None:
nav_shape = data.shape[:-sig_dims]
else:
nav_shape = tuple(nav_shape)
if sig_shape is None:
sig_shape = data.shape[-sig_dims:]
else:
sig_shape = tuple(sig_shape)
if self.data.size % prod(sig_shape) != 0:
raise ValueError("Data size is not a multiple of sig shape")
self._image_count = self.data.size // prod(sig_shape)
self._nav_shape = nav_shape
self._sig_shape = sig_shape
self._sync_offset = sync_offset
self._array_backends = array_backends
self._check_cast = check_cast
self._tiledelay = tiledelay
self._force_need_decode = force_need_decode
self._nav_shape_product = int(prod(nav_shape))
self._shape = Shape(
nav_shape + sig_shape, sig_dims=self.sig_dims
)
if tileshape is None:
self.tileshape = None
else:
assert len(tileshape) == self.sig_dims + 1
self.tileshape = Shape(tileshape, sig_dims=self.sig_dims)
self._sync_offset_info = self.get_sync_offset_info()
self._meta = DataSetMeta(
shape=self._shape,
array_backends=self.array_backends,
raw_dtype=self.data.dtype,
sync_offset=self._sync_offset,
image_count=self._image_count,
)
def initialize(self, executor):
return self
@classmethod
def get_msg_converter(cls):
return MemDatasetParams
@classmethod
def detect_params(cls, data: np.ndarray, executor: "JobExecutor"):
try:
_ = data.shape
return {
"parameters": {
"data": data,
},
"info": {}
}
except AttributeError:
return False
@property
def dtype(self):
return self.data.dtype
@property
def shape(self):
return self._shape
@property
def array_backends(self) -> Sequence[ArrayBackend]:
"""
All backends can be returned on request
.. versionadded:: 0.11.0
"""
if self._array_backends is None:
native = get_backend(self.data)
is_sparse = native in SPARSE_BACKENDS
native_device_class = get_device_class(native)
cost_metric = {}
# Sort by tuple (cost_override, conversion_cost),
# meaning preference for native backend, same sparsity
# and same device class take precedence over measured conversion cost
for backend in BACKENDS:
cost_metric[backend] = [5, conversion_cost(native, backend)]
if backend == native:
cost_metric[backend][0] -= 2
# sparse==sparse or dense==dense
if (backend in SPARSE_BACKENDS) == is_sparse:
cost_metric[backend][0] -= 2
# Same device class
if get_device_class(backend) == native_device_class:
cost_metric[backend][0] -= 1
return tuple(sorted(BACKENDS, key=lambda k: cost_metric[k]))
else:
return self._array_backends
def check_valid(self):
return True
def get_cache_key(self):
return TypeError("memory data set is not cacheable yet")
def get_num_partitions(self):
return self.num_partitions
def get_base_shape(self, roi):
if self.tileshape is not None:
return self.tileshape
if self._base_shape is not None:
return self._base_shape
return super().get_base_shape(roi)
def adjust_tileshape(
self, tileshape: tuple[int, ...], roi: Optional[np.ndarray],
) -> tuple[int, ...]:
if self.tileshape is not None:
return tuple(self.tileshape)
return super().adjust_tileshape(tileshape, roi)
def need_decode(self, read_dtype, roi, corrections):
if self._force_need_decode:
return True
return super().need_decode(read_dtype, roi, corrections)
def get_io_backend(self):
return MemBackend()
def get_partitions(self):
fileset = FileSet([
MemoryFile(
path=None,
start_idx=0,
end_idx=self._image_count,
native_dtype=self.data.dtype,
sig_shape=self.shape.sig,
data=self.data.reshape((-1, *self.shape.sig)),
check_cast=self._check_cast,
)
])
for part_slice, start, stop in self.get_slices():
log.debug(
"creating partition slice %s start %s stop %s",
part_slice, start, stop,
)
yield MemPartition(
meta=self._meta,
partition_slice=part_slice,
fileset=fileset,
start_frame=start,
num_frames=stop - start,
tiledelay=self._tiledelay,
tileshape=self.tileshape,
force_need_decode=self._force_need_decode,
io_backend=self.get_io_backend(),
decoder=self.get_decoder(),
)
class MemPartition(BasePartition):
def __init__(self, tiledelay, tileshape, force_need_decode=False,
*args, **kwargs):
super().__init__(*args, **kwargs)
self._tiledelay = tiledelay
self._tileshape = tileshape
self._force_tileshape = True
self._force_need_decode = force_need_decode
def get_io_backend(self):
return MemBackend()
def get_macrotile(self, *args, **kwargs):
self._force_tileshape = False
mt = super().get_macrotile(*args, **kwargs)
self._force_tileshape = True
return mt
def get_tiles(self, *args, **kwargs):
if args and isinstance(args[0], TilingScheme):
tiling_scheme = args[0]
args = args[1:]
intent = tiling_scheme.intent
elif 'tiling_scheme' in kwargs:
tiling_scheme = kwargs.pop('tiling_scheme')
intent = tiling_scheme.intent
else:
# In this case we require the next if-block to execute
intent = None
# force our own tiling_scheme, if a tileshape is given:
if self._tileshape is not None and self._force_tileshape:
tiling_scheme = TilingScheme.make_for_shape(
tileshape=self._tileshape,
dataset_shape=self.meta.shape,
intent=intent,
)
tiles = super().get_tiles(tiling_scheme, *args, **kwargs)
if self._tiledelay:
log.debug("delayed get_tiles, tiledelay=%.3f" % self._tiledelay)
for tile in tiles:
yield tile
time.sleep(self._tiledelay)
else:
yield from tiles