import math
import logging
import warnings
from typing import TYPE_CHECKING, Optional, Union
from collections.abc import Sequence
from typing_extensions import Literal
import numpy as np
from libertem.common.exceptions import UDFException
from libertem.io.corrections import CorrectionSet
from libertem.common import Shape, Slice
from libertem.common.math import prod
from libertem.common.udf import UDFProtocol, UDFMethod
if TYPE_CHECKING:
from numpy import typing as nt
from libertem.io.dataset.base import DataSet, Partition
log = logging.getLogger(__name__)
TilingIntent = Union[Literal["partition"], Literal["frame"], Literal["tile"]]
[docs]
class TilingScheme:
def __init__(
self, slices: list[Slice],
tileshape: Shape, dataset_shape: Shape, intent: Optional[TilingIntent] = None, debug=None
):
self._slices = slices
self._tileshape = tileshape
self._dataset_shape = dataset_shape
self._debug = debug
self._intent = intent
if tileshape.nav.dims > 1:
raise ValueError("tileshape should have flat navigation dimensions")
[docs]
def adjust_for_partition(self, partition: "Partition") -> "TilingScheme":
"""
If the intent is per-partition processing, the tiling scheme must match the
partition shape exactly. If there is a mismatch, this method returns a
new scheme that matches the partition.
Parameters
----------
partition
The Partition we want to adjust the tiling scheme to.
Returns
-------
TilingScheme
The adjusted tiling scheme, or this one, if it matches exactly
"""
partition_size = partition.slice.shape.nav.size
if partition_size != self.depth and self.intent == "partition":
# adjust depth to match partition size exactly:
new_shape = Shape(
(partition_size,) + tuple(self._tileshape.sig),
sig_dims=self._tileshape.sig.dims
)
return TilingScheme(
slices=self._slices,
tileshape=new_shape,
dataset_shape=self._dataset_shape,
intent=self._intent,
debug=self._debug,
)
return self
[docs]
@classmethod
def make_for_shape(
cls,
tileshape: Shape,
dataset_shape: Shape,
intent: Optional[TilingIntent] = None,
debug=None,
) -> "TilingScheme":
"""
Make a TilingScheme from `tileshape` and `dataset_shape`.
Note that both in signal and navigation direction there are border
effects, i.e. if the depth doesn't evenly divide the number of
frames in the partition (simplified, ROI also applies...), or if
the signal dimensions of `tileshape` doesn't evenly divide the signal
dimensions of the `dataset_shape`.
Parameters
----------
tileshape
Uniform shape of all tiles. Should have flat navigation axis
(meaning tileshape.nav.dims == 1) and be contiguous in signal
dimensions.
dataset_shape
Shape of the whole data set. Only the signal part is used.
intent
The intent of this scheme (whole partitions, frames or tiles)
Needs to be set for correct per-partition tiling!
"""
# FIXME: validate navigation part of the tileshape to be contiguous
# (i.e. a shape like (1, 1, ..., 1, X1, ..., XN))
# where X1 is <= the dataset shape at that index, and X2, ..., XN are
# equal to the dataset shape at that index
sig_slice = Slice(
origin=tuple([0] * dataset_shape.sig.dims),
shape=dataset_shape.sig
)
subslices = list(sig_slice.subslices(tileshape.sig))
return cls(
slices=subslices,
tileshape=tileshape,
dataset_shape=dataset_shape,
debug=debug,
intent=intent,
)
def __getitem__(self, idx: int) -> Slice:
return self._slices[idx]
def __len__(self):
return len(self._slices)
def __repr__(self):
unique_shapes = list({tuple(slice_.shape) for slice_ in self._slices})
return "<TilingScheme (depth=%d) shapes=%r len=%d>" % (
self.depth, unique_shapes, len(self._slices),
)
@property
def intent(self) -> Optional[TilingIntent]:
return self._intent
@property
def slices(self):
"""
signal-only slices for all possible positions
"""
return list(enumerate(self._slices))
@property
def slices_array(self):
"""
Returns the slices from the schema as a numpy ndarray
`a` of shape `(n, 2, sig_dims)` with:
`a[i, 0]` are origins for slice `i`
`a[i, 1]` are shapes for slice `i`
"""
sig_dims = self._tileshape.sig.dims
slices = np.zeros((len(self), 2, sig_dims), dtype=np.int64)
for idx, slice_ in self.slices:
slices[idx] = (tuple(slice_.origin), tuple(slice_.shape))
return slices
@property
def shape(self):
"""
tileshape. note that some border tiles can be smaller!
"""
return self._tileshape
@property
def dataset_shape(self):
return self._dataset_shape
@property
def depth(self):
return self._tileshape.nav[0]
[docs]
class Negotiator:
"""
Tile shape negotiator. The main functionality is in `get_scheme`,
which, given a `udf`, `dataset` and `read_dtype` will generate
a `TilingScheme` that is compatible with both the `UDF` and the
`DataSet`, possibly even optimal.
"""
[docs]
def validate(
self,
shape: tuple[int, ...],
ds_sig_shape: tuple[int, ...],
size: int,
io_max_size: int,
itemsize: int,
base_shape: tuple[int, ...],
corrections: Optional[CorrectionSet],
):
sig_shape = shape[1:]
# we need some wiggle room with the size, because there may be a harder
# lower size value for some cases (for example HDF5, which overrides
# some of the sizing negotiation we are doing here)
if any(s > ps for s, ps in zip(sig_shape, ds_sig_shape)):
raise ValueError("generated tileshape does not fit the partition")
size_px = max(size, io_max_size) // itemsize
if prod(shape) > size_px:
message = "shape %r (%d) does not fit into size %d" % (
shape, prod(shape), size_px
)
# The shape might be exceeded if dead pixel correction didn't find a
# valid tiling scheme. In that case it falls back to by-frame processing.
if (
corrections is not None
and corrections.get_excluded_pixels() is not None
and shape[0] == 1
):
warnings.warn(message)
else:
raise ValueError(message)
for dim in range(len(base_shape)):
# Tile shape always has one nav dim
# Allow only base shape for nav
# Allow full frames for sig
if ((shape[dim] % base_shape[dim] != 0)
and (dim == 0 or (dim > 0 and shape[dim] != ds_sig_shape[dim - 1]))):
raise ValueError(
f"The tileshape {shape} is incompatible with base "
f"shape {base_shape} and dataset shape {ds_sig_shape} in dimension {dim}."
)
[docs]
def get_scheme(
self,
udfs: Sequence[UDFProtocol],
dataset,
read_dtype: "nt.DTypeLike",
approx_partition_shape: Shape,
roi: Optional[np.ndarray] = None,
corrections: Optional[CorrectionSet] = None,
) -> TilingScheme:
"""
Generate a :class:`TilingScheme` instance that is
compatible with both the given `udf` and the
:class:~`libertem.io.dataset.base.DataSet`.
Parameters
----------
udfs : Sequence[UDFProtocol]
The concrete UDFs to optimize the tiling scheme for.
Depending on the method (tile, frame, partition)
and preferred total input size and depth.
dataset : DataSet
The DataSet instance we generate the scheme for.
read_dtype
The dtype in which the data will be fed into the UDF
approx_partition_shape
The approximate partition shape that is likely to be used
roi : np.ndarray
Region of interest
corrections : CorrectionSet
Correction set to consider in negotiation
"""
itemsize = np.dtype(read_dtype).itemsize
# FIXME: let the UDF define upper bound for signal size (lower bound, too?)
# (signal buffers should fit into the L2 cache)
# try not to waste page faults:
min_sig_size = dataset.get_min_sig_size()
ds_sig_shape = dataset.shape.sig
# This already takes corrections into account through a different pathway
need_decode = dataset.need_decode(roi=roi, read_dtype=read_dtype, corrections=corrections)
io_max_size = self._get_io_max_size(dataset, approx_partition_shape, itemsize, need_decode)
depths = [
self._get_min_depth(udf, approx_partition_shape)
for udf in udfs
]
depth = max(depths) # take the largest min-depth
base_shape = self._get_base_shape(udfs, dataset, approx_partition_shape, roi)
intent = self._get_intent(udfs)
sizes = [
self._get_size(
io_max_size, udf, itemsize, approx_partition_shape, base_shape,
)
for udf in udfs
]
if intent == "partition":
size = max(sizes) # by partition wants to be big, ...
else:
size = min(sizes)
size_px = size // itemsize
if corrections is not None and corrections.have_corrections():
# The correction has to make sure that there are no excluded pixels
# at tile boundaries
base_shape = corrections.adjust_tileshape(
tile_shape=base_shape,
sig_shape=tuple(ds_sig_shape),
base_shape=base_shape,
)
# first, scale `base_shape` up to contain at least `min_sig_size` items:
min_factors = self._get_scale_factors(
base_shape,
containing_shape=ds_sig_shape,
size=min_sig_size,
)
min_base_shape = self._scale_base_shape(base_shape, min_factors)
# considering the min size, calculate the max depth:
max_depth = max(1, size_px // prod(min_base_shape))
if depth > max_depth:
depth = max_depth
full_base_shape = (1,) + tuple(base_shape)
min_factors = (depth,) + tuple(min_factors)
containing_shape = approx_partition_shape
factors = self._get_scale_factors(
full_base_shape,
containing_shape=containing_shape,
size=size_px,
min_factors=min_factors,
)
tileshape = self._scale_base_shape(full_base_shape, factors)
tileshape_orig = tileshape
# the dataset has a "veto" on the tileshape:
# FIXME: this veto may break if the base shape was adjusted
# above, and we need to be careful not to break corrections with this,
# and also fulfill requests of per-frame reading
log.debug("tileshape before adjustment: %r", (tileshape,))
tileshape = tuple(dataset.adjust_tileshape(tileshape, roi=roi))
log.debug("tileshape after adjustment: %r", (tileshape,))
# if the veto generated a tileshape that is smaller than the full base shape,
# we need to re-adjust the full_base_shape
if tileshape_orig != tileshape: # make sure we don't change too eagerly:
if tileshape[0] < full_base_shape[0]:
full_base_shape = (tileshape[0], *full_base_shape[1:])
has_pixel_corr = (
corrections is not None
and corrections.get_excluded_pixels() is not None
)
for (orig, new, sig) in zip(tileshape_orig[1:], tileshape[1:], ds_sig_shape):
if new != orig and new != sig:
# Otherwise we may generate incorrect correction results
err_str = (
"dataset.adjust_tileshape() can only accept tile sig shape or switch to"
"full frames if dead pixel patching is active. "
f"Got original tile shape {tileshape_orig}, new tileshape {tileshape} "
f"and dataset sig shape {ds_sig_shape}"
)
if has_pixel_corr:
raise ValueError(err_str)
else:
warnings.warn(err_str)
self.validate(
tileshape, ds_sig_shape, size, io_max_size, itemsize, full_base_shape, corrections,
)
return TilingScheme.make_for_shape(
tileshape=Shape(tileshape, sig_dims=ds_sig_shape.dims),
dataset_shape=dataset.shape,
intent=intent,
debug={
"min_factors": min_factors,
"factors": factors,
"tileshape": tileshape,
"size": size,
"size_px": size_px,
"full_base_shape": full_base_shape,
"need_decode": need_decode,
"depth": depth,
}
)
def _get_io_max_size(self, dataset, approx_partition_shape, itemsize, need_decode):
if need_decode:
io_max_size = dataset.get_max_io_size()
if io_max_size is None:
io_max_size = 2**20
else:
io_max_size = itemsize * prod(approx_partition_shape)
return io_max_size
def _get_scale_factors(self, shape, containing_shape, size, min_factors=None):
"""
Generate scaling factors to scale `shape` up to `size` elements,
while being constrained to `containing_shape`.
"""
log.debug(
"_get_scale_factors in: shape=%r, containing_shape=%r, size=%r, min_factors=%r",
shape, containing_shape, size, min_factors
)
assert len(shape) == len(containing_shape)
if min_factors is None:
factors = [1] * len(shape)
else:
factors = list(min_factors)
max_factors = tuple(
cs // s
for s, cs in zip(shape, containing_shape)
)
prelim_shape = self._scale_base_shape(shape, factors)
rest = size / prod(prelim_shape)
if rest < 1:
rest = 1
for idx in range(len(shape)):
max_factor = max_factors[idx]
factor = int(math.floor(rest * factors[idx]))
if factor < factors[idx]:
factor = factors[idx]
if factor > max_factor:
factor = max_factor
factors[idx] = factor
prelim_shape = self._scale_base_shape(shape, factors)
rest = max(1, math.floor(size / prod(prelim_shape)))
log.debug(
"_get_scale_factors out: %r",
factors,
)
return factors
def _scale_base_shape(self, base_shape, factors):
assert len(factors) == len(base_shape)
return tuple(
f * bs
for f, bs in zip(factors, base_shape)
)
def _get_default_size(self):
# FIXME: adjust size to L3 // number of workers per node
return 1*2**20
def _get_udf_size_pref(self, udf: UDFProtocol):
udf_prefs = udf.get_tiling_preferences()
size = udf_prefs.get("total_size", np.inf)
if size is UDFProtocol.TILE_SIZE_BEST_FIT:
size = self._get_default_size()
return size
def _get_intent(self, udfs: Sequence[UDFProtocol]) -> TilingIntent:
udf_methods = tuple(udf.get_method() for udf in udfs)
if any(m not in tuple(UDFMethod) for m in udf_methods):
raise UDFException('A UDF declared an invalid processing method')
if UDFMethod.PARTITION in udf_methods:
return "partition"
elif UDFMethod.FRAME in udf_methods:
return "frame"
elif UDFMethod.TILE in udf_methods:
return "tile"
else:
raise ValueError('No recognized UDF method, empty udfs arg?')
def _get_size(
self, io_max_size, udf: UDFProtocol, itemsize,
approx_partition_shape: Shape, base_shape):
"""
Calculate the maximum tile size in bytes
"""
udf_method = udf.get_method()
partition_size = itemsize * prod(tuple(approx_partition_shape))
partition_size_sig = itemsize * prod(tuple(approx_partition_shape.sig))
if udf_method == UDFMethod.FRAME:
size = max(self._get_default_size(), partition_size_sig)
elif udf_method == UDFMethod.PARTITION:
size = partition_size
elif udf_method == UDFMethod.TILE:
# start with the UDF size preference:
size = self._get_udf_size_pref(udf)
# constrain to maximum read size
size = min(size, io_max_size)
# if the base_shape is larger than the current maximum size,
# we need to increase the size:
base_size = itemsize * prod(base_shape)
size = max(base_size, size)
else: # pragma: no cover
# Should never be reached, this is checked earlier in UDFRunner
raise UDFException(f'UDF.get_method() returned unrecognized method: {udf_method}')
return size
def _get_base_shape(
self,
udfs: Sequence["UDFProtocol"],
dataset: "DataSet",
approx_partition_shape: Shape,
roi: Optional[np.ndarray],
):
methods = [
udf.get_method()
for udf in udfs
]
if any(m in (UDFMethod.FRAME, UDFMethod.PARTITION) for m in methods):
base_shape = approx_partition_shape.sig
else:
# only by tile:
base_shape = Shape(
dataset.get_base_shape(roi=roi),
sig_dims=approx_partition_shape.sig.dims
).sig
return base_shape
def _get_udf_depth_pref(self, udf: "UDFProtocol", approx_partition_shape: Shape) -> int:
udf_prefs = udf.get_tiling_preferences()
depth = udf_prefs.get("depth", UDFProtocol.TILE_DEPTH_DEFAULT)
if depth is UDFProtocol.TILE_DEPTH_DEFAULT:
depth = 32
if depth > approx_partition_shape[0]:
depth = approx_partition_shape[0]
return depth
def _get_min_depth(self, udf: "UDFProtocol", approx_partition_shape: Shape) -> int:
udf_method = udf.get_method()
if udf_method == UDFMethod.PARTITION:
return approx_partition_shape[0]
elif udf_method == UDFMethod.TILE:
return self._get_udf_depth_pref(udf, approx_partition_shape)
return 1