Source code for libertem.io.dataset.npy

import os
import typing
from typing import Optional
import logging

import numpy as np
from numpy.lib.format import open_memmap
from libertem.common.messageconverter import MessageConverter

from libertem.io.dataset.base import (
    DataSet, FileSet, BasePartition, DataSetException, DataSetMeta, File, IOBackend,
)
from libertem.common import Shape
from libertem.common.math import prod

log = logging.getLogger(__name__)


class NPYDatasetParams(MessageConverter):
    SCHEMA = {
        "$schema": "http://json-schema.org/draft-07/schema#",
        "$id": "http://libertem.org/NPYDatasetParams.schema.json",
        "title": "NPYDatasetParams",
        "type": "object",
        "properties": {
            "type": {"const": "NPY"},
            "path": {"type": "string"},
            "nav_shape": {
                "type": "array",
                "items": {"type": "number", "minimum": 1},
                "minItems": 2,
                "maxItems": 2
            },
            "sig_shape": {
                "type": "array",
                "items": {"type": "number", "minimum": 1},
                "minItems": 2,
                "maxItems": 2
            },
            "sync_offset": {"type": "number"},
            "io_backend": {
                "enum": IOBackend.get_supported(),
            },
        },
        "required": ["type", "path"],
    }

    def convert_to_python(self, raw_data):
        data = {
            k: raw_data[k]
            for k in ["path"]
        }
        if "nav_shape" in raw_data:
            data["nav_shape"] = tuple(raw_data["nav_shape"])
        if "sig_shape" in raw_data:
            data["sig_shape"] = tuple(raw_data["sig_shape"])
        if "sync_offset" in raw_data:
            data["sync_offset"] = int(raw_data["sync_offset"])
        return data


class NPYInfo(typing.NamedTuple):
    dtype: str
    shape: tuple[int, ...]
    count: int
    offset: int


def read_npy_info(path: str) -> NPYInfo:
    mmp = open_memmap(path, mode='r')
    shape = mmp.shape
    c_contiguous = mmp.flags['C_CONTIGUOUS']
    count = mmp.size
    offset = mmp.offset
    dtype = mmp.dtype
    if not c_contiguous:
        raise DataSetException(
            'Unable to process NPY arrays that are not C_CONTIGUOUS, '
            'consider converting with np.ascontiguousarray().'
        )
    # Make absolutely sure the file is closed before anything else happens
    mmp._mmap.close()
    del mmp
    return NPYInfo(dtype=dtype, shape=shape, count=count, offset=offset)


