Source code for libertem.io.dataset.dask

import logging
import itertools
import numpy as np
import dask.array as da

from libertem.common import Shape, Slice
from libertem.io.dataset.base import (
    DataSet, DataSetMeta, BasePartition, File, FileSet, DataSetException
)
from libertem.io.dataset.base.backend_mmap import MMapFile, MMapBackend, MMapBackendImpl
from libertem.common.messageconverter import MessageConverter

log = logging.getLogger(__name__)


class DaskDatasetParams(MessageConverter):
    SCHEMA = {
        "$schema": "http://json-schema.org/draft-07/schema#",
        "$id": "http://libertem.org/DaskDatasetParams.schema.json",
        "title": "DaskDatasetParams",
        "type": "object",
        "properties": {
            "type": {"const": "DASK"},
            "sig_dims": {"type": "number", "minimum": 1},
            "preserve_dimensions": {"type": "boolean"},
            "min_size": {"type": "number", "minimum": 1},
        },
        "required": ["type"],
    }

    def convert_to_python(self, raw_data):
        data = {
            k: raw_data[k]
            for k in ["sig_dims", "preserve_dimensions", "min_size"]
            if k in raw_data
        }
        return data


class FakeDaskMMapFile(MMapFile):
    """
    Implementing the same interface as MMapFile, without filesystem backing
    """
    def open(self):
        # scheduler='threads' ensures that upstream computation for this array
        # chunk happens completely on this worker and not elsewhere
        self._arr = self.desc._array.compute(scheduler='threads')
        # need to be aware that Dask can create Fortran-ordered arrays
        # when .compute is called, which can lead to downstream issues when
        # np.frombuffer is called on self._mmap in the backend. Currently it seems
        # like np.frombuffer cannot handle Fortran ordering and throws a ValueError
        self._mmap = self._arr
        return self

    def close(self):
        del self._arr
        del self._mmap


class DaskBackend(MMapBackend):
    def get_impl(self):
        return DaskBackendImpl()


class DaskBackendImpl(MMapBackendImpl):
    FILE_CLS = FakeDaskMMapFile


