Source code for libertem.io.dataset.npy

import io
import os
import sys
import typing
from typing import Optional
import logging

import numpy as np
from numpy.lib.utils import safe_eval
from numpy.lib.format import read_magic
from numpy.compat import long, asstr
from libertem.common.messageconverter import MessageConverter

from libertem.io.dataset.base import (
    DataSet, FileSet, BasePartition, DataSetException, DataSetMeta, File, IOBackend,
)
from libertem.common import Shape
from libertem.common.math import prod

log = logging.getLogger(__name__)


class NPYDatasetParams(MessageConverter):
    SCHEMA = {
        "$schema": "http://json-schema.org/draft-07/schema#",
        "$id": "http://libertem.org/NPYDatasetParams.schema.json",
        "title": "NPYDatasetParams",
        "type": "object",
        "properties": {
            "type": {"const": "NPY"},
            "path": {"type": "string"},
            "nav_shape": {
                "type": "array",
                "items": {"type": "number", "minimum": 1},
                "minItems": 2,
                "maxItems": 2
            },
            "sig_shape": {
                "type": "array",
                "items": {"type": "number", "minimum": 1},
                "minItems": 2,
                "maxItems": 2
            },
            "sync_offset": {"type": "number"},
            "io_backend": {
                "enum": IOBackend.get_supported(),
            },
        },
        "required": ["type", "path"],
    }

    def convert_to_python(self, raw_data):
        data = {
            k: raw_data[k]
            for k in ["path"]
        }
        if "nav_shape" in raw_data:
            data["nav_shape"] = tuple(raw_data["nav_shape"])
        if "sig_shape" in raw_data:
            data["sig_shape"] = tuple(raw_data["sig_shape"])
        if "sync_offset" in raw_data:
            data["sync_offset"] = int(raw_data["sync_offset"])
        return data


class NPYInfo(typing.NamedTuple):
    dtype: str
    shape: tuple[int, ...]
    count: int
    offset: int


# `_read_bytes`, `_read_array_header`, `_filter_header` are
# stolen from `numpy.lib.format`, as they don't appear to be
# part of the public API.

def _read_bytes(fp, size, error_template="ran out of data"):
    """
    Read from file-like object until size bytes are read.
    Raises ValueError if not EOF is encountered before size bytes are read.
    Non-blocking objects only supported if they derive from io objects.

    Required as e.g. ZipExtFile in python 2.6 can return less data than
    requested.
    """
    data = b''
    while True:
        # io files (default in python3) return None or raise on
        # would-block, python2 file will truncate, probably nothing can be
        # done about that.  note that regular files can't be non-blocking
        try:
            r = fp.read(size - len(data))
            data += r
            if len(r) == 0 or len(data) == size:
                break
        except io.BlockingIOError:
            pass
    if len(data) != size:
        msg = "EOF: reading %s, expected %d bytes got %d"
        raise ValueError(msg % (error_template, size, len(data)))
    else:
        return data


def _read_array_header(fp, version):
    """
    see read_array_header_1_0
    """
    # Read an unsigned, little-endian short int which has the length of the
    # header.
    import struct
    if version == (1, 0):
        hlength_str = _read_bytes(fp, 2, "array header length")
        header_length = struct.unpack('<H', hlength_str)[0]
        header = _read_bytes(fp, header_length, "array header")
    elif version == (2, 0):
        hlength_str = _read_bytes(fp, 4, "array header length")
        header_length = struct.unpack('<I', hlength_str)[0]
        header = _read_bytes(fp, header_length, "array header")
    else:
        raise ValueError("Invalid version %r" % version)

    # The header is a pretty-printed string representation of a literal
    # Python dictionary with trailing newlines padded to a 16-byte
    # boundary. The keys are strings.
    #   "shape" : tuple of int
    #   "fortran_order" : bool
    #   "descr" : dtype.descr
    header = _filter_header(header)
    try:
        d = safe_eval(header)
    except SyntaxError as e:
        msg = "Cannot parse header: %r\nException: %r"
        raise ValueError(msg % (header, e))
    if not isinstance(d, dict):
        msg = "Header is not a dictionary: %r"
        raise ValueError(msg % d)
    keys = sorted(d.keys())
    if keys != ['descr', 'fortran_order', 'shape']:
        msg = "Header does not contain the correct keys: %r"
        raise ValueError(msg % (keys,))

    # Sanity-check the values.
    if (not isinstance(d['shape'], tuple)
            or not np.all([isinstance(x, (int, long)) for x in d['shape']])):
        msg = "shape is not valid: %r"
        raise ValueError(msg % (d['shape'],))
    if not isinstance(d['fortran_order'], bool):
        msg = "fortran_order is not a valid bool: %r"
        raise ValueError(msg % (d['fortran_order'],))
    try:
        dtype = np.dtype(d['descr'])
    except TypeError:
        msg = "descr is not a valid dtype descriptor: %r"
        raise ValueError(msg % (d['descr'],))

    return d['shape'], d['fortran_order'], dtype


