Source code for libertem.io.dataset.hdf5

import os
import contextlib
import typing
from typing import Optional
import warnings
import logging
import time

import numpy as np
import h5py
from sparseconverter import CUDA, NUMPY, ArrayBackend

from libertem.common.math import prod, flat_nonzero
from libertem.common import Slice, Shape
from libertem.common.buffers import zeros_aligned
from libertem.io.corrections import CorrectionSet
from libertem.common.messageconverter import MessageConverter
from .base import (
    DataSet, Partition, DataTile, DataSetException, DataSetMeta,
    TilingScheme,
)


# alias for mocking:
current_time = time.time

logger = logging.getLogger(__name__)


class HDF5DatasetParams(MessageConverter):
    SCHEMA = {
        "$schema": "http://json-schema.org/draft-07/schema#",
        "$id": "http://libertem.org/HDF5DatasetParams.schema.json",
        "title": "HDF5DatasetParams",
        "type": "object",
        "properties": {
            "type": {"const": "HDF5"},
            "path": {"type": "string"},
            "ds_path": {"type": "string"},
            "nav_shape": {
                "type": "array",
                "items": {"type": "number", "minimum": 1},
                "minItems": 2,
                "maxItems": 2
            },
            "sig_shape": {
                "type": "array",
                "items": {"type": "number", "minimum": 1},
                "minItems": 2,
                "maxItems": 2
            },
            "sync_offset": {"type": "number"},
        },
        "required": ["type", "path", "ds_path"]
    }

    def convert_to_python(self, raw_data):
        data = {
            k: raw_data[k]
            for k in ["path", "ds_path"]
        }
        if "nav_shape" in raw_data:
            data["nav_shape"] = tuple(raw_data["nav_shape"])
        if "sig_shape" in raw_data:
            data["sig_shape"] = tuple(raw_data["sig_shape"])
        if "sync_offset" in raw_data:
            data["sync_offset"] = raw_data["sync_offset"]
        return data


def _ensure_2d_nav(nav_shape: tuple[int, ...]) -> tuple[int, int]:
    # For any iterable shape, reduce or pad it to a 2-tuple
    # with the same prod(shape). Reduction from left to right
    # (final dimension preserved). Special case for empty
    # nav_shape which is converted to (1, 1)
    nav_shape = tuple(nav_shape)
    if len(nav_shape) == 1:
        nav_shape = (1,) + nav_shape
    elif len(nav_shape) >= 2:
        nav_shape = (prod(nav_shape[:-1]),) + nav_shape[-1:]
    elif len(nav_shape) == 0:
        return (1, 1)
    else:
        raise ValueError(f'Incompatible nav_shape {nav_shape}')
    return nav_shape


class HDF5ArrayDescriptor(typing.NamedTuple):
    name: str
    shape: tuple[int, ...]
    dtype: np.dtype
    compression: Optional[str]
    chunks: tuple[int, ...]


def _get_datasets(path):
    datasets: list[HDF5ArrayDescriptor] = []

    try:
        timeout = int(os.environ.get('LIBERTEM_IO_HDF5_TIMEOUT_DEBUG', 3))
    except ValueError:
        timeout = 3

    t0 = current_time()

    def _make_list(name, obj):
        if current_time() - t0 > timeout:
            raise TimeoutError
        if hasattr(obj, 'size') and hasattr(obj, 'shape'):
            if obj.ndim < 3:
                # Can't process this dataset, skip
                return
            datasets.append(
                HDF5ArrayDescriptor(name, obj.shape, obj.dtype, obj.compression, obj.chunks)
            )

    with h5py.File(path, 'r') as f:
        f.visititems(_make_list)
    return datasets


