Source code for libertem.io.dataset.raw_csr

import typing
import os

import scipy.sparse
import numpy as np
import numba
import tomli
from sparseconverter import SCIPY_CSR, ArrayBackend, for_backend, NUMPY

from libertem.common import Slice, Shape
from libertem.common.math import prod, count_nonzero
from libertem.io.corrections.corrset import CorrectionSet
from libertem.io.dataset.base import (
    DataTile, DataSet
)
from libertem.io.dataset.base.meta import DataSetMeta
from libertem.io.dataset.base.partition import Partition
from libertem.io.dataset.base.tiling_scheme import TilingScheme
from libertem.common.messageconverter import MessageConverter
from libertem.common.numba import numba_dtypes

if typing.TYPE_CHECKING:
    from libertem.io.dataset.base.backend import IOBackend
    from libertem.common.executor import JobExecutor
    import numpy.typing as nt


def load_toml(path: str):
    with open(path, "rb") as f:
        return tomli.load(f)


class RawCSRDatasetParams(MessageConverter):
    SCHEMA = {
        "$schema": "http://json-schema.org/draft-07/schema#",
        "$id": "http://libertem.org/RawCSRDatasetParams.schema.json",
        "title": "RawCSRDatasetParams",
        "type": "object",
        "properties": {
            "type": {"const": "RAW_CSR"},
            "path": {"type": "string"},
            "nav_shape": {
                "type": "array",
                "items": {"type": "number", "minimum": 1},
                "minItems": 2,
                "maxItems": 2
            },
            "sig_shape": {
                "type": "array",
                "items": {"type": "number", "minimum": 1},
                "minItems": 2,
                "maxItems": 2
            },
            "sync_offset": {"type": "number"},
        },
        "required": ["type", "path"]
    }

    def convert_to_python(self, raw_data):
        data = {
            k: raw_data[k]
            for k in ["path"]
        }
        if "nav_shape" in raw_data:
            data["nav_shape"] = tuple(raw_data["nav_shape"])
        if "sig_shape" in raw_data:
            data["sig_shape"] = tuple(raw_data["sig_shape"])
        if "sync_offset" in raw_data:
            data["sync_offset"] = raw_data["sync_offset"]
        return data


class CSRDescriptor(typing.NamedTuple):
    indptr_file: str
    indptr_dtype: np.dtype
    indices_file: str
    indices_dtype: np.dtype
    data_file: str
    data_dtype: np.dtype


class CSRTriple(typing.NamedTuple):
    indptr: np.ndarray
    indices: np.ndarray
    data: np.ndarray