[docs] class NPYDataSet(DataSet): """ .. versionadded:: 0.10.0 Read data stored in a NumPy .npy binary file. Dataset shape and dtype are inferred from the file header unless overridden by the arguments to this class. As of this time Fortran-ordered .npy files are not supported Parameters ---------- path : str The path to the .npy file sig_dims : int, optional, by default 2 The number of dimensions from the end of the full shape to interpret as signal dimensions. If None will be inferred from the sig_shape argument when present. nav_shape : Tuple[int, int], optional A nav_shape to apply to the dataset overriding the shape value read from the .npy header, by default None. This can be used to read a subset of the .npy file, or reshape the contained data. Frames are read in C-order from the beginning of the file. sig_shape : Tuple[int, int], optional A sig_shape to apply to the dataset overriding the shape value read from the .npy header, by default None. Pixels are read in C-order from the beginning of the file. sync_offset : int, optional, by default 0 If positive, number of frames to skip from start If negative, number of blank frames to insert at start io_backend : IOBackend, optional The I/O backend to use, see :ref:`io backends`, by default None. Raises ------ DataSetException If sig_dims is not an integer and cannot be inferred from sig_shape DataSetException If the supplied nav_shape + sig_shape describe an array larger than the contents of the .npy file DataSetException If the .npy file is Fortran-ordered Examples -------- >>> ds = ctx.load("npy", path='./path_to_file.npy') # doctest: +SKIP """ def __init__( self, path: str, sig_dims: Optional[int] = 2, nav_shape: Optional[tuple[int, int]] = None, sig_shape: Optional[tuple[int, int]] = None, sync_offset: int = 0, io_backend: Optional[IOBackend] = None, ): super().__init__(io_backend=io_backend) self._meta = None self._nav_shape = tuple(nav_shape) if nav_shape else nav_shape self._sig_shape = tuple(sig_shape) if sig_shape else sig_shape self._sig_dims = sig_dims if self._sig_shape is not None: if self._sig_dims is None: self._sig_dims = len(self._sig_shape) if len(self._sig_shape) != self._sig_dims: raise DataSetException(f'Mismatching sig_dims (= {self._sig_dims}) ' f'and sig_shape {self._sig_shape} arguments') if self._sig_dims is None or not isinstance(self._sig_dims, int): raise DataSetException('Must supply one of sig_dims or sig_shape to NPYDataSet') self._path = path self._sync_offset = sync_offset self._npy_info: typing.Optional[NPYInfo] = None def _get_filesize(self): return os.stat(self._path).st_size def initialize(self, executor) -> "DataSet": self._filesize = executor.run_function(self._get_filesize) npyinfo = executor.run_function(read_npy_info, self._path) self._npy_info = npyinfo np_shape = Shape(npyinfo.shape, sig_dims=self._sig_dims) sig_shape = self._sig_shape if self._sig_shape else np_shape.sig nav_shape = self._nav_shape if self._nav_shape else np_shape.nav shape = Shape(tuple(nav_shape) + tuple(sig_shape), sig_dims=self._sig_dims) # Trying to follow implementation of RawFileDataSet i.e. the _image_count # is the whole block of data interpreted as N frames of sig_shape, noting that # here sig_shape can be either user-supplied or from the npy metadata # if prod(sig_shape) is not a factor then bytes will be dropped at the end self._image_count = np_shape.size // prod(sig_shape) self._nav_shape_product = shape.nav.size self._sync_offset_info = self.get_sync_offset_info() self._meta = DataSetMeta( shape=shape, raw_dtype=np.dtype(npyinfo.dtype), sync_offset=self._sync_offset or 0, image_count=self._image_count, ) return self @property def dtype(self): return self._meta.raw_dtype @property def shape(self): return self._meta.shape @classmethod def detect_params(cls, path, executor): try: npy_info = executor.run_function(read_npy_info, path) # FIXME: assumption about number of sig dims shape = Shape(npy_info.shape, sig_dims=2) return { "parameters": { "path": path, "nav_shape": tuple(shape.nav), "sig_shape": tuple(shape.sig), }, "info": { "image_count": int(prod(shape.nav)), "native_sig_shape": tuple(shape.sig), } } except Exception as e: print(e) return False def get_diagnostics(self): return [ {"name": "dtype", "value": str(self.dtype)} ] @classmethod def get_supported_extensions(cls): return {"npy"} @classmethod def get_msg_converter(cls): return NPYDatasetParams def _get_fileset(self): assert self._npy_info is not None return FileSet([ File( path=self._path, start_idx=0, end_idx=self._meta.image_count, sig_shape=self.shape.sig, native_dtype=self._meta.raw_dtype, file_header=self._npy_info.offset, ) ]) def check_valid(self): try: fileset = self._get_fileset() backend = self.get_io_backend().get_impl() with backend.open_files(fileset): return True except (OSError, ValueError) as e: raise DataSetException("invalid dataset: %s" % e) def get_cache_key(self): return { "path": self._path, # nav_shape + sig_shape; included because changing nav_shape will change # the partition structure and cause errors "shape": tuple(self.shape), "dtype": str(self.dtype), "sync_offset": self._sync_offset, } def get_partitions(self): fileset = self._get_fileset() for part_slice, start, stop in self.get_slices(): yield BasePartition( meta=self._meta, fileset=fileset, partition_slice=part_slice, start_frame=start, num_frames=stop - start, io_backend=self.get_io_backend(), ) def __repr__(self): return f"<NPYDataSet of {self.dtype} shape={self.shape}>"