from functools import partial
from typing import Any, Optional
from collections.abc import Iterable
import contextlib
from collections import defaultdict, OrderedDict
import numpy as np
import dask
from dask import delayed
import dask.array as da
from libertem.io.corrections import CorrectionSet
from libertem.io.dataset.base import DataSet
from libertem.utils.devices import detect
from .base import BaseJobExecutor
from libertem.common.executor import Environment, TaskCommHandler, TaskProtocol
from libertem.common.scheduler import Worker, WorkerSet
from ..common.buffers import BufferWrapper
from ..common.math import prod
from ..udf.base import (
UDFMergeAllMixin, UDFRunner, UDF, UDFData, MergeAttrMapping,
get_resources_for_backends, UDFResults, BackendSpec
)
from .utils.dask_buffer import DaskBufferWrapper, DaskPreallocBufferWrapper, DaskResultBufferWrapper
from .utils import delayed_unpack
class DelayedUDFRunner(UDFRunner):
def __init__(self, udfs: list[UDF], debug: bool = False, progress_reporter: Any = None):
self._part_results = defaultdict(dict)
super().__init__(udfs, debug=debug)
@staticmethod
def _make_udf_result(udfs: Iterable[UDF], damage: BufferWrapper) -> "UDFResults":
udf_results = UDFRunner._make_udf_result(udfs, damage)
buffers = udf_results.buffers
damage = udf_results.damage
new_buffers = tuple(
{
k: DaskResultBufferWrapper.from_buffer_wrapper(v)
for k, v in bufs.items()
}
for bufs in buffers
)
return UDFResults(
buffers=new_buffers,
damage=damage,
)
def _apply_part_result(self, udfs: Iterable[UDF], damage, part_results, task):
for part_results_udf, udf in zip(part_results, udfs):
# Allow user to define an alternative merge strategy
# using dask-compatible functions. In the Delayed case we
# won't be getting partial results with damage anyway.
# Currently there is no interface to provide all of the results
# to udf.merge at once and in the correct order, so I am accumulating
# results in self._part_results[udf] = {partition_slice_roi: part_results, ...}
if (isinstance(udf, UDFMergeAllMixin)
or (type(udf).merge is UDF.merge and not udf.requires_custom_merge_all)):
if self._accumulate_part_results(udf, part_results_udf, task):
try:
parts = {
key: value.get_proxy()
for key, value in self._part_results[udf].items()
}
udf._do_merge_all(parts)
finally:
self._part_results.pop(udf, None)
continue
# In principle we can requrire that sig merges are done sequentially
# by using dask.graph_manipulation.bind, this could improve thread
# safety on the sig buffer. In the current implem it's not necessary as
# we are calling delayed_apply_part_result sequentially on results
# and replacing all buffers with their partially merged versions
# on each call, so we are implicitly sequential when updating
# try:
# bind_to = udf.prior_merge
# except AttributeError:
# bind_to = None
structure = structure_from_task([udf], task)[0]
flat_structure = delayed_unpack.flatten_nested(structure)
flat_mapping = delayed_unpack.build_mapping(structure)
src = part_results_udf.get_proxy()
src_dict = {k: b for k, b in src._dict.items()}
dest_dict = {}
for k, b_decl in udf.get_result_buffers().items():
# We skip result_only buffers since they are not expected
# in the dest_dict by the default merge function
if b_decl.use == 'result_only':
continue
b = udf.results.get_buffer(k)
view = b.get_view_for_partition(task.partition)
# Handle result-only buffers
if view is not None:
try:
dest_dict[k] = view.unwrap_sliced()
except AttributeError:
# Handle kind='single' buffers
# Could handle sig buffers this way too if we assert no chunking
dest_dict[k] = view
# Run the udf.merge function in a wrapper which flattens
# the results to a flat list which can be unpacked from the Delayed
# This step makes a copy of the buffer view inside the delayed
# call because this default Dask behaviour is to give a read-only
# view of the arguments to the delayed function
merged_del = delayed(merge_wrap, nout=len(flat_mapping))
# need to copy the damage, as it will be overwritten in the next loop iteration,
# which is independent from actually evaluating the delayed object:
merged_del_result = merged_del(udf, dest_dict, src_dict, damage.raw_data.copy())
# The following is needed when binding sequential updates on sig buffers
# if bind_to is not None:
# merged_del = delayed_bind([*merged_del_result],
# [*bind_to])
# if any([b.kind != 'nav' for _, b in udf.results.items()]):
# # have sig result, gotta do a bind merge next time round
# udf.prior_merge = merged_del_result
# We now build the result back into dask arrays and assign them
# into the appropriate slices of the full result buffers
wrapped_res = delayed_to_buffer_wrappers(merged_del_result, flat_structure,
task.partition, as_buffer=False)
renested = delayed_unpack.rebuild_nested(wrapped_res, flat_mapping)
# Assign into the result buffers with the partially merged version
# This is OK because we are calling this part merge function sequentially
# so each new call gets the most recent version of the buffer in the
# dask task graph
udf.set_views_for_partition(task.partition)
dest = udf.results.get_proxy()
merged = MergeAttrMapping(renested)
for k in dest:
getattr(dest, k)[:] = getattr(merged, k)
v = damage.get_view_for_partition(task.partition)
v[:] = True
def _accumulate_part_results(self, udf, part_results, task):
"""
If the udf has a merge_all method, this function is used
to accumulate dask array-backed partial results on the executor
so that the merge_all can be called on all of them at once
Ensures that the results are correctly ordered and complete
before allowing merge_all to be called
"""
buf = next(iter(part_results.values())) # get the first buffer
slice_with_roi = buf._slice_for_partition(task.partition)
self._part_results[udf][slice_with_roi] = part_results
# number of frames in dataset
if udf.meta.roi is not None:
target_coverage = np.count_nonzero(udf.meta.roi)
else:
target_coverage = prod(task.partition.meta.shape.nav)
# number of frames we have results for
current_coverage = sum(prod(k.shape.nav) for k in self._part_results[udf].keys())
if target_coverage == current_coverage:
ordered_results = sorted(self._part_results[udf].items(),
key=lambda kv: kv[0].origin[0])
self._part_results[udf] = OrderedDict(ordered_results)
return True
elif current_coverage > target_coverage:
raise RuntimeError('More frames accumulated than ROI specifies - '
f'target {target_coverage} - processed {current_coverage}')
return False
def results_for_dataset_sync(self, dataset: DataSet, executor: 'DelayedJobExecutor',
roi: Optional[np.ndarray] = None, progress: bool = False,
corrections: Optional[CorrectionSet] = None, backends: Optional[BackendSpec] = None,
dry: bool = False) -> Iterable[tuple]:
executor.register_master_udfs(self._udfs)
return super().results_for_dataset_sync(
dataset, executor, roi=roi, progress=progress,
corrections=corrections, backends=backends, dry=dry
)
[docs]
class DelayedJobExecutor(BaseJobExecutor):
"""
:class:`~libertem.common.executor.JobExecutor` that uses dask.delayed to execute tasks.
.. versionadded:: 0.9.0
Highly experimental at this time!
"""
def __init__(self):
# Only import if actually instantiated, i.e. will likely be used
import libertem.preload # noqa: 401
self._udfs = None
[docs]
@contextlib.contextmanager
def scatter(self, obj):
yield delayed(obj)
def cancel(self, cancel_id: Any):
pass
[docs]
def run_tasks(
self,
tasks: Iterable[TaskProtocol],
params_handle: Any,
cancel_id: Any,
task_comm_handler: TaskCommHandler,
):
"""
Wraps the call task() such that it returns a flat list
of results, then unpacks the Delayed return value into
the normal
:code:`tuple(udf.results for udf in self._udfs)`
where the buffers inside udf.results are dask arrays instead
of normal np.arrays
Needs a reference to the udfs on the master node so
that the results structure can be inferred. This reference is
found in self._udfs, which is set with the method:
:code:`executor.register_master_udfs(udfs)`
called from :meth:`DelayedUDFRunner.results_for_dataset_sync`
"""
env = Environment(threads_per_worker=1, threaded_executor=True)
for task in tasks:
structure = structure_from_task(self._udfs, task)
flat_structure = delayed_unpack.flatten_nested(structure)
flat_mapping = delayed_unpack.build_mapping(structure)
flat_result_task = partial(task_wrap, task)
result = delayed(flat_result_task, nout=len(flat_structure))(env=env,
params=params_handle)
wrapped_res = delayed_to_buffer_wrappers(result, flat_structure, task.partition,
roi=self._udfs[0].meta.roi)
renested = delayed_unpack.rebuild_nested(wrapped_res, flat_mapping)
result = tuple(UDFData(data=res) for res in renested)
yield result, task
[docs]
def run_function(self, fn, *args, **kwargs):
result = fn(*args, **kwargs)
return result
def run_delayed(self, fn, *args, _delayed_kwargs=None, **kwargs):
if _delayed_kwargs is None:
_delayed_kwargs = {}
result = delayed(fn, **_delayed_kwargs)(*args, **kwargs)
return result
[docs]
def map(self, fn, iterable):
return [fn(item)
for item in iterable]
[docs]
def run_each_host(self, fn, *args, **kwargs):
return {"localhost": fn(*args, **kwargs)}
[docs]
def run_each_worker(self, fn, *args, **kwargs):
return {"delayed": fn(*args, **kwargs)}
[docs]
def get_available_workers(self):
resources = {"compute": 1, "CPU": 1}
# We don't know at this time,
# but assume one worker per CPU
devices = detect()
return WorkerSet([
Worker(
name='delayed', host='localhost',
resources=resources,
nthreads=len(devices['cpus']),
)
])
def modify_buffer_type(self, buf):
"""
Convert existing buffers from BufferWrapper to DaskBufferWrapper
A refactoring of the UDF backend would remove the need for this method.
:meta private:
"""
return DaskBufferWrapper.from_buffer(buf)
def register_master_udfs(self, udfs):
"""
Give the executor a reference to the udfs instantiated
on the main node, for introspection purposes
:meta private:
"""
self._udfs = udfs
def _compute(self, *args, udfs=None, user_backends=None, traverse=True, **kwargs):
"""
Acts as dask.compute(*args, **kwargs) but with knowledge
of Libertem data structures and compute resources
"""
if 'resources' in kwargs:
if udfs is not None:
raise ValueError('Cannot specify both udfs for resources and resources to use')
resources = kwargs.get('resources')
elif udfs is not None:
resources = self.get_resources(udfs, user_backends=user_backends)
else:
resources = None
kwargs['resources'] = resources
to_unpack = tuple(a for a in args)
unwrapped_args = tuple(self.unwrap_results(a) for a in to_unpack)
results = dask.compute(*unwrapped_args, traverse=traverse, **kwargs)
if len(args) == 1:
if len(results) > 1:
raise RuntimeWarning(f'Unexpected number of results {len(results)} '
'from dask.compute, dropping extras')
results = results[0]
return results
[docs]
@staticmethod
def get_resources_from_udfs(udfs, user_backends=None):
"""
Returns the resources required by the udfs passed as
argument, excluding those not in the tuple user_backends
"""
if user_backends is None:
user_backends = tuple()
if isinstance(udfs, UDF):
udfs = [udfs]
backends = [udf.get_backends() for udf in udfs]
return get_resources_for_backends(backends, user_backends)
@staticmethod
def unwrap_results(results):
unpackable = {**delayed_unpack.default_unpackable(),
UDFData: lambda x: x._data.items(),
DaskPreallocBufferWrapper: lambda x: [(0, x.data)],
DaskBufferWrapper: lambda x: [(0, x.data)],
}
res_unpack = delayed_unpack.flatten_nested(results, unpackable_types=unpackable)
flat_mapping = delayed_unpack.build_mapping(results, unpackable_types=unpackable)
flat_mapping_reduced = [el[:-1] if issubclass(el[-1][0], BufferWrapper) else el
for el in flat_mapping]
return delayed_unpack.rebuild_nested(res_unpack, flat_mapping_reduced)
def get_udf_runner(self) -> type['UDFRunner']:
return DelayedUDFRunner
def make_copy(array_dict):
for k, v in array_dict.items():
if not v.flags['WRITEABLE']:
array_dict[k] = v.copy()
return array_dict
def merge_wrap(udf, dest_dict, src_dict, raw_damage):
"""
The function called as delayed, acting as a wrapper
to return a flat list of results rather than a structure
of UDFData or MergeAttrMapping
:meta private:
"""
# Have to make a copy of dest buffers because Dask brings
# data into the delayed function as read-only np arrays
# I experimented with setting WRITEABLE to True but this
# resulted in errors in the final array
dest_dict = make_copy(dest_dict)
dest = MergeAttrMapping(dest_dict)
src = MergeAttrMapping(src_dict)
# In place merge into the copy of dest
udf.meta.set_valid_nav_mask(raw_damage)
udf.merge(dest=dest, src=src)
# Return flat list of results so they can be unpacked later
return delayed_unpack.flatten_nested(dest._dict)
def task_wrap(task, *args, **kwargs):
"""
Flatten the structure tuple(udf.results for udf in self._udfs)
where udf.results is an instance of UDFData(data={'name':BufferWrapper,...})
into a simple list [np.ndarray, np.ndarray, ...]
:meta private:
"""
res = task(*args, **kwargs)
res = tuple(r._data for r in res)
flat_res = delayed_unpack.flatten_nested(res)
return [r._data for r in flat_res]
def structure_from_task(udfs, task):
"""
Based on the instantiated whole dataset UDFs and the task
information, build a description of the expected UDF results
for the task's partition like:
:code:`({'buffer_name': StructDescriptor(shape, dtype, extra_shape, buffer_kind), ...}, ...)`
:meta private:
"""
structure = []
for udf in udfs:
res_data = {}
for buffer_name, buffer in udf.results.items():
part_buf_extra_shape = buffer.extra_shape
buffer.set_shape_partition(task.partition, roi=buffer._roi)
part_buf_shape = buffer.shape
part_buf_dtype = buffer.dtype
res_data[buffer_name] = \
delayed_unpack.StructDescriptor(np.ndarray,
shape=part_buf_shape,
dtype=part_buf_dtype,
kind=buffer.kind,
extra_shape=part_buf_extra_shape)
results_container = res_data
structure.append(results_container)
return tuple(structure)
def delayed_to_buffer_wrappers(flat_delayed, flat_structure, partition, roi=None, as_buffer=True):
"""
Take the iterable Delayed results object, and re-wrap each Delayed object
back into a BufferWrapper wrapping a dask.array of the correct shape and dtype
:meta private:
"""
wrapped_res = []
for el, descriptor in zip(flat_delayed, flat_structure):
buffer_kind = descriptor.kwargs.pop('kind')
extra_shape = descriptor.kwargs.pop('extra_shape')
buffer_dask = da.from_delayed(el, *descriptor.args, **descriptor.kwargs)
if as_buffer:
buffer = DaskBufferWrapper(buffer_kind,
extra_shape=extra_shape,
dtype=descriptor.kwargs['dtype'])
# Need to test whether roi=None here is a problem
buffer.set_shape_partition(partition, roi=roi)
buffer.replace_array(buffer_dask)
wrapped_res.append(buffer)
else:
wrapped_res.append(buffer_dask)
return wrapped_res