def _filter_header(s):
    """Clean up 'L' in npz header ints.

    Cleans up the 'L' in strings representing integers. Needed to allow npz
    headers produced in Python2 to be read in Python3.

    Parameters
    ----------
    s : byte string
        Npy file header.

    Returns
    -------
    header : str
        Cleaned up header.

    """
    import tokenize
    if sys.version_info[0] >= 3:
        from io import StringIO
    else:
        from StringIO import StringIO

    tokens = []
    last_token_was_number = False
    for token in tokenize.generate_tokens(StringIO(asstr(s)).read):
        token_type = token[0]
        token_string = token[1]
        if (last_token_was_number
                and token_type == tokenize.NAME
                and token_string == "L"):
            continue
        else:
            tokens.append(token)
        last_token_was_number = (token_type == tokenize.NUMBER)
    return tokenize.untokenize(tokens)


def read_npy_info(path: str) -> NPYInfo:
    with open(path, "rb") as fp:
        version = read_magic(fp)
        shape, fortran_order, dtype = _read_array_header(fp, version)
        if fortran_order:
            raise DataSetException('Unable to process Fortran-ordered NPY arrays, '
                                   'consider converting with np.ascontiguousarray().')
        if len(shape) == 0:
            count = 1
        else:
            count = int(np.multiply.reduce(shape, dtype=np.int64))
        offset = fp.tell()
        return NPYInfo(dtype=dtype, shape=shape, count=count, offset=offset)