[docs] class DaskDataSet(DataSet): """ .. versionadded:: 0.9.0 Wraps a Dask.array.array such that it can be processed by LiberTEM. Partitions are created to be aligned with the array chunking. When the array chunking is not compatible with LiberTEM the wrapper merges chunks until compatibility is achieved. The best-case scenario is for the original array to be chunked in the leftmost navigation dimension. If instead another navigation dimension is chunked then the user can set `preserve_dimension=False` to re-order the navigation shape to achieve better chunking for LiberTEM. If more than one navigation dimension is chunked, the class will do its best to merge chunks without creating partitions which are too large. LiberTEM requires that a partition contains only whole signal frames, so any signal dimension chunking is immediately merged by this class. This wrapper is most useful when the Dask array was created using lazy I/O via `dask.delayed`, or via `dask.array` operations. The major assumption is that the chunks in the array can each be individually evaluated without having to read or compute more data than the chunk itself contains. If this is not the case then this class could perform very poorly due to read amplification, or even crash the Dask workers. As the class performs rechunking using a merge-only strategy it will never split chunks which were present in the original array. If the array is originally very lightly chunked, then the corresponding LiberTEM partitions will be very large. In addition, overly-chunked arrays (for example one chunk per frame) can incurr excessive Dask task graph overheads and should be avoided where possible. Parameters ---------- dask_array: dask.array.array A Dask array sig_dims: int Number of dimensions in dask_array.shape counting from the right to treat as signal dimensions preserve_dimensions: bool, optional If False, allow optimization of the dask_arry chunking by re-ordering the nav_shape to put the most chunked dimensions first. This can help when more than one nav dimension is chunked. min_size: float, optional The minimum partition size in bytes if the array chunking allows an order-preserving merge strategy. The default min_size is 128 MiB. io_backend: bool, optional For compatibility, accept an unused io_backend argument. Example -------- >>> import dask.array as da >>> >>> d_arr = da.ones((4, 4, 64, 64), chunks=(2, -1, -1, -1)) >>> ds = ctx.load('dask', dask_array=d_arr, sig_dims=2) Will create a dataset with 5 partitions split along the zeroth dimension. """ # TODO add mechanism to re-order the dimensions of results automatically # if preserve_dimensions is set to False def __init__(self, dask_array, *, sig_dims, preserve_dimensions=True, min_size=None, io_backend=None): super().__init__(io_backend=io_backend) if io_backend is not None: raise DataSetException("DaskDataSet currently doesn't support alternative I/O backends") self._check_array(dask_array, sig_dims) self._array = dask_array self._sig_dims = sig_dims self._sig_shape = self._array.shape[-self._sig_dims:] self._dtype = self._array.dtype self._preserve_dimension = preserve_dimensions self._min_size = min_size if self._min_size is None: # TODO add a method to determine a sensible partition byte-size self._min_size = self._default_min_size @property def array(self): return self._array def get_io_backend(self): return DaskBackend() def initialize(self, executor): self._array = self._adapt_chunking(self._array, self._sig_dims) self._nav_shape = self._array.shape[:-self._sig_dims] self._nav_shape_product = int(np.prod(self._nav_shape)) self._image_count = self._nav_shape_product shape = Shape(self._nav_shape + self._sig_shape, sig_dims=self._sig_dims) self._meta = DataSetMeta( shape=shape, raw_dtype=np.dtype(self._dtype), sync_offset=0, image_count=self._nav_shape_product, ) return self @property def dtype(self): return self._meta.raw_dtype @property def shape(self): return self._meta.shape @classmethod def get_msg_converter(cls): return DaskDatasetParams @property def _default_min_size(self): """ Default minimum chunk size if not supplied at init """ return 128 * (2**20) # MB def _chunk_slices(self, array): chunks = array.chunks boundaries = tuple(tuple(self.chunks_to_slices(chunk_lengths)) for chunk_lengths in chunks) return tuple(itertools.product(*boundaries)) def _adapt_chunking(self, array, sig_dims): n_dimension = array.ndim # Handle chunked signal dimensions by merging just in case sig_dim_idxs = [*range(n_dimension)[-sig_dims:]] if any([len(array.chunks[c]) > 1 for c in sig_dim_idxs]): original_n_chunks = [len(c) for c in array.chunks] array = array.rechunk({idx: -1 for idx in sig_dim_idxs}) log.warning('Merging sig dim chunks as LiberTEM does not ' 'support paritioning along the sig axes. ' f'Original n_blocks: {original_n_chunks}. ' f'New n_blocks: {[len(c) for c in array.chunks]}.') # Warn if there is no nav_dim chunking n_nav_chunks = [len(dim_chunking) for dim_chunking in array.chunks[:-sig_dims]] if set(n_nav_chunks) == {1}: log.warning('Dask array is not chunked in navigation dimensions, ' 'cannot split into nav-partitions without loading the ' 'whole dataset on each worker. ' f'Array shape: {array.shape}. ' f'Chunking: {array.chunks}. ' f'array size {array.nbytes / 1e6} MiB.') # If we are here there is nothing else to do. return array # Orient the nav dimensions so that the zeroth dimension is # the most chunked, this obviously changes the dataset nav_shape ! if not self._preserve_dimension: n_nav_chunks = [len(dim_chunking) for dim_chunking in array.chunks[:-sig_dims]] nav_sort_order = np.argsort(n_nav_chunks)[::-1].tolist() sort_order = nav_sort_order + sig_dim_idxs if not np.equal(sort_order, np.arange(n_dimension)).all(): original_shape = array.shape original_n_chunks = [len(c) for c in array.chunks] array = da.transpose(array, axes=sort_order) log.warning('Re-ordered nav_dimensions to improve partitioning, ' 'create the dataset with preserve_dimensions=True ' 'to suppress this behaviour. ' f'Original shape: {original_shape} with ' f'n_blocks: {original_n_chunks}. ' f'New shape: {array.shape} with ' f'n_blocks: {[len(c) for c in array.chunks]}.') # Handle chunked nav_dimensions # We can allow nav_dimensions to be fully chunked (one chunk per element) # up-to-but-not-including the first non-fully chunked dimension. After this point # we must merge/rechunk all subsequent nav dimensions to ensure continuity # of frame indexes in a flattened nav dimension. This should be removed # when if we allow non-contiguous flat_idx Partitions nav_rechunk_dict = {} for dim_idx, dim_chunking in enumerate(array.chunks[:-sig_dims]): if set(dim_chunking) == {1}: continue else: merge_dimensions = [*range(dim_idx + 1, n_dimension - sig_dims)] for merge_i in merge_dimensions: if len(array.chunks[merge_i]) > 1: nav_rechunk_dict[merge_i] = -1 if nav_rechunk_dict: original_n_chunks = [len(c) for c in array.chunks] array = array.rechunk(nav_rechunk_dict) log.warning('Merging nav dimension chunks according to scheme ' f'{nav_rechunk_dict} as we cannot maintain continuity ' 'of frame indexing in the flattened navigation dimension. ' f'Original n_blocks: {original_n_chunks}. ' f'New n_blocks: {[len(c) for c in array.chunks]}.') # Merge remaining chunks maintaining C-ordering until we reach a target chunk sizes # or a minmum number of partitions corresponding to the number of workers new_chunking, min_size, max_size = merge_until_target(array, self._min_size) if new_chunking != array.chunks: original_n_chunks = [len(c) for c in array.chunks] chunksizes = get_chunksizes(array) orig_min, orig_max = chunksizes.min(), chunksizes.max() array = array.rechunk(new_chunking) log.warning('Applying re-chunking to increase minimum partition size. ' f'n_blocks: {original_n_chunks} => {[len(c) for c in array.chunks]}. ' f'Min chunk size {orig_min / 1e6:.1f} => {min_size / 1e6:.1f} MiB , ' f'Max chunk size {orig_max / 1e6:.1f} => {max_size / 1e6:.1f} MiB.') return array def _check_array(self, array, sig_dims): if not isinstance(array, da.Array): raise DataSetException('Expected a Dask array as input, recieved ' f'{type(array)}.') if not isinstance(sig_dims, int) and sig_dims >= 0: raise DataSetException('Expected non-negative integer sig_dims,' f'recieved {sig_dims}.') if any([np.isnan(c).any() for c in array.shape])\ or any([np.isnan(c).any() for c in array.chunks]): raise DataSetException('Dask array has an unknown shape or chunk sizes ' 'so cannot be interpreted as a LiberTEM partitions. ' 'Run array.compute_compute_chunk_sizes() ' 'before passing to DaskDataSet, though this ' 'may be performance-intensive. Chunking: ' f'{array.chunks}, Shape {array.shape}') if sig_dims >= array.ndim: raise DataSetException(f'Number of sig_dims {sig_dims} not compatible ' f'with number of array dims {array.ndim}, ' 'must be able to create partitions along nav ' 'dimensions.') return True def check_valid(self): return self._check_array(self._array, self._sig_dims) def get_num_partitions(self): return len([*itertools.product(*self._array.chunks)]) @staticmethod def chunks_to_slices(chunk_lengths): prior = 0 for c in chunk_lengths: newc = c + prior yield slice(prior, newc) prior = newc @staticmethod def slices_to_shape(slices): return tuple(s.stop - s.start for s in slices) @staticmethod def slices_to_origin(slices): return tuple(s.start for s in slices) @staticmethod def flatten_nav(slices, nav_shape, sig_dims): """ Because LiberTEM partitions are set up with a flat nav dimension we must flatten the Dask array slices. This is ensured to be possible by earlier calls to _adapt_chunking but should be removed if ever partitions are able to have >1D navigation axes. """ nav_slices = slices[:-sig_dims] sig_slices = slices[-sig_dims:] start_frame = np.ravel_multi_index([s.start for s in nav_slices], nav_shape) end_frame = 1 + np.ravel_multi_index([s.stop - 1 for s in nav_slices], nav_shape) nav_slice = slice(start_frame, end_frame) return (nav_slice,) + sig_slices, start_frame, end_frame def get_slices(self): """ Generates the LiberTEM slices which correspond to the chunks in the Dask array backing the dataset Generates both the flat_nav slice for creating the LiberTEM partition and also the full_slices used to index into the dask array """ chunk_slices = self._chunk_slices(self._array) for full_slices in chunk_slices: flat_slices, start_frame, end_frame = self.flatten_nav(full_slices, self._nav_shape, self._sig_dims) flat_slice = Slice(origin=self.slices_to_origin(flat_slices), shape=Shape(self.slices_to_shape(flat_slices), sig_dims=self._sig_dims)) yield full_slices, flat_slice, start_frame, end_frame def _get_fileset(self): """ The fileset is set up to have one 'file' per partition which corresponds to one 'file' per Dask chunk """ partitions = [] for full_slices, _, start, stop in self.get_slices(): partitions.append(DaskFile( array_chunk=self._array[full_slices], path=None, start_idx=start, end_idx=stop, native_dtype=self._dtype, sig_shape=self.shape.sig )) return DaskFileSet(partitions) def get_partitions(self): """ Partitions contain a reference to the whole array and the whole fileset, but the part_slice and start_frame/num_frames provided mean that the subsequent call to get_read_ranges() means only one 'file' is read/.compute(), and this corresponds to the partition *exactly* """ fileset = self._get_fileset() for _, part_slice, start, stop in self.get_slices(): yield DaskPartition( self._array, meta=self._meta, fileset=fileset, partition_slice=part_slice, start_frame=start, num_frames=stop - start, io_backend=self.get_io_backend(), decoder=self.get_decoder() ) def __repr__(self): return (f"<DaskDataSet of {self.dtype} shape={self.shape}, " f"n_blocks={[len(c) for c in self._array.chunks]}>")
class DaskFile(File): def __init__(self, *args, array_chunk=None, **kwargs): """ Upon creation, the dask array has been sliced to give only one chunk corresponding to a LiberTEM partition """ self._array = array_chunk super().__init__(*args, **kwargs) class DaskFileSet(FileSet): pass class DaskPartition(BasePartition): def __init__(self, dask_array, *args, **kwargs): self._array = dask_array super().__init__(*args, **kwargs) def array_mult(*arrays, dtype=np.float64): num_arrays = len(arrays) if num_arrays == 1: return np.asarray(arrays[0]).astype(dtype) elif num_arrays == 2: return np.multiply.outer(*arrays).astype(dtype) elif num_arrays > 2: return np.multiply.outer(arrays[0], array_mult(*arrays[1:])) else: raise RuntimeError('Unexpected number of arrays') def get_last_chunked_dim(chunking): n_chunks = [len(c) for c in chunking] chunked_dims = [idx for idx, el in enumerate(n_chunks) if el > 1] try: return chunked_dims[-1] except IndexError: return -1 def get_chunksizes(array, chunking=None): if chunking is None: chunking = array.chunks shape = array.shape el_bytes = array.dtype.itemsize last_chunked = get_last_chunked_dim(chunking) if last_chunked < 0: return np.asarray(array.nbytes) static_size = np.prod(shape[last_chunked + 1:], dtype=np.float64) * el_bytes chunksizes = array_mult(*chunking[:last_chunked + 1]) * static_size return chunksizes def modify_chunking(chunking, dim, merge_idxs): chunk_dim = chunking[dim] merge_idxs = tuple(sorted(merge_idxs)) before = chunk_dim[:merge_idxs[0]] after = chunk_dim[merge_idxs[1] + 1:] merged_dim = (sum(chunk_dim[merge_idxs[0]:merge_idxs[1] + 1]),) new_chunk_dim = tuple(before) + merged_dim + tuple(after) chunking = chunking[:dim] + (new_chunk_dim,) + chunking[dim + 1:] return chunking def findall(sequence, val): return [idx for idx, e in enumerate(sequence) if e == val] def neighbour_idxs(sequence, idx): max_idx = len(sequence) - 1 if idx > 0 and idx < max_idx: return (idx - 1, idx + 1) elif idx == 0: return (None, idx + 1) elif idx == max_idx: return (idx - 1, None) else: raise def min_neighbour(sequence, idx): left, right = neighbour_idxs(sequence, idx) if left is None: return right elif right is None: return left else: return min([left, right], key=lambda x: sequence[x]) def min_with_min_neighbor(sequence): min_val = min(sequence) occurences = findall(sequence, min_val) min_idx_pairs = [(idx, min_neighbour(sequence, idx)) for idx in occurences] pair = [sum(get_values(sequence, idxs)) for idxs in min_idx_pairs] min_pair = min(pair) min_pair_occurences = findall(pair, min_pair) return min_idx_pairs[min_pair_occurences[-1]] # breaking ties from right def get_values(sequence, idxs): return [sequence[idx] for idx in idxs] def merge_until_target(array, target, min_chunks=0): chunking = array.chunks if array.nbytes < target: # A really small dataset, better to treat as one partition chunking = tuple((s,) for s in array.shape) chunksizes = get_chunksizes(array) while chunksizes.size > min_chunks and chunksizes.min() < target: if (chunksizes < 0).any(): log.warn('Overflow in chunksize calculation, will be clipped!') chunksizes = np.clip(chunksizes, 0., np.inf) last_chunked_dim = get_last_chunked_dim(chunking) if last_chunked_dim < 0: # No chunking, by definition complete break last_chunking = chunking[last_chunked_dim] to_merge = min_with_min_neighbor(last_chunking) chunking = modify_chunking(chunking, last_chunked_dim, to_merge) chunksizes = get_chunksizes(array, chunking=chunking) return chunking, chunksizes.min(), chunksizes.max()