[docs] class RawCSRDataSet(DataSet): """ Read sparse data in compressed sparse row (CSR) format from a triple of files that contain the index pointers, the coordinates and the values. See `Wikipedia article on the CSR format <https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_row_(CSR,_CRS_or_Yale_format)>`_ for more information on the format. The necessary parameters are specified in a TOML file like this: .. code-block:: [params] filetype = "raw_csr" nav_shape = [512, 512] sig_shape = [516, 516] [raw_csr] indptr_file = "rowind.dat" indptr_dtype = "<i4" indices_file = "coords.dat" indices_dtype = "<i4" data_file = "values.dat" data_dtype = "<i4"` Both the navigation and signal axis are flattened in the file, so that existing CSR libraries like scipy.sparse can be used directly by memory-mapping or reading the file contents. Parameters ---------- path : str Path to the TOML file with file names and other parameters for the sparse dataset. nav_shape : Tuple[int, int], optional A nav_shape to apply to the dataset overriding the shape value read from the TOML file, by default None. This can be used to read a subset of the data, or reshape the contained data. sig_shape : Tuple[int, int], optional A sig_shape to apply to the dataset overriding the shape value read from the TOML file, by default None. sync_offset : int, optional, by default 0 If positive, number of frames to skip from start If negative, number of blank frames to insert at start io_backend : IOBackend, optional The I/O backend to use, see :ref:`io backends`, by default None. num_partitions: int, optional Override the number of partitions. This is useful if the default number of partitions, chosen based on common workloads, creates partitions which are too large (or small) for the UDFs being run on this dataset. Examples -------- >>> ds = ctx.load("raw_csr", path='./path_to.toml') # doctest: +SKIP """ def __init__( self, path: str, nav_shape: typing.Optional[tuple[int, ...]] = None, sig_shape: typing.Optional[tuple[int, ...]] = None, sync_offset: int = 0, io_backend: typing.Optional["IOBackend"] = None, num_partitions: typing.Optional[int] = None, ): if io_backend is not None: raise NotImplementedError() super().__init__( io_backend=io_backend, num_partitions=num_partitions, ) self._path = path if nav_shape is not None: nav_shape = tuple(nav_shape) self._nav_shape = nav_shape if sig_shape is not None: sig_shape = tuple(sig_shape) self._sig_shape = sig_shape self._sync_offset = sync_offset self._conf = None self._descriptor = None def initialize(self, executor: "JobExecutor") -> "DataSet": self._conf = conf = executor.run_function(load_toml, self._path) assert conf is not None if conf['params']['filetype'].lower() != 'raw_csr': raise ValueError(f"Filetype is not CSR, found {conf['params']['filetype']}") nav_shape = tuple(conf['params']['nav_shape']) sig_shape = tuple(conf['params']['sig_shape']) if self._nav_shape is None: self._nav_shape = nav_shape if self._sig_shape is None: self._sig_shape = sig_shape else: if prod(self._sig_shape) != prod(sig_shape): raise ValueError(f"Sig size mismatch between {self._sig_shape} and {sig_shape}.") shape = Shape(self._nav_shape + self._sig_shape, sig_dims=len(self._sig_shape)) self._descriptor = descriptor = executor.run_function(get_descriptor, self._path) executor.run_function( check, descriptor=descriptor, nav_shape=self._nav_shape, sig_shape=self._sig_shape ) image_count = executor.run_function(get_nav_size, descriptor=descriptor) self._image_count = image_count self._nav_shape_product = int(prod(self._nav_shape)) self._sync_offset_info = self.get_sync_offset_info() self._meta = DataSetMeta( shape=shape, array_backends=[SCIPY_CSR], image_count=image_count, raw_dtype=descriptor.data_dtype, dtype=None, metadata=None, sync_offset=self._sync_offset, ) return self @property def dtype(self) -> "nt.DTypeLike": assert self._meta is not None return self._meta.raw_dtype @property def shape(self) -> Shape: assert self._meta is not None return self._meta.shape @property def array_backends(self) -> typing.Sequence[ArrayBackend]: assert self._meta is not None return self._meta.array_backends def get_base_shape(self, roi): return (1, ) + tuple(self.shape.sig) def get_max_io_size(self): # High value since referring to dense for the time being # Compromise between memory use during densification and # performance with native sparse return int(1024*1024*20) def check_valid(self) -> bool: return True # TODO @staticmethod def _get_filesize(path): return os.stat(path).st_size def supports_correction(self): return False @classmethod def detect_params(cls, path: str, executor: "JobExecutor"): try: _, extension = os.path.splitext(path) has_extension = extension.lstrip('.') in cls.get_supported_extensions() under_size_lim = executor.run_function(cls._get_filesize, path) < 2**20 # 1 MB if not (has_extension or under_size_lim): return False conf = executor.run_function(load_toml, path) if "params" not in conf: return False if "filetype" not in conf["params"]: return False if conf["params"]["filetype"].lower() != "raw_csr": return False descriptor = executor.run_function(get_descriptor, path) image_count = executor.run_function(get_nav_size, descriptor=descriptor) return { "parameters": { 'path': path, "nav_shape": conf["params"]["nav_shape"], "sig_shape": conf["params"]["sig_shape"], "sync_offset": 0, }, "info": { "image_count": image_count, } } except (TypeError, UnicodeDecodeError, tomli.TOMLDecodeError, OSError): return False @classmethod def get_msg_converter(cls) -> type["MessageConverter"]: return RawCSRDatasetParams def get_diagnostics(self): return [ {"name": "data dtype", "value": str(self._descriptor.data_dtype)}, {"name": "indptr dtype", "value": str(self._descriptor.indptr_dtype)}, {"name": "indices dtype", "value": str(self._descriptor.indices_dtype)}, ] # TODO: nonzero elements? @classmethod def get_supported_extensions(cls) -> set[str]: return {"toml"} def get_cache_key(self) -> str: raise NotImplementedError() # TODO @classmethod def get_supported_io_backends(cls) -> list[str]: return [] # FIXME: we may want to read using a backend in the future def adjust_tileshape( self, tileshape: tuple[int, ...], roi: typing.Optional[np.ndarray] ) -> tuple[int, ...]: return (tileshape[0],) + tuple(self._sig_shape) def need_decode( self, read_dtype: "nt.DTypeLike", roi: typing.Optional[np.ndarray], corrections: typing.Optional[CorrectionSet] ) -> bool: return super().need_decode(read_dtype, roi, corrections) def get_partitions(self) -> typing.Generator[Partition, None, None]: assert self._meta is not None for part_slice, start, stop in self.get_slices(): yield RawCSRPartition( descriptor=self._descriptor, meta=self._meta, partition_slice=part_slice, start_frame=start, num_frames=stop - start, io_backend=None, decoder=None, )
class RawCSRPartition(Partition): def __init__( self, descriptor: CSRDescriptor, start_frame: int, num_frames: int, *args, **kwargs ): self._descriptor = descriptor self._start_frame = start_frame self._num_frames = num_frames self._corrections = CorrectionSet() self._worker_context = None super().__init__(*args, **kwargs) def set_corrections(self, corrections: typing.Optional[CorrectionSet]): if corrections is not None and corrections.have_corrections(): raise NotImplementedError("corrections not implemented for raw CSR data set") def validate_tiling_scheme(self, tiling_scheme: TilingScheme): if len(tiling_scheme) != 1: raise ValueError("Cannot slice CSR data in sig dimensions") def get_locations(self): # Allow using any worker by default return None def get_tiles( self, tiling_scheme: TilingScheme, dest_dtype="float32", roi=None, array_backend: typing.Optional[ArrayBackend] = None ): assert array_backend == SCIPY_CSR or array_backend is None tiling_scheme = tiling_scheme.adjust_for_partition(self) self.validate_tiling_scheme(tiling_scheme) triple = get_triple(self._descriptor) if self._corrections is not None and self._corrections.have_corrections(): raise NotImplementedError( "corrections are not yet supported for raw CSR" ) if roi is None: yield from read_tiles_straight( triple, self.slice, self.meta.sync_offset, tiling_scheme, dest_dtype ) else: yield from read_tiles_with_roi( triple, self.slice, self.meta.sync_offset, tiling_scheme, roi, dest_dtype ) def sliced_indptr(triple: CSRTriple, partition_slice: Slice, sync_offset: int): assert len(partition_slice.shape.nav) == 1 skip = min(0, partition_slice.origin[0] + sync_offset) indptr_start = max(0, partition_slice.origin[0] + sync_offset) indptr_stop = max(0, partition_slice.origin[0] + partition_slice.shape.nav[0] + 1 + sync_offset) return skip, triple.indptr[indptr_start:indptr_stop] def get_triple(descriptor: CSRDescriptor) -> CSRTriple: data: np.ndarray = np.memmap( descriptor.data_file, dtype=descriptor.data_dtype, mode='r' ) indices: np.ndarray = np.memmap( descriptor.indices_file, dtype=descriptor.indices_dtype, mode='r' ) indptr: np.ndarray = np.memmap( descriptor.indptr_file, dtype=descriptor.indptr_dtype, mode='r' ) return CSRTriple( indptr=indptr, indices=indices, data=data, ) def check(descriptor: CSRDescriptor, nav_shape, sig_shape, debug=False): triple = get_triple(descriptor) if triple.indices.shape != triple.data.shape: raise RuntimeError('Shape mismatch between data and indices.') if debug: assert np.min(triple.indices) >= 0 assert np.max(triple.indices) < prod(sig_shape) assert np.min(triple.indptr) >= 0 assert np.max(triple.indptr) == len(triple.indices) def get_descriptor(path: str) -> CSRDescriptor: """ Get a CSRDescriptor from the path to a toml sidecar file """ conf = load_toml(path) assert conf is not None if conf['params']['filetype'].lower() != 'raw_csr': raise ValueError(f"Filetype is not CSR, found {conf['params']['filetype']}") base_path = os.path.dirname(path) # make sure the key is not case sensitive to follow the convention of # the Context.load() function. csr_key = conf['params']['filetype'] csr_conf = conf[csr_key] return CSRDescriptor( indptr_file=os.path.join(base_path, csr_conf['indptr_file']), indptr_dtype=csr_conf['indptr_dtype'], indices_file=os.path.join(base_path, csr_conf['indices_file']), indices_dtype=csr_conf['indices_dtype'], data_file=os.path.join(base_path, csr_conf['data_file']), data_dtype=csr_conf['data_dtype'], ) def get_nav_size(descriptor: CSRDescriptor) -> int: ''' To run efficiently on a remote worker for dataset initialization ''' indptr = np.memmap( descriptor.indptr_file, dtype=descriptor.indptr_dtype, mode='r', ) return len(indptr) - 1 def read_tiles_straight( triple: CSRTriple, partition_slice: Slice, sync_offset: int, tiling_scheme: TilingScheme, dest_dtype: np.dtype, ): assert len(tiling_scheme) == 1 skip, indptr = sliced_indptr( triple, partition_slice=partition_slice, sync_offset=sync_offset ) sig_shape = tuple(partition_slice.shape.sig) sig_size = partition_slice.shape.sig.size sig_dims = len(sig_shape) # Technically, one could use the slicing implementation of csr_matrix here. # However, it is slower, presumably because it takes a copy # Furthermore it provides a template to use an actual I/O backend here # instead of memory mapping. for indptr_start in range(0, len(indptr) - 1, tiling_scheme.depth): tile_start = indptr_start - skip # skip is a negative value or 0 indptr_stop = min(indptr_start + tiling_scheme.depth, len(indptr) - 1) if indptr_stop - indptr_start <= 0: continue indptr_slice = indptr[indptr_start:indptr_stop + 1] start = indptr[indptr_start] stop = indptr[indptr_stop] data = triple.data[start:stop] if dest_dtype != data.dtype: data = data.astype(dest_dtype) indices = triple.indices[start:stop] indptr_slice = indptr_slice - indptr_slice[0] arr = scipy.sparse.csr_matrix( (data, indices, indptr_slice), shape=(indptr_stop - indptr_start, sig_size) ) tile_slice = Slice( origin=(partition_slice.origin[0] + tile_start, ) + (0, ) * sig_dims, shape=Shape((arr.shape[0], ) + sig_shape, sig_dims=sig_dims), ) yield DataTile( data=arr, tile_slice=tile_slice, scheme_idx=0, ) def populate_tile( indptr_tile_start: "np.ndarray", indptr_tile_stop: "np.ndarray", orig_data: "np.ndarray", orig_indices: "np.ndarray", data_out: "np.ndarray", indices_out: "np.ndarray", indptr_out: "np.ndarray", ): offset = 0 indptr_out[0] = 0 for i, (start, stop) in enumerate(zip(indptr_tile_start, indptr_tile_stop)): chunk_size = stop - start data_out[offset:offset + chunk_size] = orig_data[start:stop] indices_out[offset:offset + chunk_size] = orig_indices[start:stop] offset += chunk_size indptr_out[i + 1] = offset populate_tile_numba = numba.njit(populate_tile) def can_use_numba(triple: CSRTriple) -> bool: return all(d in numba_dtypes for d in (triple.data.dtype, triple.indices.dtype, triple.indptr.dtype)) def read_tiles_with_roi( triple: CSRTriple, partition_slice: Slice, sync_offset: int, tiling_scheme: TilingScheme, roi: np.ndarray, dest_dtype: np.dtype, ): assert len(tiling_scheme) == 1 roi = roi.reshape((-1, )) part_start = max(0, partition_slice.origin[0]) tile_offset = count_nonzero(roi[:part_start]) part_roi = partition_slice.get(roi, nav_only=True) skip, indptr = sliced_indptr(triple, partition_slice=partition_slice, sync_offset=sync_offset) if skip < 0: skipped_part_roi = part_roi[-skip:] else: skipped_part_roi = part_roi roi_overhang = max(0, len(skipped_part_roi) - len(indptr) + 1) if roi_overhang: real_part_roi = skipped_part_roi[:-roi_overhang] else: real_part_roi = skipped_part_roi real_part_roi = for_backend(real_part_roi, NUMPY) sig_shape = tuple(partition_slice.shape.sig) sig_size = partition_slice.shape.sig.size sig_dims = len(sig_shape) start_values = indptr[:-1][real_part_roi] stop_values = indptr[1:][real_part_roi] # Implementing this "by hand" instead of fancy indexing to provide a template to use an # actual I/O backend here instead of memory mapping. # The native scipy.sparse.csr_matrix implementation of fancy indexing # with a boolean mask for nav is very fast. if can_use_numba(triple): my_populate_tile = populate_tile_numba else: my_populate_tile = populate_tile for indptr_start in range(0, len(part_roi), tiling_scheme.depth): indptr_stop = min(indptr_start + tiling_scheme.depth, len(start_values)) indptr_start = min(indptr_start, indptr_stop) # Don't read empty slices if indptr_stop - indptr_start <= 0: continue # Cast to int64 to avoid later upcasting to float64 in case of uint64 # We can safely assume that files have less than 2**63 entries so that casting # from uint64 to int64 should be safe indptr_tile_start = start_values[indptr_start:indptr_stop].astype(np.int64) indptr_tile_stop = stop_values[indptr_start:indptr_stop].astype(np.int64) size = sum(indptr_tile_stop - indptr_tile_start) data = np.zeros(dtype=dest_dtype, shape=size) indices = np.zeros(dtype=triple.indices.dtype, shape=size) indptr_slice = np.zeros( dtype=indptr.dtype, shape=indptr_stop - indptr_start + 1 ) my_populate_tile( indptr_tile_start=indptr_tile_start, indptr_tile_stop=indptr_tile_stop, orig_data=triple.data, orig_indices=triple.indices, data_out=data, indices_out=indices, indptr_out=indptr_slice, ) arr = scipy.sparse.csr_matrix( (data, indices, indptr_slice), shape=(indptr_stop - indptr_start, sig_size) ) tile_slice = Slice( origin=(tile_offset + indptr_start, ) + (0, ) * sig_dims, shape=Shape((indptr_stop - indptr_start, ) + sig_shape, sig_dims=sig_dims), ) yield DataTile( data=arr, tile_slice=tile_slice, scheme_idx=0, )