[docs] class NPYDataSet(DataSet): """ .. versionadded:: 0.10.0 Read data stored in a NumPy .npy binary file. Dataset shape and dtype are inferred from the file header unless overridden by the arguments to this class. As of this time Fortran-ordered .npy files are not supported Parameters ---------- path : str The path to the .npy file sig_dims : int, optional, by default 2 The number of dimensions from the end of the full shape to interpret as signal dimensions. If None will be inferred from the sig_shape argument when present. nav_shape : Tuple[int, int], optional A nav_shape to apply to the dataset overriding the shape value read from the .npy header, by default None. This can be used to read a subset of the .npy file, or reshape the contained data. Frames are read in C-order from the beginning of the file. sig_shape : Tuple[int, int], optional A sig_shape to apply to the dataset overriding the shape value read from the .npy header, by default None. Pixels are read in C-order from the beginning of the file. sync_offset : int, optional, by default 0 If positive, number of frames to skip from start If negative, number of blank frames to insert at start io_backend : IOBackend, optional The I/O backend to use, see :ref:`io backends`, by default None. Raises ------ DataSetException If sig_dims is not an integer and cannot be inferred from sig_shape DataSetException If the supplied nav_shape + sig_shape describe an array larger than the contents of the .npy file DataSetException If the .npy file is Fortran-ordered Examples -------- >>> ds = ctx.load("npy", path='./path_to_file.npy') # doctest: +SKIP """ def __init__( self, path: str, sig_dims: Optional[int] = 2, nav_shape: Optional[tuple[int, int]] = None, sig_shape: Optional[tuple[int, int]] = None, sync_offset: int = 0, io_backend: Optional[IOBackend] = None, ): super().__init__(io_backend=io_backend) self._meta = None self._nav_shape = tuple(nav_shape) if nav_shape else nav_shape self._sig_shape = tuple(sig_shape) if sig_shape else sig_shape self._sig_dims = sig_dims if self._sig_shape is not None: if self._sig_dims is None: self._sig_dims = len(self._sig_shape) if len(self._sig_shape) != self._sig_dims: raise DataSetException(f'Mismatching sig_dims (= {self._sig_dims}) ' f'and sig_shape {self._sig_shape} arguments') if self._sig_dims is None or not isinstance(self._sig_dims, int): raise DataSetException('Must supply one of sig_dims or sig_shape to NPYDataSet') self._path = path self._sync_offset = sync_offset self._npy_info: typing.Optional[NPYInfo] = None def _get_filesize(self): return os.stat(self._path).st_size def initialize(self, executor) -> "DataSet": self._filesize = executor.run_function(self._get_filesize) npyinfo = executor.run_function(read_npy_info, self._path) self._npy_info = npyinfo np_shape = Shape(npyinfo.shape, sig_dims=self._sig_dims) sig_shape = self._sig_shape if self._sig_shape else np_shape.sig nav_shape = self._nav_shape if self._nav_shape else np_shape.nav shape = Shape(tuple(nav_shape) + tuple(sig_shape), sig_dims=self._sig_dims) # Trying to follow implementation of RawFileDataSet i.e. the _image_count # is the whole block of data interpreted as N frames of sig_shape, noting that # here sig_shape can be either user-supplied or from the npy metadata # if prod(sig_shape) is not a factor then bytes will be dropped at the end self._image_count = np_shape.size // prod(sig_shape) self._nav_shape_product = shape.nav.size self._sync_offset_info = self.get_sync_offset_info() self._meta = DataSetMeta( shape=shape, raw_dtype=np.dtype(npyinfo.dtype), sync_offset=self._sync_offset or 0, image_count=self._image_count, ) return self @property def dtype(self): return self._meta.raw_dtype @property def shape(self): return self._meta.shape @classmethod def detect_params(cls, path, executor): try: npy_info = executor.run_function(read_npy_info, path) # FIXME: assumption about number of sig dims shape = Shape(npy_info.shape, sig_dims=2) return { "parameters": { "path": path, "nav_shape": tuple(shape.nav), "sig_shape": tuple(shape.sig), }, "info": { "image_count": int(prod(shape.nav)), "native_sig_shape": tuple(shape.sig), } } except Exception as e: print(e) return False def get_diagnostics(self): return [ {"name": "dtype", "value": str(self.dtype)} ] @classmethod def get_supported_extensions(cls): return {"npy"} @classmethod def get_msg_converter(cls): return NPYDatasetParams def _get_fileset(self): assert self._npy_info is not None return FileSet([ File( path=self._path, start_idx=0, end_idx=self._meta.image_count, sig_shape=self.shape.sig, native_dtype=self._meta.raw_dtype, file_header=self._npy_info.offset, ) ]) def check_valid(self): try: fileset = self._get_fileset() backend = self.get_io_backend().get_impl() with backend.open_files(fileset): return True except (OSError, ValueError) as e: raise DataSetException("invalid dataset: %s" % e) def get_cache_key(self): return { "path": self._path, # nav_shape + sig_shape; included because changing nav_shape will change # the partition structure and cause errors "shape": tuple(self.shape), "dtype": str(self.dtype), "sync_offset": self._sync_offset, } def get_partitions(self): fileset = self._get_fileset() for part_slice, start, stop in self.get_slices(): yield BasePartition( meta=self._meta, fileset=fileset, partition_slice=part_slice, start_frame=start, num_frames=stop - start, io_backend=self.get_io_backend(), ) def __repr__(self): return f"<NPYDataSet of {self.dtype} shape={self.shape}>"