Source code for libertem.io.dataset.raw_csr

import typing
import os

import scipy.sparse
import numpy as np
import numba
import tomli
from sparseconverter import SCIPY_CSR, ArrayBackend, for_backend, NUMPY

from libertem.common import Slice, Shape
from libertem.common.math import prod, count_nonzero
from libertem.io.corrections.corrset import CorrectionSet
from libertem.io.dataset.base import (
    DataTile, DataSet
)
from libertem.io.dataset.base.meta import DataSetMeta
from libertem.io.dataset.base.partition import Partition
from libertem.io.dataset.base.tiling_scheme import TilingScheme
from libertem.common.messageconverter import MessageConverter
from libertem.common.numba import numba_dtypes

if typing.TYPE_CHECKING:
    from libertem.io.dataset.base.backend import IOBackend
    from libertem.common.executor import JobExecutor
    import numpy.typing as nt


def load_toml(path: str):
    with open(path, "rb") as f:
        return tomli.load(f)


class RawCSRDatasetParams(MessageConverter):
    SCHEMA = {
        "$schema": "http://json-schema.org/draft-07/schema#",
        "$id": "http://libertem.org/RawCSRDatasetParams.schema.json",
        "title": "RawCSRDatasetParams",
        "type": "object",
        "properties": {
            "type": {"const": "RAW_CSR"},
            "path": {"type": "string"},
            "nav_shape": {
                "type": "array",
                "items": {"type": "number", "minimum": 1},
                "minItems": 2,
                "maxItems": 2
            },
            "sig_shape": {
                "type": "array",
                "items": {"type": "number", "minimum": 1},
                "minItems": 2,
                "maxItems": 2
            },
            "sync_offset": {"type": "number"},
        },
        "required": ["type", "path"]
    }

    def convert_to_python(self, raw_data):
        data = {
            k: raw_data[k]
            for k in ["path"]
        }
        if "nav_shape" in raw_data:
            data["nav_shape"] = tuple(raw_data["nav_shape"])
        if "sig_shape" in raw_data:
            data["sig_shape"] = tuple(raw_data["sig_shape"])
        if "sync_offset" in raw_data:
            data["sync_offset"] = raw_data["sync_offset"]
        return data


class CSRDescriptor(typing.NamedTuple):
    indptr_file: str
    indptr_dtype: np.dtype
    indices_file: str
    indices_dtype: np.dtype
    data_file: str
    data_dtype: np.dtype


class CSRTriple(typing.NamedTuple):
    indptr: np.ndarray
    indices: np.ndarray
    data: np.ndarray


