import math
from typing import Any, Optional, overload
from collections.abc import Generator, Sequence
import numpy as np
from libertem.common.math import prod, count_nonzero
from libertem.common.shape import Shape, ShapeLike
[docs]
class SliceUsageError(ValueError):
"""
Raised when a Slice is incorrectly instantiated or used
"""
[docs]
class Slice:
"""
A n-dimensional slice, defined by origin and shape
Parameters
----------
origin : tuple of int
global "top-left" coordinates of this slice
shape : Shape instance
the size of this slice
"""
__slots__ = ["origin", "shape"]
def __init__(self, origin: Sequence[int], shape: Shape):
self.origin = tuple(origin)
self.shape = shape
if len(self.origin) != len(self.shape):
raise SliceUsageError(
("cannot build slice with dimensionality of shape/origin mismatch (%d vs %d); "
"origin=%r, shape=%r") % (
len(self.origin), len(self.shape), self.origin, self.shape,
)
)
if not isinstance(shape, Shape):
raise SliceUsageError("please use libertem.common.Shape instance as shape parameter")
def __repr__(self) -> str:
return f"<Slice origin={self.origin!r} shape={self.shape!r}>"
def __hash__(self) -> int:
# enables using a Slice as a key in dict, an item in sets etc.
# in this case important for use as cache key for our mask container
return hash((self.origin, tuple(self.shape)))
def __eq__(self, other: object) -> bool:
return isinstance(other, Slice) and (
self.shape == other.shape and self.origin == other.origin
)
[docs]
@classmethod
def from_shape(cls, shape: Sequence[int], sig_dims: int) -> "Slice":
"""
Construct a `Slice` at zero-origin from `shape` and `sig_dims`.
"""
return Slice(
origin=(0,) * len(shape),
shape=Shape(shape, sig_dims=sig_dims),
)
[docs]
def intersection_with(self, other: "Slice") -> "Slice":
"""
Calculate the intersection between this slice and `other`. May result in
dimensions that are zero, which means that there is no intersection.
Returns
-------
slice : Slice
the intersection between this and the other slice
"""
if len(self.origin) != len(other.origin):
raise SliceUsageError(
("cannot intersect slices with different dimensionality (%s vs %s); "
"self.origin=%r, other.origin=%r") % (
len(self.origin), len(other.origin), self.origin, other.origin,
)
)
if self.shape.sig.dims != other.shape.sig.dims:
raise SliceUsageError(
"cannot intersect slices with different signal dimensionality ({} vs {})".format(
self.shape.sig.dims, other.shape.sig.dims
)
)
new_origin = tuple(
max(o1, o2)
for (o1, o2) in zip(self.origin, other.origin)
)
new_shape = [
min(
(o1 + s1) - no,
(o2 + s2) - no,
)
for (o1, o2, no, s1, s2) in zip(
self.origin, other.origin, new_origin, self.shape, other.shape
)
]
new_shape = [max(0, s) for s in new_shape]
result = Slice(
origin=new_origin,
shape=Shape(new_shape, sig_dims=self.shape.sig.dims),
)
return result
[docs]
def is_null(self) -> bool:
"""
If any part of our shape is zero, this slice doesn't span any data and is null / empty.
"""
return any(s == 0 for s in self.shape)
[docs]
def shift(self, other: "Slice") -> "Slice":
"""
make a new ``Slice`` with origin relative to ``other.origin``
and the same shape as this ``Slice``
useful for translating to the local coordinate system of ``other``
"""
if len(self.origin) != len(other.origin):
raise SliceUsageError(
"cannot shift slices with different "
f"dimensionality ({self.origin} vs {other.origin})"
)
return Slice(origin=tuple(our_coord - their_coord
for (our_coord, their_coord) in zip(self.origin, other.origin)),
shape=self.shape)
[docs]
def shift_by(self, offset: Sequence[int]) -> "Slice":
"""
Return a new slice with the origin moved by the supplied offset
and the same shape
"""
if len(self.origin) != len(offset):
raise SliceUsageError(
"cannot shift slices with different "
f"dimensionality ({self.origin} vs {offset})"
)
return Slice(
origin=tuple(
our_coord + off
for (our_coord, off)
in zip(self.origin, offset)
),
shape=self.shape,
)
@overload
def get(
self,
arr: None = None,
sig_only: bool = False,
nav_only: bool = False
) -> tuple[slice, ...]: ...
@overload
def get(
self,
arr: np.ndarray,
sig_only: bool = False,
nav_only: bool = False
) -> np.ndarray: ...
[docs]
def get(
self,
arr: Optional[np.ndarray] = None,
sig_only: bool = False,
nav_only: bool = False
):
"""
Get a standard python tuple-of-slice-object which can be used
to slice any compatible numpy.ndarray
Parameters
----------
arr
something implementing the slice interface. if given, returns arr[slice]
sig_only : bool
get a signal-only slice for frames/masks
nav_only : bool
get a nav-only slice, for example for indexing something that is shaped like
the navigation dimensions of this Slice.
Returns
-------
tuple of slice objects
returns standard python slices computed from
our origin+shape model or arr indexed with this slicing
if arr is given
Examples
--------
>>> import numpy as np
>>> from libertem.common import Slice, Shape
>>> s = Slice(shape=Shape((16, 16, 4, 4), sig_dims=2), origin=(0, 0, 12, 12))
>>> data = np.ones((16, 16))
>>> data[s.get(sig_only=True)]
array([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]])
"""
if sig_only:
o, s = self.origin, self.shape
slice_ = tuple(
slice(o[i], (o[i] + s[i]))
for i in range(s.nav.dims, s.sig.dims + s.nav.dims)
)
elif nav_only:
o, s = self.origin, self.shape
slice_ = tuple(
slice(o[i], (o[i] + s[i]))
for i in range(s.nav.dims)
)
else:
slice_ = self._get()
if arr is not None:
if sig_only:
# Skip the supposed nav dimensions of the data
return arr[(Ellipsis, ) + slice_]
else:
# for nav_only, we return the full remaining dimensions anyway
# if arr has more dimensions than the slice
return arr[slice_]
else:
return slice_
def _get(self):
"""
Direct conversion from Slice to tuple(slice, ...) without options
"""
return tuple(slice(o, (o + s)) for (o, s) in zip(self.origin, self.shape))
[docs]
def discard_nav(self) -> "Slice":
"""
returns a copy with the origin/shape zeroed in the nav dimensions
this is used to create uniform cache keys
"""
o, s, sig_dims = self._discard_nav_key()
new_shape = Shape(s, sig_dims=sig_dims)
return Slice(origin=o, shape=new_shape)
def _discard_nav_key(self) -> tuple[tuple[int, ...], tuple[int, ...], int]:
"""
Construct a hashable tuple of the Slice with a zero-length nav dimensions
Functions as discard_nav but avoids Shape and Slice constructor overheads
"""
o, s = self.origin, self.shape
nav_dims = s.nav_dims
zero_nav = (0,) * nav_dims
return (zero_nav + o[nav_dims:], zero_nav + s._sig_shape, s.sig_dims)
[docs]
def subslices(self, shape: ShapeLike) -> Generator["Slice", None, None]:
"""
Generator for all subslices of this slice with dimensions
specified by ``shape``.
Parameters
----------
shape : tuple of int or Shape
the shape of each sub-slice
Yields
------
Slice
all subslices, in fast-access order
"""
# example: self.shape=(3, 1, 1, 1), subslice shape=(2, 1, 1, 1)
# math.ceil(3/2) = math.ceil(1.5) = 2 -> we need two subslices across the y dimension
shape = Shape(shape, sig_dims=self.shape.sig.dims)
if self.shape.dims != shape.dims:
raise SliceUsageError(
("cannot create subslices with different dimensionality (%d vs %d); "
"self.shape=%r, shape=%r") % (
self.shape.dims, shape.dims, self.shape, shape,
)
)
ni = tuple(math.ceil(s1 / s)
for (s1, s) in zip(self.shape, shape))
def _make_slice(origin: tuple[int, ...], new_shape: Shape) -> Slice:
sig_dims = new_shape.sig.dims
# this makes sure that the border tiles have the correct shape set
new_shape_tuple = tuple(
min(ns, so + s - o)
for (ns, so, s, o) in zip(new_shape, self.origin, self.shape, origin)
)
new_shape = Shape(new_shape_tuple, sig_dims=sig_dims)
for x in new_shape_tuple:
assert x > 0, \
"invalid shape: {!r} while subslicing {!r} with {!r} (origin={!r})".format(
new_shape, self.shape, shape, origin
)
return Slice(
origin=origin,
shape=new_shape,
)
return (
_make_slice(origin=tuple(
o + i * s
for (o, i, s) in zip(self.origin, indexes, shape)
), new_shape=Shape(tuple(shape), sig_dims=self.shape.sig.dims))
for indexes in np.ndindex(ni)
)
@property
def nav(self) -> "Slice":
"""
Returns a new Slice, with sig_dims=0, limited to the nav part
"""
return Slice(
origin=self.origin[:self.shape.nav.dims],
shape=self.shape.nav,
)
@property
def sig(self) -> "Slice":
"""
Returns a new Slice, limited to the sig part
"""
return Slice(
origin=self.origin[self.shape.nav.dims:],
shape=self.shape.sig,
)
[docs]
def flatten_nav(self, containing_shape: ShapeLike) -> "Slice":
sig_dims = self.shape.sig.dims
nav_dims = self.shape.dims - sig_dims
containing_shape = tuple(containing_shape)[:nav_dims]
origin = self.origin[:nav_dims]
# validation for the nav_shape:
# what are the preconditions that allow flattening?
#
# - nav part of the shape: must be in the form of:
#
# (1, 1, ..., N, M, M, ...)
#
# where N<=M and M is the corresponding part of
# the shape of the dataset.
#
# - the origin must match the shape in the following way:
#
# (o1, o2, ..., oi, 0, 0, ...)
#
# where all oj are arbitraty (but in bounds)
#
state = 0
for cs, s, o in zip(containing_shape, self.shape.nav, origin):
if state == 0:
if s != 1:
state = 1
assert s <= cs, "invalid nav_shape #1"
elif state == 1:
assert s == cs, "invalid nav_shape #2"
assert o == 0, "invalid origin"
nav_origin = np.ravel_multi_index(
origin,
containing_shape
)
nav_shape = prod(self.shape.nav)
return Slice(
origin=(nav_origin,) + self.origin[nav_dims:],
shape=Shape((nav_shape,) + tuple(self.shape.sig), sig_dims=sig_dims)
)
[docs]
def adjust_for_roi(self, roi: Optional[np.ndarray]) -> "Slice":
"""
Make a new slice that has origin and shape modified according to `roi`.
"""
if roi is None:
return self
roi = roi.reshape(-1)
assert self.shape.nav.dims == 1
s_o = self.origin[0]
s_s = self.shape[0]
# We need to find how many 1s there are for all previous partitions, to know
# the origin; then we count how many 1s there are in our partition
# to find our shape.
origin = count_nonzero(roi[:s_o])
shape = count_nonzero(roi[s_o:s_o + s_s])
sig_dims = self.shape.sig.dims
return Slice(
origin=(origin,) + self.origin[-sig_dims:],
shape=Shape((shape,) + tuple(self.shape.sig), sig_dims=sig_dims),
)
[docs]
def clip_to(self, shape: Shape):
other_slice = Slice((0,) * shape.dims, shape)
return self.intersection_with(other_slice)
def __getstate__(self) -> dict[str, Any]:
return {
k: getattr(self, k)
for k in self.__slots__
}
def __setstate__(self, state: dict[str, Any]) -> None:
for k, v in state.items():
setattr(self, k, v)