import os
import contextlib
from copy import deepcopy
import functools
import logging
import signal
from typing import Any, Optional, Union, Callable
from collections.abc import Iterable
import uuid
from dask import distributed as dd
import dask
from libertem.common.threading import set_num_threads_env
from .base import BaseJobExecutor, AsyncAdapter, ResourceError
from libertem.common.executor import (
JobCancelledError, TaskCommHandler, TaskProtocol, Environment, WorkerContext,
)
from libertem.common.async_utils import sync_to_async
from libertem.common.scheduler import Worker, WorkerSet
from libertem.common.backend import set_use_cpu, set_use_cuda
from libertem.common.async_utils import adjust_event_loop_policy
from .utils import assign_cudas
log = logging.getLogger(__name__)
[docs]
class DaskWorkerContext(WorkerContext):
def __init__(self, comms_topic: Optional[str]):
# DaskWorkerContext sends all messages via a single unique topic id
# which are later unpacked on the client node; this allows us to handle
# concurrent runs using the same executor with separate comms channels
self._comms_topic = comms_topic
@property
def dask_worker(self):
try:
return self._worker
except AttributeError:
self._worker = dd.get_worker()
return self._worker
def signal(self, ident: str, topic: str, msg_dict: dict[str, Any]):
if self._comms_topic is None:
# Scheduler Dask does not have comms so don't send
return
msg_dict.update({'ident': ident, 'topic': topic})
try:
self.dask_worker.log_event(self._comms_topic, msg_dict)
except AttributeError:
# No structured logs available in this Dask
# Catch the exception here just in case there is
# version / API mismatch
pass
[docs]
@contextlib.contextmanager
def set_worker_log_level(level: Union[str, int], force: bool = False):
"""
Set the dask.distributed log level for any processes spawned
within the context manager. If force is False, don't overwrite
any existing environment variable.
"""
env_keys = ['DASK_DISTRIBUTED__LOGGING__DISTRIBUTED']
try:
old_env = {
k: os.environ[k]
for k in env_keys
if k in os.environ
}
os.environ.update({
k: str(level)
for k in env_keys
if force or (k not in old_env)
})
yield
finally:
os.environ.update(old_env)
for key in env_keys:
if key not in old_env:
del os.environ[key]
def worker_setup(resource, device):
# Disable handling Ctrl-C on the workers for a local cluster
# since the nanny restarts workers in that case and that gets mixed
# with Ctrl-C handling of the main process, at least on Windows
signal.signal(signal.SIGINT, signal.SIG_IGN)
if resource == "CUDA":
set_use_cuda(device)
elif resource == "CPU":
set_use_cpu(device)
else:
raise ValueError("Unknown resource %s, use 'CUDA' or 'CPU'", resource)
[docs]
def cluster_spec(
cpus: Union[int, Iterable[int]],
cudas: Union[int, Iterable[int]],
has_cupy: bool,
name: str = 'default',
num_service: int = 1,
options: Optional[dict] = None,
preload: Optional[tuple[str, ...]] = None
):
'''
Create a worker specification dictionary for a LiberTEM Dask cluster
The return from this function can be passed to :code:`DaskJobExecutor.make_local(spec=spec)`.
This creates a Dask cluster spec with special initializations and resource tags
for CPU + GPU processing in LiberTEM.
See :ref:`cluster spec` for an example.
See https://distributed.dask.org/en/stable/api.html#distributed.SpecCluster
for more info on cluster specs.
Parameters
----------
cpus: int | Iterable[int]
IDs for CPU workers as an iterable, or an integer number of workers to create.
Currently no pinning is used, i.e. this specifies the total
number and identification of workers, not the CPU cores that are used.
cudas: int | Iterable[int]
IDs for CUDA device workers as an iterable, or an integer number of GPU workers to
create. LiberTEM will use the IDs specified or assign round-robin to the available devices.
In the iterable case these have to match CUDA device IDs on the system.
Specify the same ID multiple times to spawn multiple workers on the same CUDA device.
has_cupy: bool
Specify if the cluster should signal that it supports GPU-based array programming using
CuPy
name
Prefix for the worker names
num_service
Number of additional workers that are reserved for service tasks. Computation tasks
will not be scheduled on these workers, which guarantees responsive behavior for file
browsing etc.
options
Options to pass through to every worker. See Dask documentation for details
preload
Items to preload on workers in addition to LiberTEM-internal preloads.
This can be used to load libraries, for example HDF5 filter plugins before h5py is used.
See https://docs.dask.org/en/stable/how-to/customize-initialization.html#preload-scripts
for more information.
See also
--------
:func:`libertem.utils.devices.detect`
'''
if options is None:
options = {}
if preload is None:
preload = ()
if options.get("nthreads") is None:
options["nthreads"] = 1
if options.get("silence_logs") is None:
options["silence_logs"] = logging.WARN
workers_spec = {}
cpu_options = deepcopy(options)
cpu_options["resources"] = {"CPU": 1, "compute": 1, "ndarray": 1}
cpu_base_spec = {
"cls": dd.Nanny,
"options": cpu_options,
}
# Service workers not for computation
service_options = deepcopy(options)
service_options["resources"] = {}
service_base_spec = {
"cls": dd.Nanny,
"options": service_options
}
cuda_options = deepcopy(options)
cuda_options["resources"] = {"CUDA": 1, "compute": 1}
if has_cupy:
cuda_options["resources"]["ndarray"] = 1
cuda_base_spec = {
"cls": dd.Nanny,
"options": cuda_options
}
def _get_tracing_setup(service_name: str, service_id: str) -> str:
return (
f"from libertem.common.tracing import maybe_setup_tracing; "
f"maybe_setup_tracing(service_name='{service_name}', service_id='{service_id}')"
)
if isinstance(cpus, int):
cpus = tuple(range(cpus))
for cpu in cpus:
worker_name = f'{name}-cpu-{cpu}'
cpu_spec = deepcopy(cpu_base_spec)
cpu_spec['options']['preload'] = preload + (
'from libertem.executor.dask import worker_setup; '
+ f'worker_setup(resource="CPU", device={cpu})',
_get_tracing_setup(worker_name, str(cpu)),
'libertem.preload',
)
workers_spec[worker_name] = cpu_spec
for service in range(num_service):
worker_name = f'{name}-service-{service}'
service_spec = deepcopy(service_base_spec)
service_spec['options']['preload'] = preload + (
_get_tracing_setup(worker_name, str(service)),
'libertem.preload',
)
workers_spec[worker_name] = service_spec
cudas = assign_cudas(cudas)
for cuda in cudas:
worker_name = f'{name}-cuda-{cuda}'
if worker_name in workers_spec:
num_with_name = sum(n.startswith(worker_name) for n in workers_spec)
worker_name = f'{worker_name}-{num_with_name - 1}'
cuda_spec = deepcopy(cuda_base_spec)
cuda_spec['options']['preload'] = preload + (
'from libertem.executor.dask import worker_setup; '
+ f'worker_setup(resource="CUDA", device={cuda})',
_get_tracing_setup(worker_name, str(cuda)),
'libertem.preload',
)
workers_spec[worker_name] = cuda_spec
return workers_spec
def _run_task(task, params, task_id, threaded_executor, comms_topic: Optional[str]):
"""
Very simple wrapper function. As dask internally caches functions that are
submitted to the cluster in various ways, we need to make sure to
consistently use the same function, and not build one on the fly.
Without this function, UDFTask->UDF->UDFData ends up in the
cache, which blows up memory usage over time.
"""
worker_context = DaskWorkerContext(comms_topic)
env = Environment(threads_per_worker=1,
threaded_executor=threaded_executor,
worker_context=worker_context)
task_result = task(env=env, params=params)
return {
"task_result": task_result,
"task_id": task_id,
}
def _simple_run_task(task):
return task()
class CommonDaskMixin:
client: dd.Client
def _task_idx_to_workers(self, workers: WorkerSet, idx: int) -> WorkerSet:
hosts = list(sorted(workers.hosts()))
host_idx = idx % len(hosts)
host = hosts[host_idx]
return workers.filter(lambda w: w.host == host)
def _future_for_location(
self,
task: TaskProtocol,
locations: WorkerSet,
resources,
workers: WorkerSet,
task_args=None,
wrap_fn=_run_task,
):
"""
Submit tasks and return the resulting futures
Parameters
----------
task:
callable
locations:
potential locations to run the task
resources:
required resources for the task
workers : WorkerSet
Available workers in the cluster
"""
submit_kwargs = {}
if task_args is None:
task_args = {}
if locations is not None:
if len(locations) == 0:
raise ValueError("no workers found for task")
location_names = locations.names()
else:
location_names = None
submit_kwargs.update({
'resources': self._validate_resources(workers, resources),
'workers': location_names,
'pure': False,
})
return self.client.submit(
wrap_fn, task, *task_args, **submit_kwargs
)
def _validate_resources(self, workers, resources):
# This is set in the constructor of DaskJobExecutor
if self.lt_resources:
if not self._resources_available(workers, resources):
raise ResourceError("Requested resources not available in cluster:", resources)
result = resources
else:
if 'CUDA' in resources:
raise ResourceError(
"Requesting CUDA resource on a cluster without resource management."
)
result = {}
return result
def _resources_available(self, workers, resources):
def filter_fn(worker):
return all(worker.resources.get(key, 0) >= resources[key] for key in resources.keys())
return len(workers.filter(filter_fn))
def has_libertem_resources(self):
workers = self.get_available_workers()
def has_resources(worker):
r = worker.resources
return 'compute' in r and (('CPU' in r and 'ndarray' in r) or 'CUDA' in r)
return len(workers.filter(has_resources)) > 0
def _get_future(
self,
task: TaskProtocol,
workers: WorkerSet,
idx: int,
params_handle,
threaded_executor,
comms_topic: Optional[str]
):
if len(workers) == 0:
raise RuntimeError("no workers available!")
params_fut = self._scatter_map[params_handle]
return self._future_for_location(
task=task,
locations=task.get_locations() or self._task_idx_to_workers(
workers, idx
),
resources=task.get_resources(),
workers=workers,
task_args=(
params_fut,
idx,
threaded_executor,
comms_topic,
)
)
def get_available_workers(self) -> WorkerSet:
info = self.client.scheduler_info()
return WorkerSet([
Worker(
name=worker['name'],
host=worker['host'],
resources=worker['resources'],
nthreads=worker['nthreads'],
)
for worker in info['workers'].values()
])
def get_resource_details(self):
workers = self.get_available_workers()
details = {}
for worker in workers:
host_name = worker.host
if worker.name.startswith("tcp"):
# for handling `dask-worker`
# `dask-worker` name starts with "tcp"
# only supports CPU
resource = 'cpu'
else:
# for handling `libertem-worker`
r = worker.resources
if "CPU" in r:
resource = 'cpu'
elif "CUDA" in r:
resource = 'cuda'
else:
resource = 'service'
if host_name not in details.keys():
details[host_name] = {
'host': host_name,
'cpu': 0,
'cuda': 0,
'service': 0,
}
details[host_name][resource] += 1
details_sorted = []
for host in sorted(details.keys()):
details_sorted.append(details[host])
return details_sorted
def _dispatch_messages(subscribers: dict[str, list[Callable]], dask_message: tuple[float, dict]):
"""
Unpacks the Dask message format and forwards the message
to all subscribed callbacks for that topic (if any)
"""
timestamp, message = dask_message
true_topic = message.pop('topic')
for handler in subscribers.get(true_topic, []):
handler(true_topic, message)
[docs]
class DaskJobExecutor(CommonDaskMixin, BaseJobExecutor):
'''
Default LiberTEM executor that uses `Dask futures
<https://docs.dask.org/en/stable/futures.html>`_.
Parameters
----------
client : distributed.Client
is_local : bool
Close the Client and cluster when the executor is closed.
lt_resources : bool
Specify if the cluster has LiberTEM resource tags and environment
variables for GPU processing. Autodetected by default.
'''
def __init__(self, client: dd.Client, is_local: bool = False,
lt_resources: bool = None):
self.is_local = is_local
self.client = client
if lt_resources is None:
lt_resources = self.has_libertem_resources()
self.lt_resources = lt_resources
self._futures = {}
self._scatter_map = {}
[docs]
@contextlib.contextmanager
def scatter(self, obj):
# an additional layer of indirection, because we want to be able to
# redirect keys to new values
handle = str(uuid.uuid4())
try:
fut = self.client.scatter(obj, broadcast=True, hash=False)
self._scatter_map[handle] = fut
yield handle
finally:
if handle in self._scatter_map:
del self._scatter_map[handle]
[docs]
def scatter_update(self, handle, obj):
fut = self.client.scatter(obj, broadcast=True, hash=False)
self._scatter_map[handle] = fut
[docs]
def scatter_update_patch(self, handle, patch):
fut = self._scatter_map[handle]
def _do_patch(obj):
if not hasattr(obj, 'patch'):
raise TypeError(f'object is not patcheable: {obj}')
obj.patch(patch)
# can't `client.run` here, as that doesn't resolve the scatter future
futures = []
for worker in self.get_available_workers().names():
futures.append(self.client.submit(_do_patch, fut, pure=False, workers=[worker]))
for res in dd.as_completed(futures, loop=self.client.loop):
pass
[docs]
def run_tasks(
self,
tasks: Iterable[TaskProtocol],
params_handle: Any,
cancel_id: Any,
task_comm_handler: TaskCommHandler,
):
tasks = list(tasks)
tasks_w_index = list(enumerate(tasks))
def _id_to_task(task_id):
return tasks[task_id]
workers = self.get_available_workers()
threaded_executor = workers.has_threaded_workers()
self._futures[cancel_id] = []
initial = []
try:
topic_id = f'topic-{cancel_id}'
# Wrap all subscriptions into single unique topic
self.client.subscribe_topic(
topic_id,
functools.partial(_dispatch_messages,
task_comm_handler.subscriptions)
)
except AttributeError:
# Dask version does not support structured logs
# Fall back to partition-level progress updates only
topic_id = None
for w in range(int(len(workers))):
if not tasks_w_index:
break
idx, wrapped_task = tasks_w_index.pop(0)
future = self._get_future(wrapped_task, workers, idx, params_handle,
threaded_executor, topic_id)
initial.append(future)
self._futures[cancel_id].append(future)
try:
as_completed = dd.as_completed(initial, with_results=True, loop=self.client.loop)
for future, result_wrap in as_completed:
if future.cancelled():
log.debug(
"future %r is cancelled, stopping",
future,
)
del self._futures[cancel_id]
raise JobCancelledError()
result = result_wrap['task_result']
task = _id_to_task(result_wrap['task_id'])
if tasks_w_index:
idx, wrapped_task = tasks_w_index.pop(0)
future = self._get_future(
wrapped_task, workers, idx, params_handle, threaded_executor, topic_id
)
as_completed.add(future)
self._futures[cancel_id].append(future)
yield result, task
finally:
if cancel_id in self._futures:
del self._futures[cancel_id]
if topic_id is not None:
self.client.unsubscribe_topic(topic_id)
def cancel(self, cancel_id):
log.debug("cancelling with cancel_id=`%s`", cancel_id)
if cancel_id in self._futures:
futures = self._futures[cancel_id]
self.client.cancel(futures)
[docs]
def run_each_partition(self, partitions, fn, all_nodes=False):
"""
Run `fn` for all partitions. Yields results in order of completion.
Parameters
----------
partitions : List[Partition]
List of relevant partitions.
fn : callable
Function to call, will get the partition as first and only argument.
all_nodes : bool
If all_nodes is True, run the function on all nodes that have this partition,
otherwise run on any node that has the partition. If a partition has no location,
the function will not be run for that partition if `all_nodes` is True, otherwise
it will be run on any node.
"""
def _make_items_all():
for p in partitions:
locs = p.get_locations()
if locs is None:
continue
for workers in locs.group_by_host():
yield (lambda: fn(p), workers)
if all_nodes:
items = _make_items_all()
else:
# TODO check if we should request a compute resource
items = ((lambda: fn(p), p.get_locations(), {})
for p in partitions)
workers = self.get_available_workers()
futures = [
self._future_for_location(*item, workers, wrap_fn=_simple_run_task)
for item in items
]
# TODO: do we need cancellation and all that good stuff?
for future, result in dd.as_completed(futures, with_results=True, loop=self.client.loop):
if future.cancelled():
raise JobCancelledError()
yield result
[docs]
def run_function(self, fn, *args, **kwargs):
"""
run a callable :code:`fn` on any worker
"""
fn_with_args = functools.partial(fn, *args, **kwargs)
future = self.client.submit(fn_with_args, priority=1, pure=False)
return future.result()
[docs]
def map(self, fn, iterable):
"""
Run a callable :code:`fn` for each element in :code:`iterable`, on arbitrary worker nodes.
Parameters
----------
fn : callable
Function to call. Should accept exactly one parameter.
iterable : Iterable
Which elements to call the function on.
"""
return [future.result()
for future in self.client.map(fn, iterable, pure=False)]
[docs]
def run_each_host(self, fn, *args, **kwargs):
"""
Run a callable :code:`fn` once on each host, gathering all results into
a dict host -> result
"""
# TODO: any cancellation/errors to handle?
available_workers = self.get_available_workers()
future_map = {}
for worker_set in available_workers.group_by_host():
future_map[worker_set.example().host] = self.client.submit(
functools.partial(fn, *args, **kwargs),
priority=1,
workers=worker_set.names(),
# NOTE: need pure=False, otherwise the functions will all map to the same
# scheduler key and will only run once
pure=False,
)
result_map = {
host: future.result()
for host, future in future_map.items()
}
return result_map
[docs]
def run_each_worker(self, fn, *args, **kwargs):
# Client.run() creates issues on Windows and OS X with Python 3.6
# FIXME workaround may not be needed anymore for Python 3.7+
available_workers = self.get_available_workers()
future_map = {}
for worker in available_workers:
future_map[worker.name] = self.client.submit(
functools.partial(fn, *args, **kwargs),
priority=1,
workers=[worker.name],
# NOTE: need pure=False, otherwise the functions will all map to the same
# scheduler key and will only run once
pure=False,
)
result_map = {
name: future.result()
for name, future in future_map.items()
}
return result_map
[docs]
def close(self):
if self.is_local:
# Client.close won't close the Cluster itself because
# we provided an external dd.SpecCluster
self.client.close()
# Manually close the cluster if not yet torn down
# use getattr just in case cluster is already gone
if getattr(self.client, 'cluster', None) is not None:
self.client.cluster.close(timeout=30)
# NOTE: distributed already registers atexit handlers for
# both clients and clusters, this is here to allow manual closure
# followed by creation of a new Executor without accumulating clusters
[docs]
@classmethod
def connect(cls, scheduler_uri, *args, client_kwargs: Optional[dict] = None, **kwargs):
"""
Connect to a remote dask scheduler.
Parameters
----------
scheduler_uri: str
Compatible with the :code:`address` parameter of :class:`distributed.Client`.
client_kwargs: dict or None
Passed as kwargs to :class:`distributed.Client`.
:code:`client_kwargs['set_as_default']` is set to :code:`False`
unless specified otherwise to avoid interference with Dask-based workflows.
Pass :code:`client_kwargs={'set_as_default': True}` to set the Client as the
default Dask scheduler and keep it running when the Context closes.
*args, **kwargs: Passed to :class:`DaskJobExecutor`.
Returns
-------
DaskJobExecutor
the connected JobExecutor
"""
if client_kwargs is None:
client_kwargs = {}
if client_kwargs.get('set_as_default') is None:
client_kwargs['set_as_default'] = False
is_local = not client_kwargs['set_as_default']
client = dd.Client(address=scheduler_uri, **client_kwargs)
return cls(client=client, is_local=is_local, *args, **kwargs)
[docs]
@classmethod
def make_local(cls, spec: Optional[dict] = None, cluster_kwargs: Optional[dict] = None,
client_kwargs: Optional[dict] = None, preload: Optional[tuple[str]] = None):
"""
Spin up a local dask cluster
Parameters
----------
spec
Dask cluster spec, see
https://distributed.dask.org/en/stable/api.html#distributed.SpecCluster
for more info.
:func:`libertem.utils.devices.detect` allows to detect devices that can be used
with LiberTEM, and :func:`cluster_spec` can be used to create a :code:`spec`
with customized parameters.
cluster_kwargs
Passed to :class:`distributed.SpecCluster`.
client_kwargs
Passed to :class:`distributed.Client`. Pass
:code:`client_kwargs={'set_as_default': True}` to set the Client as the
default Dask scheduler.
preload: Optional[Tuple[str]]
Passed to :func:`cluster_spec` if :code:`spec` is :code:`None`.
Returns
-------
DaskJobExecutor
the connected JobExecutor
"""
# Distributed doesn't adjust the event loop policy when being run
# from within pytest as of version 2.21.0. For that reason we
# adjust the policy ourselves here.
adjust_event_loop_policy()
if spec is None:
from libertem.utils.devices import detect
spec = cluster_spec(**detect(), preload=preload)
else:
if preload is not None:
raise ValueError(
"Passing both spec and preload is not supported. "
"Instead, include preloading specification in the spec"
)
if client_kwargs is None:
client_kwargs = {}
if client_kwargs.get('set_as_default') is None:
client_kwargs['set_as_default'] = False
if cluster_kwargs is None:
cluster_kwargs = {}
if cluster_kwargs.get('silence_logs') is None:
cluster_kwargs['silence_logs'] = logging.WARN
dist_log_level = dask.config.get('distributed.logging.distributed', default=None)
if dist_log_level is None:
dist_log_level = cluster_kwargs['silence_logs']
with set_num_threads_env(n=1), set_worker_log_level(dist_log_level):
# Mitigation for https://github.com/dask/distributed/issues/6776
with dask.config.set({"distributed.worker.profile.enabled": False}):
cluster = dd.SpecCluster(workers=spec, **(cluster_kwargs or {}))
client = dd.Client(cluster, **(client_kwargs or {}))
client.wait_for_workers(len(spec))
is_local = not client_kwargs['set_as_default']
return cls(client=client, is_local=is_local, lt_resources=True)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
[docs]
class AsyncDaskJobExecutor(AsyncAdapter):
def __init__(self, wrapped=None, *args, **kwargs):
if wrapped is None:
wrapped = DaskJobExecutor(*args, **kwargs)
super().__init__(wrapped)
@classmethod
async def connect(cls, scheduler_uri, *args, **kwargs):
executor = await sync_to_async(functools.partial(
DaskJobExecutor.connect,
scheduler_uri=scheduler_uri,
*args,
**kwargs,
))
return cls(wrapped=executor)
@classmethod
async def make_local(cls, spec=None, cluster_kwargs=None, client_kwargs=None):
executor = await sync_to_async(functools.partial(
DaskJobExecutor.make_local,
spec=spec,
cluster_kwargs=cluster_kwargs,
client_kwargs=client_kwargs,
))
return cls(wrapped=executor)
def cli_worker(
scheduler,
local_directory,
cpus,
cudas,
has_cupy,
name,
log_level,
preload: tuple[str, ...]
):
import asyncio
options = {
"silence_logs": log_level,
"local_directory": local_directory
}
spec = cluster_spec(
cpus=cpus, cudas=cudas, has_cupy=has_cupy, name=name,
options=options, preload=preload,
)
async def run(spec):
# Mitigation for https://github.com/dask/distributed/issues/6776
with dask.config.set({"distributed.worker.profile.enabled": False}):
workers = []
for name, spec in spec.items():
cls = spec['cls']
workers.append(
cls(scheduler, name=name, **spec['options'])
)
import asyncio
await asyncio.gather(*workers)
for w in workers:
await w.finished()
asyncio.get_event_loop().run_until_complete(run(spec))