import contextlib
import functools
import logging
import uuid
from typing import Optional
import concurrent.futures
from typing import Any
from collections.abc import Iterable
from opentelemetry import trace
from .base import (
BaseJobExecutor, AsyncAdapter,
)
from libertem.common.executor import (
JobCancelledError, TaskProtocol, TaskCommHandler, WorkerContext,
SimpleWorkerQueue, WorkerQueue, Environment
)
from libertem.common.async_utils import sync_to_async
from libertem.utils.devices import detect
from libertem.common.scheduler import Worker, WorkerSet
from libertem.common.backend import get_use_cuda
from libertem.common.tracing import TracedThreadPoolExecutor
log = logging.getLogger(__name__)
tracer = trace.get_tracer(__name__)
[docs]
class ConcurrentWorkerContext(WorkerContext):
def __init__(self, msg_queue: WorkerQueue):
self._msg_queue = msg_queue
def signal(self, ident: str, topic: str, msg_dict: dict[str, Any]):
if 'ident' in msg_dict:
raise ValueError('ident is a reserved name')
msg_dict.update({'ident': ident})
self._msg_queue.put((topic, msg_dict))
def _run_task(task, params, task_id, threaded_executor, msg_queue, scatter_map):
"""
Wraps the task to be run in the pool
"""
params = scatter_map[params]
worker_context = ConcurrentWorkerContext(msg_queue)
env = Environment(threads_per_worker=1,
threaded_executor=threaded_executor,
worker_context=worker_context)
task_result = task(env=env, params=params)
return {
"task_result": task_result,
"task_id": task_id,
}
[docs]
class ConcurrentJobExecutor(BaseJobExecutor):
'''
:class:`JobExecutor` that uses :mod:`python.concurrent.futures`.
.. versionadded:: 0.9.0
Parameters
----------
client : concurrent.futures.Executor
is_local : bool
Shut the client down when the executor closes.
'''
def __init__(self, client: concurrent.futures.Executor, is_local=False):
# Only import if actually instantiated, i.e. will likely be used
import libertem.preload # noqa: 401
self.is_local = is_local
self.client = client
self._futures = {}
self._scatter_map = {}
[docs]
@contextlib.contextmanager
def scatter(self, obj):
handle = str(uuid.uuid4())
self._scatter_map[handle] = obj
try:
yield handle
finally:
del self._scatter_map[handle]
[docs]
def scatter_update(self, handle, obj):
self._scatter_map[handle] = obj
[docs]
def scatter_update_patch(self, handle, patch):
self._scatter_map[handle].patch(patch)
def _get_future(self, wrapped_task, idx, params_handle, msg_queue):
return self.client.submit(
_run_task,
task=wrapped_task, params=params_handle, task_id=idx,
threaded_executor=True, msg_queue=msg_queue,
scatter_map=self._scatter_map,
)
[docs]
def run_tasks(
self,
tasks: Iterable[TaskProtocol],
params_handle: Any,
cancel_id: Any,
task_comm_handler: TaskCommHandler,
):
tasks = list(tasks)
def _id_to_task(task_id):
return tasks[task_id]
self._futures[cancel_id] = []
msg_queue = SimpleWorkerQueue()
for idx, wrapped_task in list(enumerate(tasks)):
future = self._get_future(wrapped_task, idx, params_handle, msg_queue)
self._futures[cancel_id].append(future)
with task_comm_handler.monitor(msg_queue):
try:
as_completed = concurrent.futures.as_completed(self._futures[cancel_id])
for future in as_completed:
result_wrap = future.result()
if future.cancelled():
del self._futures[cancel_id]
raise JobCancelledError()
result = result_wrap['task_result']
task = _id_to_task(result_wrap['task_id'])
yield result, task
finally:
if cancel_id in self._futures:
del self._futures[cancel_id]
def cancel(self, cancel_id):
if cancel_id in self._futures:
for future in self._futures[cancel_id]:
future.cancel()
[docs]
def run_function(self, fn, *args, **kwargs):
"""
run a callable :code:`fn` on any worker
"""
fn_with_args = functools.partial(fn, *args, **kwargs)
future = self.client.submit(fn_with_args)
return future.result()
[docs]
def map(self, fn, iterable):
"""
Run a callable :code:`fn` for each element in :code:`iterable`, on arbitrary worker nodes.
Parameters
----------
fn : callable
Function to call. Should accept exactly one parameter.
iterable : Iterable
Which elements to call the function on.
"""
return self.client.map(fn, iterable)
[docs]
def get_available_workers(self):
resources = {"compute": 1, "CPU": 1}
if get_use_cuda() is not None:
resources["CUDA"] = 1
return WorkerSet([
Worker(name='concurrent', host='localhost', resources=resources, nthreads=1)
])
else:
devices = detect()
return WorkerSet([
Worker(
name='concurrent', host='localhost',
resources=resources,
nthreads=len(devices['cpus']),
)
])
[docs]
def run_each_host(self, fn, *args, **kwargs):
"""
Run a callable :code:`fn` once on each host, gathering all results into
a dict host -> result
"""
# TODO: any cancellation/errors to handle?
future = self.client.submit(fn, *args, **kwargs)
return {"localhost": future.result()}
[docs]
def run_each_worker(self, fn, *args, **kwargs):
future = self.client.submit(fn, *args, **kwargs)
return {"inline": future.result()}
[docs]
def close(self):
if self.is_local:
self.client.shutdown(wait=False)
[docs]
@classmethod
def make_local(cls, n_threads: Optional[int] = None):
"""
Create a local ConcurrentJobExecutor backed by
a :class:`python:concurrent.futures.ThreadPoolExecutor`
Parameters
----------
n_threads : Optional[int]
The number of threads to spawn in the executor,
by default None in which case as many threads as there
are CPU cores will be spawned.
Returns
-------
ConcurrentJobExecutor
the connected JobExecutor
"""
if n_threads is None:
devices = detect()
n_threads = len(devices['cpus'])
client = TracedThreadPoolExecutor(tracer, max_workers=n_threads)
return cls(client=client, is_local=True)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
[docs]
class AsyncConcurrentJobExecutor(AsyncAdapter):
def __init__(self, wrapped=None, *args, **kwargs):
if wrapped is None:
wrapped = ConcurrentJobExecutor(*args, **kwargs)
super().__init__(wrapped)
@classmethod
async def make_local(cls):
executor = await sync_to_async(functools.partial(
ConcurrentJobExecutor.make_local,
))
return cls(wrapped=executor)