def _have_contig_chunks(chunks, ds_shape):
    """
    Returns `True` if the `chunks` are contiguous in the navigation axes.

    Examples
    --------

    >>> ds_shape = Shape((64, 64, 128, 128), sig_dims=2)
    >>> _have_contig_chunks((1, 4, 32, 32), ds_shape)
    True
    >>> _have_contig_chunks((2, 4, 32, 32), ds_shape)
    False
    >>> _have_contig_chunks((2, 64, 32, 32), ds_shape)
    True
    >>> _have_contig_chunks((64, 1, 32, 32), ds_shape)
    False
    >>> ds_shape_5d = Shape((16, 64, 64, 128, 128), sig_dims=2)
    >>> _have_contig_chunks((1, 1, 2, 32, 32), ds_shape_5d)
    True
    >>> _have_contig_chunks((1, 2, 1, 32, 32), ds_shape_5d)
    False
    >>> _have_contig_chunks((2, 1, 1, 32, 32), ds_shape_5d)
    False
    """
    # In other terms:
    # There exists an index `i` such that `prod(chunks[:i]) == 1` and
    # chunks[i+1:] == ds_shape[i+1:], limited to the nav part of chunks and ds_shape
    #
    nav_shape = tuple(ds_shape.nav)
    nav_dims = len(nav_shape)
    chunks_nav = chunks[:nav_dims]

    for i in range(nav_dims):
        left = chunks_nav[:i]
        left_prod = prod(left)
        if left_prod == 1 and chunks_nav[i + 1:] == nav_shape[i + 1:]:
            return True
    return False


def _partition_shape_for_chunking(chunks, ds_shape):
    """
    Get the minimum partition shape for that allows us to prevent read amplification
    with chunked HDF5 files.

    Examples
    --------

    >>> ds_shape = Shape((64, 64, 128, 128), sig_dims=2)
    >>> _partition_shape_for_chunking((1, 4, 32, 32), ds_shape)
    (1, 4, 128, 128)
    >>> _partition_shape_for_chunking((2, 4, 32, 32), ds_shape)
    (2, 64, 128, 128)
    >>> _partition_shape_for_chunking((2, 64, 32, 32), ds_shape)
    (2, 64, 128, 128)
    >>> _partition_shape_for_chunking((64, 1, 32, 32), ds_shape)
    (64, 64, 128, 128)
    >>> ds_shape_5d = Shape((16, 64, 64, 128, 128), sig_dims=2)
    >>> _partition_shape_for_chunking((1, 1, 2, 32, 32), ds_shape_5d)
    (1, 1, 2, 128, 128)
    >>> _partition_shape_for_chunking((1, 2, 1, 32, 32), ds_shape_5d)
    (1, 2, 64, 128, 128)
    >>> _partition_shape_for_chunking((2, 1, 1, 32, 32), ds_shape_5d)
    (2, 64, 64, 128, 128)
    """
    first_non_one = [x == 1 for x in chunks].index(False)
    shape_left = chunks[:first_non_one + 1]
    return shape_left + ds_shape[first_non_one + 1:]


def _tileshape_for_chunking(chunks, ds_shape):
    """
    Calculate a tileshape for tiled reading from chunked
    data sets.

    Examples
    --------
    >>> ds_shape = Shape((64, 64, 128, 128), sig_dims=2)
    >>> _tileshape_for_chunking((1, 4, 32, 32), ds_shape)
    (4, 32, 32)
    """
    return chunks[-ds_shape.sig.dims - 1:]


def _get_tileshape_nd(partition_slice, tiling_scheme):
    extra_nav_dims = partition_slice.shape.nav.dims - tiling_scheme.shape.nav.dims
    # keep shape of the rightmost dimension:
    nav_item = min(tiling_scheme.shape[0], partition_slice.shape.nav[-1])
    return extra_nav_dims * (1,) + (nav_item,) + tuple(tiling_scheme.shape.sig)


class H5Reader:
    def __init__(self, path, ds_path):
        self._path = path
        self._ds_path = ds_path

    @contextlib.contextmanager
    def get_h5ds(self, cache_size=1024 * 1024):
        logger.debug("H5Reader.get_h5ds: cache_size=%dMiB", cache_size / 1024 / 1024)
        with h5py.File(self._path, 'r', rdcc_nbytes=cache_size, rdcc_nslots=19997) as f:
            yield f[self._ds_path]