[docs] class RawCSRDataSet(DataSet): """ Read sparse data in compressed sparse row (CSR) format from a triple of files that contain the index pointers, the coordinates and the values. See `Wikipedia article on the CSR format <https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_row_(CSR,_CRS_or_Yale_format)>`_ for more information on the format. The necessary parameters are specified in a TOML file like this: .. code-block:: [params] filetype = "raw_csr" nav_shape = [512, 512] sig_shape = [516, 516] [raw_csr] indptr_file = "rowind.dat" indptr_dtype = "<i4" indices_file = "coords.dat" indices_dtype = "<i4" data_file = "values.dat" data_dtype = "<i4"` Both the navigation and signal axis are flattened in the file, so that existing CSR libraries like scipy.sparse can be used directly by memory-mapping or reading the file contents. Parameters ---------- path : str Path to the TOML file with file names and other parameters for the sparse dataset. nav_shape : Tuple[int, int], optional A nav_shape to apply to the dataset overriding the shape value read from the TOML file, by default None. This can be used to read a subset of the data, or reshape the contained data. sig_shape : Tuple[int, int], optional A sig_shape to apply to the dataset overriding the shape value read from the TOML file, by default None. sync_offset : int, optional, by default 0 If positive, number of frames to skip from start If negative, number of blank frames to insert at start io_backend : IOBackend, optional The I/O backend to use, see :ref:`io backends`, by default None. Examples -------- >>> ds = ctx.load("raw_csr", path='./path_to.toml') # doctest: +SKIP """ def __init__( self, path: str, nav_shape: typing.Optional[tuple[int, ...]] = None, sig_shape: typing.Optional[tuple[int, ...]] = None, sync_offset: int = 0, io_backend: typing.Optional["IOBackend"] = None ): if io_backend is not None: raise NotImplementedError() super().__init__(io_backend=io_backend) self._path = path if nav_shape is not None: nav_shape = tuple(nav_shape) self._nav_shape = nav_shape if sig_shape is not None: sig_shape = tuple(sig_shape) self._sig_shape = sig_shape self._sync_offset = sync_offset self._conf = None self._descriptor = None def initialize(self, executor: "JobExecutor") -> "DataSet": self._conf = conf = executor.run_function(load_toml, self._path) assert conf is not None if conf['params']['filetype'].lower() != 'raw_csr': raise ValueError(f"Filetype is not CSR, found {conf['params']['filetype']}") nav_shape = tuple(conf['params']['nav_shape']) sig_shape = tuple(conf['params']['sig_shape']) if self._nav_shape is None: self._nav_shape = nav_shape if self._sig_shape is None: self._sig_shape = sig_shape else: if prod(self._sig_shape) != prod(sig_shape): raise ValueError(f"Sig size mismatch between {self._sig_shape} and {sig_shape}.") shape = Shape(self._nav_shape + self._sig_shape, sig_dims=len(self._sig_shape)) self._descriptor = descriptor = executor.run_function(get_descriptor, self._path) executor.run_function( check, descriptor=descriptor, nav_shape=self._nav_shape, sig_shape=self._sig_shape ) image_count = executor.run_function(get_nav_size, descriptor=descriptor) self._image_count = image_count self._nav_shape_product = int(prod(self._nav_shape)) self._sync_offset_info = self.get_sync_offset_info() self._meta = DataSetMeta( shape=shape, array_backends=[SCIPY_CSR], image_count=image_count, raw_dtype=descriptor.data_dtype, dtype=None, metadata=None, sync_offset=self._sync_offset, ) return self @property def dtype(self) -> "nt.DTypeLike": assert self._meta is not None return self._meta.raw_dtype @property def shape(self) -> Shape: assert self._meta is not None return self._meta.shape @property def array_backends(self) -> typing.Sequence[ArrayBackend]: assert self._meta is not None return self._meta.array_backends def get_base_shape(self, roi): return (1, ) + tuple(self.shape.sig) def get_max_io_size(self): # High value since referring to dense for the time being # Compromise between memory use during densification and # performance with native sparse return int(1024*1024*20) def check_valid(self) -> bool: return True # TODO @staticmethod def _get_filesize(path): return os.stat(path).st_size def supports_correction(self): return False @classmethod def detect_params(cls, path: str, executor: "JobExecutor"): try: _, extension = os.path.splitext(path) has_extension = extension.lstrip('.') in cls.get_supported_extensions() under_size_lim = executor.run_function(cls._get_filesize, path) < 2**20 # 1 MB if not (has_extension or under_size_lim): return False conf = executor.run_function(load_toml, path) if "params" not in conf: return False if "filetype" not in conf["params"]: return False if conf["params"]["filetype"].lower() != "raw_csr": return False descriptor = executor.run_function(get_descriptor, path) image_count = executor.run_function(get_nav_size, descriptor=descriptor) return { "parameters": { 'path': path, "nav_shape": conf["params"]["nav_shape"], "sig_shape": conf["params"]["sig_shape"], "sync_offset": 0, }, "info": { "image_count": image_count, } } except (TypeError, UnicodeDecodeError, tomli.TOMLDecodeError, OSError): return False @classmethod def get_msg_converter(cls) -> type["MessageConverter"]: return RawCSRDatasetParams def get_diagnostics(self): return [ {"name": "data dtype", "value": str(self._descriptor.data_dtype)}, {"name": "indptr dtype", "value": str(self._descriptor.indptr_dtype)}, {"name": "indices dtype", "value": str(self._descriptor.indices_dtype)}, ] # TODO: nonzero elements? @classmethod def get_supported_extensions(cls) -> set[str]: return {"toml"} def get_cache_key(self) -> str: raise NotImplementedError() # TODO @classmethod def get_supported_io_backends(cls) -> list[str]: return [] # FIXME: we may want to read using a backend in the future def adjust_tileshape( self, tileshape: tuple[int, ...], roi: typing.Optional[np.ndarray] ) -> tuple[int, ...]: return (tileshape[0],) + tuple(self._sig_shape) def need_decode( self, read_dtype: "nt.DTypeLike", roi: typing.Optional[np.ndarray], corrections: typing.Optional[CorrectionSet] ) -> bool: return super().need_decode(read_dtype, roi, corrections) def get_partitions(self) -> typing.Generator[Partition, None, None]: assert self._meta is not None for part_slice, start, stop in self.get_slices(): yield RawCSRPartition( descriptor=self._descriptor, meta=self._meta, partition_slice=part_slice, start_frame=start, num_frames=stop - start, io_backend=None, decoder=None, )
class RawCSRPartition(Partition): def __init__( self, descriptor: CSRDescriptor, start_frame: int, num_frames: int, *args, **kwargs ): self._descriptor = descriptor self._start_frame = start_frame self._num_frames = num_frames self._corrections = CorrectionSet() self._worker_context = None super().__init__(*args, **kwargs) def set_corrections(self, corrections: typing.Optional[CorrectionSet]): if corrections is not None and corrections.have_corrections(): raise NotImplementedError("corrections not implemented for raw CSR data set") def validate_tiling_scheme(self, tiling_scheme: TilingScheme): if len(tiling_scheme) != 1: raise ValueError("Cannot slice CSR data in sig dimensions") def get_locations(self): # Allow using any worker by default return None def get_tiles( self, tiling_scheme: TilingScheme, dest_dtype="float32", roi=None, array_backend: typing.Optional[ArrayBackend] = None ): assert array_backend == SCIPY_CSR or array_backend is None tiling_scheme = tiling_scheme.adjust_for_partition(self) self.validate_tiling_scheme(tiling_scheme) triple = get_triple(self._descriptor) if self._corrections is not None and self._corrections.have_corrections(): raise NotImplementedError( "corrections are not yet supported for raw CSR" ) if roi is None: yield from read_tiles_straight( triple, self.slice, self.meta.sync_offset, tiling_scheme, dest_dtype ) else: yield from read_tiles_with_roi( triple, self.slice, self.meta.sync_offset, tiling_scheme, roi, dest_dtype ) def sliced_indptr(triple: CSRTriple, partition_slice: Slice, sync_offset: int): assert len(partition_slice.shape.nav) == 1 skip = min(0, partition_slice.origin[0] + sync_offset) indptr_start = max(0, partition_slice.origin[0] + sync_offset) indptr_stop = max(0, partition_slice.origin[0] + partition_slice.shape.nav[0] + 1 + sync_offset) return skip, triple.indptr[indptr_start:indptr_stop] def get_triple(descriptor: CSRDescriptor) -> CSRTriple: data: np.ndarray = np.memmap( descriptor.data_file, dtype=descriptor.data_dtype, mode='r' ) indices: np.ndarray = np.memmap( descriptor.indices_file, dtype=descriptor.indices_dtype, mode='r' ) indptr: np.ndarray = np.memmap( descriptor.indptr_file, dtype=descriptor.indptr_dtype, mode='r' ) return CSRTriple( indptr=indptr, indices=indices, data=data, ) def check(descriptor: CSRDescriptor, nav_shape, sig_shape, debug=False): triple = get_triple(descriptor) if triple.indices.shape != triple.data.shape: raise RuntimeError('Shape mismatch between data and indices.') if debug: assert np.min(triple.indices) >= 0 assert np.max(triple.indices) < prod(sig_shape) assert np.min(triple.indptr) >= 0 assert np.max(triple.indptr) == len(triple.indices) def get_descriptor(path: str) -> CSRDescriptor: """ Get a CSRDescriptor from the path to a toml sidecar file """ conf = load_toml(path) assert conf is not None if conf['params']['filetype'].lower() != 'raw_csr': raise ValueError(f"Filetype is not CSR, found {conf['params']['filetype']}") base_path = os.path.dirname(path) # make sure the key is not case sensitive to follow the convention of # the Context.load() function. csr_key = conf['params']['filetype'] csr_conf = conf[csr_key] return CSRDescriptor( indptr_file=os.path.join(base_path, csr_conf['indptr_file']), indptr_dtype=csr_conf['indptr_dtype'], indices_file=os.path.join(base_path, csr_conf['indices_file']), indices_dtype=csr_conf['indices_dtype'], data_file=os.path.join(base_path, csr_conf['data_file']), data_dtype=csr_conf['data_dtype'], ) def get_nav_size(descriptor: CSRDescriptor) -> int: ''' To run efficiently on a remote worker for dataset initialization ''' indptr = np.memmap( descriptor.indptr_file, dtype=descriptor.indptr_dtype, mode='r', ) return len(indptr) - 1 def read_tiles_straight( triple: CSRTriple, partition_slice: Slice, sync_offset: int, tiling_scheme: TilingScheme, dest_dtype: np.dtype, ): assert len(tiling_scheme) == 1 skip, indptr = sliced_indptr( triple, partition_slice=partition_slice, sync_offset=sync_offset ) sig_shape = tuple(partition_slice.shape.sig) sig_size = partition_slice.shape.sig.size sig_dims = len(sig_shape) # Technically, one could use the slicing implementation of csr_matrix here. # However, it is slower, presumably because it takes a copy # Furthermore it provides a template to use an actual I/O backend here # instead of memory mapping. for indptr_start in range(0, len(indptr) - 1, tiling_scheme.depth): tile_start = indptr_start - skip # skip is a negative value or 0 indptr_stop = min(indptr_start + tiling_scheme.depth, len(indptr) - 1) if indptr_stop - indptr_start <= 0: continue indptr_slice = indptr[indptr_start:indptr_stop + 1] start = indptr[indptr_start] stop = indptr[indptr_stop] data = triple.data[start:stop] if dest_dtype != data.dtype: data = data.astype(dest_dtype) indices = triple.indices[start:stop] indptr_slice = indptr_slice - indptr_slice[0] arr = scipy.sparse.csr_matrix( (data, indices, indptr_slice), shape=(indptr_stop - indptr_start, sig_size) ) tile_slice = Slice( origin=(partition_slice.origin[0] + tile_start, ) + (0, ) * sig_dims, shape=Shape((arr.shape[0], ) + sig_shape, sig_dims=sig_dims), ) yield DataTile( data=arr, tile_slice=tile_slice, scheme_idx=0, ) def populate_tile( indptr_tile_start: "np.ndarray", indptr_tile_stop: "np.ndarray", orig_data: "np.ndarray", orig_indices: "np.ndarray", data_out: "np.ndarray", indices_out: "np.ndarray", indptr_out: "np.ndarray", ): offset = 0 indptr_out[0] = 0 for i, (start, stop) in enumerate(zip(indptr_tile_start, indptr_tile_stop)): chunk_size = stop - start data_out[offset:offset + chunk_size] = orig_data[start:stop] indices_out[offset:offset + chunk_size] = orig_indices[start:stop] offset += chunk_size indptr_out[i + 1] = offset populate_tile_numba = numba.njit(populate_tile) def can_use_numba(triple: CSRTriple) -> bool: return all(d in numba_dtypes for d in (triple.data.dtype, triple.indices.dtype, triple.indptr.dtype)) def read_tiles_with_roi( triple: CSRTriple, partition_slice: Slice, sync_offset: int, tiling_scheme: TilingScheme, roi: np.ndarray, dest_dtype: np.dtype, ): assert len(tiling_scheme) == 1 roi = roi.reshape((-1, )) part_start = max(0, partition_slice.origin[0]) tile_offset = count_nonzero(roi[:part_start]) part_roi = partition_slice.get(roi, nav_only=True) skip, indptr = sliced_indptr(triple, partition_slice=partition_slice, sync_offset=sync_offset) if skip < 0: skipped_part_roi = part_roi[-skip:] else: skipped_part_roi = part_roi roi_overhang = max(0, len(skipped_part_roi) - len(indptr) + 1) if roi_overhang: real_part_roi = skipped_part_roi[:-roi_overhang] else: real_part_roi = skipped_part_roi real_part_roi = for_backend(real_part_roi, NUMPY) sig_shape = tuple(partition_slice.shape.sig) sig_size = partition_slice.shape.sig.size sig_dims = len(sig_shape) start_values = indptr[:-1][real_part_roi] stop_values = indptr[1:][real_part_roi] # Implementing this "by hand" instead of fancy indexing to provide a template to use an # actual I/O backend here instead of memory mapping. # The native scipy.sparse.csr_matrix implementation of fancy indexing # with a boolean mask for nav is very fast. if can_use_numba(triple): my_populate_tile = populate_tile_numba else: my_populate_tile = populate_tile for indptr_start in range(0, len(part_roi), tiling_scheme.depth): indptr_stop = min(indptr_start + tiling_scheme.depth, len(start_values)) indptr_start = min(indptr_start, indptr_stop) # Don't read empty slices if indptr_stop - indptr_start <= 0: continue # Cast to int64 to avoid later upcasting to float64 in case of uint64 # We can safely assume that files have less than 2**63 entries so that casting # from uint64 to int64 should be safe indptr_tile_start = start_values[indptr_start:indptr_stop].astype(np.int64) indptr_tile_stop = stop_values[indptr_start:indptr_stop].astype(np.int64) size = sum(indptr_tile_stop - indptr_tile_start) data = np.zeros(dtype=dest_dtype, shape=size) indices = np.zeros(dtype=triple.indices.dtype, shape=size) indptr_slice = np.zeros( dtype=indptr.dtype, shape=indptr_stop - indptr_start + 1 ) my_populate_tile( indptr_tile_start=indptr_tile_start, indptr_tile_stop=indptr_tile_stop, orig_data=triple.data, orig_indices=triple.indices, data_out=data, indices_out=indices, indptr_out=indptr_slice, ) arr = scipy.sparse.csr_matrix( (data, indices, indptr_slice), shape=(indptr_stop - indptr_start, sig_size) ) tile_slice = Slice( origin=(tile_offset + indptr_start, ) + (0, ) * sig_dims, shape=Shape((indptr_stop - indptr_start, ) + sig_shape, sig_dims=sig_dims), ) yield DataTile( data=arr, tile_slice=tile_slice, scheme_idx=0, )