[docs] class H5DataSet(DataSet): """ Read data from a HDF5 data set. Examples -------- >>> ds = ctx.load("hdf5", path=path_to_hdf5, ds_path="/data") Parameters ---------- path: str Path to the file ds_path: str Path to the HDF5 data set inside the file nav_shape: tuple of int, optional A n-tuple that specifies the shape of the navigation / scan grid. By default this is inferred from the HDF5 dataset. sig_shape: tuple of int, optional A n-tuple that specifies the shape of the signal / frame grid. This parameter is currently unsupported and will raise an error if provided and not matching the underlying data sig shape. By default the sig_shape is inferred from the HDF5 dataset via the :code:`sig_dims` parameter. sig_dims: int Number of dimensions that should be considered part of the signal (for example 2 when dealing with 2D image data) sync_offset: int, optional, by default 0 If positive, number of frames to skip from start If negative, number of blank frames to insert at start target_size: int Target partition size, in bytes. Usually doesn't need to be changed. min_num_partitions: int Minimum number of partitions, set to number of cores if not specified. Usually doesn't need to be specified. Note ---- If the HDF5 file to be loaded contains compressed data using a custom compression filter (other than GZIP, LZF or SZIP), the associated HDF5 filter library must be imported on the workers before accessing the file. See the `h5py documentation on filter pipelines <https://docs.h5py.org/en/stable/high/dataset.html#filter-pipeline>`_ for more information. The library `hdf5plugin <https://github.com/silx-kit/hdf5plugin>`_ is preloaded automatically if it is installed. Other filter libraries may have to be specified for preloading by the user. Preloads for a local :class:`~libertem.executor.dask.DaskJobExecutor` can be specified through the :code:`preload` argument of either :meth:`~libertem.executor.dask.DaskJobExecutor.make_local` or :func:`libertem.executor.dask.cluster_spec`. For the :class:`libertem.executor.inline.InlineJobExecutor`, the plugins can simply be imported in the main script. For the web GUI or for running LiberTEM in a cluster with existing workers (e.g. by running :code:`libertem-worker` or :code:`dask-worker` on nodes), necessary imports can be specified as :code:`--preload` arguments to the launch command, for example with :code:`libertem-server --preload hdf5plugin` resp. :code:`libertem-worker --preload hdf5plugin tcp://scheduler_ip:port`. :code:`--preload` can be specified multiple times. """ def __init__(self, path, ds_path=None, tileshape=None, nav_shape=None, sig_shape=None, target_size=None, min_num_partitions=None, sig_dims=2, io_backend=None, sync_offset: int = 0): super().__init__(io_backend=io_backend) if io_backend is not None: raise ValueError("H5DataSet currently doesn't support alternative I/O backends") self.path = path self.ds_path = ds_path self.target_size = target_size self.sig_dims = sig_dims # handle backwards-compatability: if tileshape is not None: warnings.warn( "tileshape argument is ignored and will be removed after 0.6.0", FutureWarning ) # self.min_num_partitions appears to be never used self.min_num_partitions = min_num_partitions self._dtype = None self._shape = None self._nav_shape = nav_shape self._sig_shape = sig_shape self._sync_offset = sync_offset self._chunks = None self._compression = None def get_reader(self): return H5Reader( path=self.path, ds_path=self.ds_path ) def _do_initialize(self): if self.ds_path is None: try: datasets = _get_datasets(self.path) largest_ds = max(datasets, key=lambda x: prod(x.shape)) # FIXME: excepting `SystemError` temporarily # more info: https://github.com/h5py/h5py/issues/1740 except (ValueError, TimeoutError, SystemError): raise DataSetException(f'Unable to infer dataset from file {self.path}') self.ds_path = largest_ds.name with self.get_reader().get_h5ds() as h5ds: self._dtype = h5ds.dtype shape = h5ds.shape if len(shape) == self.sig_dims: # shape = (1,) + shape -> this leads to indexing errors down the line # so we currently don't support opening 2D HDF5 files raise DataSetException("2D HDF5 files are currently not supported") ds_shape = Shape(shape, sig_dims=self.sig_dims) if self._sig_shape is not None and self._sig_shape != ds_shape.sig.to_tuple(): raise DataSetException("sig reshaping currently not supported with HDF5 files") self._image_count = ds_shape.nav.size nav_shape = ds_shape.nav.to_tuple() if self._nav_shape is None else self._nav_shape self._shape = nav_shape + ds_shape.sig self._meta = DataSetMeta( shape=self.shape, raw_dtype=self._dtype, sync_offset=self._sync_offset, image_count=self._image_count, metadata={'ds_raw_shape': ds_shape} ) self._chunks = h5ds.chunks self._compression = h5ds.compression if self._compression is not None: warnings.warn( "Loading compressed HDF5, performance can be worse than with other formats", RuntimeWarning ) self._nav_shape_product = self._shape.nav.size self._sync_offset_info = self.get_sync_offset_info() return self def initialize(self, executor): return executor.run_function(self._do_initialize) @classmethod def get_msg_converter(cls): return HDF5DatasetParams @classmethod def get_supported_extensions(cls): return {"h5", "hdf5", "hspy", "nxs"} @classmethod def get_supported_io_backends(self): return [] @classmethod def _do_detect(cls, path): try: with h5py.File(path, 'r'): pass except OSError as e: raise DataSetException(repr(e)) from e @classmethod def detect_params(cls, path, executor): try: executor.run_function(cls._do_detect, path) except (OSError, KeyError, ValueError, TypeError, DataSetException): # not a h5py file or can't open for some reason: return False # Read the dataset info from the file try: datasets = executor.run_function(_get_datasets, path) if not datasets: raise RuntimeError(f'Found no compatible datasets in the file {path}') # FIXME: excepting `SystemError` temporarily # more info: https://github.com/h5py/h5py/issues/1740 except (RuntimeError, TimeoutError, SystemError): return { "parameters": { "path": path, }, "info": { "datasets": [], } } # datasets contains at least one HDF5ArrayDescriptor # sig_dims is implicitly two here (for web GUI) sig_dims = 2 full_info = [ { "path": ds_item.name, "shape": ds_item.shape, "compression": ds_item.compression, "chunks": ds_item.chunks, "raw_nav_shape": ds_item.shape[:-sig_dims], "nav_shape": _ensure_2d_nav(ds_item.shape[:-sig_dims]), "sig_shape": ds_item.shape[-sig_dims:], "image_count": prod(ds_item.shape[:-sig_dims]), } for ds_item in datasets ] # use the largest size array as initial hdf5 dataset path # need to get info dict to access unpacked nav/sig shape # next line implements argmax on ds_descriptor.size ds_idx, _ = max(enumerate(datasets), key=lambda idx_x: prod(idx_x[1].shape)) largest_ds = full_info[ds_idx] return { "parameters": { "path": path, "ds_path": largest_ds['path'], "nav_shape": largest_ds['nav_shape'], "sig_shape": largest_ds['sig_shape'], }, "info": { "datasets": full_info } } @property def dtype(self): if self._dtype is None: raise RuntimeError("please call initialize") return self._dtype @property def shape(self): if self._shape is None: raise RuntimeError("please call initialize") return self._shape def check_valid(self): try: with self.get_reader().get_h5ds() as h5ds: h5ds.shape return True except (OSError, KeyError, ValueError) as e: raise DataSetException("invalid dataset: %s" % e) def get_cache_key(self): return { "path": self.path, "ds_path": self.ds_path, } def get_diagnostics(self): with self.get_reader().get_h5ds() as ds: try: datasets = _get_datasets(self.path) # FIXME: excepting `SystemError` temporarily # more info: https://github.com/h5py/h5py/issues/1740 except (TimeoutError, SystemError): datasets = [] datasets = [ {"name": descriptor.name, "value": [ {"name": "Size", "value": str(prod(descriptor.shape))}, {"name": "Shape", "value": str(descriptor.shape)}, {"name": "Datatype", "value": str(descriptor.dtype)}, ]} for descriptor in sorted( datasets, key=lambda i: prod(i.shape), reverse=True ) ] return [ {"name": "dtype", "value": str(ds.dtype)}, {"name": "chunks", "value": str(ds.chunks)}, {"name": "compression", "value": str(ds.compression)}, {"name": "datasets", "value": datasets}, ] def get_min_sig_size(self): if self._chunks is not None: return 1024 # allow for tiled processing w/ small-ish chunks # un-chunked HDF5 seems to prefer larger signal slices, so we aim for 32 4k blocks: return 32 * 4096 // np.dtype(self.meta.raw_dtype).itemsize def get_max_io_size(self) -> Optional[int]: if self._chunks is not None: # this may result in larger tile depth than necessary, but # it needs to be so big to pass the validation of the Negotiator. The tiles # won't ever be as large as the scheme dictates, anyway. # We limit it here to 256e6 elements, to also keep the chunk cache # usage reasonable: return int(256e6) return None # use default value from Negotiator def get_base_shape(self, roi): if roi is not None: return (1,) + self.shape.sig if self._chunks is not None: sig_chunks = self._chunks[-self.shape.sig.dims:] return (1,) + sig_chunks return (1, 1,) + (self.shape[-1],) def adjust_tileshape(self, tileshape, roi): chunks = self._chunks sig_shape = self.shape.sig if roi is not None: return (1,) + sig_shape if chunks is not None and not _have_contig_chunks(chunks, self.shape): sig_chunks = chunks[-sig_shape.dims:] sig_ts = tileshape[-sig_shape.dims:] # if larger signal chunking is requested in the negotiation, # switch to full frames: if any(t > c for t, c in zip(sig_ts, sig_chunks)): # try to keep total tileshape size: tileshape_size = prod(tileshape) depth = max(1, tileshape_size // sig_shape.size) return (depth,) + sig_shape else: # depth needs to be limited to prod(chunks.nav) return _tileshape_for_chunking(chunks, self.shape) return tileshape def need_decode(self, roi, read_dtype, corrections): return True def get_partitions(self): # ds_shape = Shape(self.shape, sig_dims=self.sig_dims) ds_shape: Shape = self.meta['ds_raw_shape'] ds_slice = Slice(origin=[0] * len(ds_shape), shape=ds_shape) target_size = self.target_size if target_size is None: if self._compression is None: target_size = 512 * 1024 * 1024 else: target_size = 256 * 1024 * 1024 partition_shape = self.partition_shape( target_size=target_size, dtype=self.dtype, containing_shape=ds_shape, ) + tuple(ds_shape.sig) # if the data is chunked in the navigation axes, choose a compatible # partition size (even important for non-compressed data!) chunks = self._chunks if chunks is not None and not _have_contig_chunks(chunks, ds_shape): partition_shape = _partition_shape_for_chunking(chunks, ds_shape) # -ve sync offset insert blank at beginning (skips at end) # +ve sync offset skips frames at beginning (blank at end) sync_offset = self._sync_offset ds_flat_shape = self.shape.flatten_nav() for slice_nd in ds_slice.subslices(partition_shape): raw_frames_slice = slice_nd.flatten_nav(ds_shape) raw_origin = raw_frames_slice.origin if self._sync_offset <= 0: # negative or zero s-o, shift right, clip length at end raw_frames_slice.origin = (raw_origin[0] + abs(sync_offset),) + raw_origin[1:] else: # positive s-o, shift left, clip part length at beginning corrected_nav_origin = raw_origin[0] - sync_offset if corrected_nav_origin < 0: corrected_nav_size = raw_frames_slice.shape[0] + corrected_nav_origin if corrected_nav_size <= 0: # Empty partition, skip continue raw_frames_slice.shape = (corrected_nav_size,) + raw_frames_slice.shape.sig raw_frames_slice.origin = (max(0, corrected_nav_origin),) + raw_origin[1:] # All raw_frames_slice should have non-zero dims here partition_slice = raw_frames_slice.clip_to(ds_flat_shape) if any(v <= 0 for v in partition_slice.shape): # Empty partition after clip to desired shape, skip continue yield H5Partition( meta=self._meta, reader=self.get_reader(), partition_slice=partition_slice, slice_nd=slice_nd, io_backend=self.get_io_backend(), chunks=self._chunks, decoder=None, sync_offset=self._sync_offset, ) def __repr__(self): return f"<H5DataSet of {self._dtype} shape={self._shape}>"
class H5Partition(Partition): def __init__(self, reader: H5Reader, slice_nd: Slice, chunks, sync_offset=0, *args, **kwargs): self.reader = reader self.slice_nd = slice_nd self._corrections = None self._chunks = chunks self._sync_offset = sync_offset super().__init__(*args, **kwargs) def _have_compatible_chunking(self): chunks = self._chunks if chunks is None: return True # all-1 in nav dims works: nav_dims = self.slice_nd.shape.nav.dims chunks_nav = self._chunks[:nav_dims] if all(c == 1 for c in chunks_nav): return True # everything else is problematic and needs special case: return False def _get_subslices_chunked_full_frame(self, scheme_lookup, nav_dims, tileshape_nd): """ chunked full-frame reading. outer loop goes over the navigation coords of each chunk, inner loop is pushed into hdf5 by reading full frames need to order slices in a way that efficiently uses the chunk cache. """ chunks_nav = self._chunks[:nav_dims] chunk_full_frame = chunks_nav + self.slice_nd.shape.sig chunk_slices = self.slice_nd.subslices(shape=chunk_full_frame) logger.debug( "_get_subslices_chunked_full_frame: chunking first by %r, then %r", chunk_full_frame, tileshape_nd, ) for chunk_slice in chunk_slices: subslices = chunk_slice.subslices(shape=tileshape_nd) for subslice in subslices: idx = scheme_lookup[(subslice.origin[nav_dims:], subslice.shape[nav_dims:])] yield idx, subslice def _get_subslices_chunked_tiled(self, tiling_scheme, scheme_lookup, nav_dims, tileshape_nd): """ general tiled reading w/ chunking outer loop is a chunk in signal dimensions, inner loop is over "rows in nav" """ slice_nd_sig = self.slice_nd.sig slice_nd_nav = self.slice_nd.nav chunks_nav = self._chunks[:nav_dims] sig_slices = slice_nd_sig.subslices(tiling_scheme.shape.sig) logger.debug( "_get_subslices_chunked_tiled: chunking first by sig %r, then nav %r, finally %r", tiling_scheme.shape.sig, chunks_nav, tileshape_nd ) for sig_slice in sig_slices: chunk_slices = slice_nd_nav.subslices(shape=chunks_nav) for chunk_slice_nav in chunk_slices: chunk_slice = Slice( origin=chunk_slice_nav.origin + sig_slice.origin, shape=chunk_slice_nav.shape + tuple(sig_slice.shape), ) subslices = chunk_slice.subslices(shape=tileshape_nd) for subslice in subslices: scheme_key = (subslice.origin[nav_dims:], subslice.shape[nav_dims:]) idx = scheme_lookup[scheme_key] yield idx, subslice def _get_subslices(self, tiling_scheme): """ Generate partition subslices for the given tiling scheme for the different cases. """ if tiling_scheme.intent == "partition": tileshape_nd = self.slice_nd.shape else: tileshape_nd = _get_tileshape_nd(self.slice_nd, tiling_scheme) assert all(ts <= ps for (ts, ps) in zip(tileshape_nd, self.slice_nd.shape)) nav_dims = self.slice_nd.shape.nav.dims # Three cases need to be handled: if self._have_compatible_chunking(): # 1) no chunking, or compatible chunking. we are free to use # whatever access pattern we deem efficient: logger.debug("using simple tileshape_nd slicing") subslices = self.slice_nd.subslices(shape=tileshape_nd) scheme_len = len(tiling_scheme) for idx, subslice in enumerate(subslices): scheme_idx = idx % scheme_len yield scheme_idx, subslice else: scheme_lookup = { (s.discard_nav().origin, tuple(s.discard_nav().shape)): idx for idx, s in tiling_scheme.slices } if len(tiling_scheme) == 1: logger.debug("using full-frame subslicing") yield from self._get_subslices_chunked_full_frame( scheme_lookup, nav_dims, tileshape_nd ) else: logger.debug("using chunk-adaptive subslicing") yield from self._get_subslices_chunked_tiled( tiling_scheme, scheme_lookup, nav_dims, tileshape_nd ) def _preprocess(self, tile_data, tile_slice): if self._corrections is None: return self._corrections.apply(tile_data, tile_slice) def _get_read_cache_size(self) -> float: chunks = self._chunks if chunks is None: return 1 * 1024 * 1024 else: # heuristic on maximum chunk cache size based on number of cores # of the node this worker is running on, available memory, ... import psutil mem = psutil.virtual_memory() num_cores = psutil.cpu_count(logical=False) available: int = mem.available if num_cores is None: num_cores = 2 cache_size: float = max(256 * 1024 * 1024, available * 0.8 / num_cores) return cache_size def _get_h5ds(self): cache_size = self._get_read_cache_size() return self.reader.get_h5ds(cache_size=cache_size) def _get_tiles_normal(self, tiling_scheme: TilingScheme, dest_dtype): with self._get_h5ds() as dataset: # because the dtype conversion done by HDF5 itself can be quite slow, # we need to use a buffer for reading in hdf5 native dtype: data_flat = np.zeros(tiling_scheme.shape, dtype=dataset.dtype).reshape((-1,)) # ... and additionally a result buffer, for re-using the array used in the DataTile: data_flat_res = np.zeros(tiling_scheme.shape, dtype=dest_dtype).reshape((-1,)) subslices = self._get_subslices( tiling_scheme=tiling_scheme, ) sync_offset = self._sync_offset ds_num_frames = self.meta.shape.nav.size for scheme_idx, tile_slice in subslices: tile_slice_flat: Slice = tile_slice.flatten_nav(self.meta['ds_raw_shape']) raw_origin = tile_slice_flat.origin raw_shape = tile_slice_flat.shape # The following block translates from tile_slice in the raw array # to the partition coordinate system with sync_offset applied # By doing this before reading we can avoid reading some tiles # at the beginning/end of the dataset # We will still read tiles which partially overlap the nav space # and afterwards we drop the unecessary frames corrected_nav_origin = raw_origin[0] - sync_offset if corrected_nav_origin < 0: # positive sync_offset, drop frames at beginning of DS if abs(corrected_nav_origin) > tile_slice_flat.shape[0]: # tile is completely before the first partition, can skip it continue # Clip at the beginning so adjust the tile shape new_nav_size = tile_slice_flat.shape[0] + corrected_nav_origin tile_slice_flat.shape = (new_nav_size,) + tile_slice_flat.shape.sig # Apply max(0, corrected_nav_origin) so we never provide negative nav coord tile_slice_flat.origin = (max(0, corrected_nav_origin),) + raw_origin[1:] # Now check for clipping at the end of the dataset final_frame_idx = tile_slice_flat.origin[0] + tile_slice_flat.shape[0] frames_beyond_end = final_frame_idx - ds_num_frames # We want to skip any tiles which are completely past the end of the dataset # and clip those which are only partially overlapping the final partition if frames_beyond_end >= tile_slice_flat.shape[0]: # Empty tile after clip, skip continue elif frames_beyond_end > 0: # tile partially overlaps end of dataset, adjust the shape new_nav_size = tile_slice_flat.shape[0] - frames_beyond_end tile_slice_flat.shape = (new_nav_size,) + tile_slice_flat.shape.sig # Read the data in this block # cut buffer into the right size buf_size = tile_slice.shape.size buf = data_flat[:buf_size].reshape(tile_slice.shape) buf_res = data_flat_res[:buf_size].reshape(tile_slice.shape) dataset.read_direct(buf, source_sel=tile_slice.get()) buf_res[:] = buf # extra copy for faster dtype/endianess conversion tile_data = buf_res.reshape(raw_shape) # If the true tile origin is before the start of the dataset, must drop frames # This corresponds to the first raw tile which overlaps the first partition # and can only occur when sync_offset > 0 if corrected_nav_origin < 0: tile_data = tile_data[abs(corrected_nav_origin):, ...] # The final tiles in the dataset can partially overlap the final partition # Drop frames at end to match the partition size # we already verified if any frames will remain # this can occur for both +ve and -ve sync_offset if frames_beyond_end > 0: tile_data = tile_data[:-frames_beyond_end, ...] # NOTE could the above two blocks ever apply simultaneously? # would the two operations conflict ? self._preprocess(tile_data, tile_slice_flat) yield DataTile( tile_data, tile_slice=tile_slice_flat, scheme_idx=scheme_idx, ) def _get_tiles_with_roi(self, roi, dest_dtype, tiling_scheme): # we currently don't chop up the frames when reading with a roi, so # the tiling scheme also must not contain more than one slice: # NOTE Why is this ?? assert len(tiling_scheme) == 1, "incompatible tiling scheme! (%r)" % (tiling_scheme) flat_roi_nonzero = flat_nonzero(roi) start_at_frame = self.slice.origin[0] stop_at_frame = start_at_frame + self.slice.shape[0] part_mask = np.logical_and(flat_roi_nonzero >= start_at_frame, flat_roi_nonzero < stop_at_frame) # Must yield tiles with tile_slice in compressed nav dimension for roi frames_in_c_nav = np.arange(flat_roi_nonzero.size)[part_mask] frames_in_part = flat_roi_nonzero[part_mask] # -ve sync offset insert blank at beginning (skips at end) # +ve sync offset skips frames at beginning (blank at end) frames_in_raw = frames_in_part + self._sync_offset raw_shape = self.meta['ds_raw_shape'].nav.to_tuple() result_shape = Shape((1,) + tuple(self.meta.shape.sig), sig_dims=self.meta.shape.sig.dims) sig_origin = (0,) * self.meta.shape.sig.dims tile_data = np.zeros(result_shape, dtype=dest_dtype) with self._get_h5ds() as h5ds: tile_data_raw = np.zeros(result_shape, dtype=h5ds.dtype) for c_nav_idx, raw_idx in zip(frames_in_c_nav, frames_in_raw): tile_slice = Slice( origin=(c_nav_idx,) + sig_origin, shape=result_shape, ) nav_coord = np.unravel_index(raw_idx, raw_shape) h5ds.read_direct(tile_data_raw, source_sel=nav_coord) tile_data[:] = tile_data_raw # extra copy for dtype/endianess conversion self._preprocess(tile_data, tile_slice) yield DataTile( tile_data, tile_slice=tile_slice, # there is only a single slice in the tiling scheme, so our # scheme_idx is constant 0 scheme_idx=0, ) def set_corrections(self, corrections: CorrectionSet): self._corrections = corrections def get_tiles(self, tiling_scheme: TilingScheme, dest_dtype="float32", roi=None, array_backend: Optional[ArrayBackend] = None): if array_backend is None: array_backend = self.meta.array_backends[0] assert array_backend in (NUMPY, CUDA) tiling_scheme = tiling_scheme.adjust_for_partition(self) if roi is not None: yield from self._get_tiles_with_roi(roi, dest_dtype, tiling_scheme) else: yield from self._get_tiles_normal(tiling_scheme, dest_dtype) def get_locations(self): return None def get_macrotile(self, dest_dtype="float32", roi=None): ''' Return a single tile for the entire partition. This is useful to support process_partiton() in UDFs and to construct dask arrays from datasets. Note ---- This can be inefficient if the dataset is compressed and chunked in the navigation axis, because you can either have forced-large macrotiles, or you can have read amplification effects, where a much larger amount of data is read from the HDF5 file than necessary. For example, if your chunking is :code:`(32, 32, 32, 32)`, in a dataset that is :code:`(128, 128, 256, 256)`, the partition must cover the whole of :code:`(32, 128, 256, 256)` - this is because partitions are contiguous in the navigation axis. The other possibility is to keep the partition smaller, for example only :code:`(3, 128, 256, 256)`. That would mean when reading a chunk from HDF5, we can only use 3*32 frames of the total 32*32 frames, a whopping ~10x read amplification. ''' tileshape = self.shape if self._chunks is not None: tileshape = self._chunks tiling_scheme = TilingScheme.make_for_shape( tileshape=Shape(tileshape, sig_dims=self.slice.shape.sig.dims).flatten_nav(), dataset_shape=self.meta.shape, ) data = zeros_aligned(self.slice.adjust_for_roi(roi).shape, dtype=dest_dtype) for tile in self.get_tiles( tiling_scheme=tiling_scheme, dest_dtype=dest_dtype, roi=roi, ): rel_slice = tile.tile_slice.shift(self.slice) data[rel_slice.get()] = tile.data tile_slice = Slice( origin=(self.slice.origin[0], 0, 0), shape=Shape(data.shape, sig_dims=2), ) return DataTile( data, tile_slice=tile_slice, scheme_idx=0, )