Source code for libertem.executor.pipelined

import itertools
import os
import sys
import logging
import functools
import contextlib
import multiprocessing as mp
from typing import (
    TYPE_CHECKING, Any, Callable, NamedTuple, Optional, TypeVar, Union,
from import Iterable, Generator
from typing_extensions import TypedDict, Literal
import uuid
import warnings
import time

from tblib import pickling_support
import cloudpickle
from opentelemetry import trace
from libertem.common.backend import set_use_cpu, set_use_cuda, set_file_limit

from libertem.common.executor import (
    Environment, TaskProtocol, WorkerContext, WorkerQueue,
    WorkerQueueEmpty, TaskCommHandler, SimpleMPWorkerQueue,
    JobCancelledError, ResourceDef,
from libertem.common.scheduler import Worker, WorkerSet, Scheduler
from libertem.common.tracing import add_partition_to_span, attach_to_parent, maybe_setup_tracing

from .utils import assign_cudas
from .base import BaseJobExecutor, ResourceError

    import prctl
except ImportError:
    prctl = None

    from opentelemetry.trace import SpanContext

tracer = trace.get_tracer(__name__)
logger = logging.getLogger(__name__)

T = TypeVar('T')

[docs] class WorkerQueues(NamedTuple): request: "WorkerQueue" response: "WorkerQueue" message: "WorkerQueue"
[docs] class WorkerSpec(TypedDict): name: str device_id: int device_kind: Union[Literal['CPU'], Literal['CUDA']] worker_idx: int has_cupy: bool
[docs] class PoolWorkerInfo(NamedTuple): queues: WorkerQueues process: mp.Process spec: WorkerSpec
[docs] class PoolStateError(Exception): """ The worker pool is not in the expected state to perform the requested operation """ pass
def _resources_for_spec(spec: WorkerSpec) -> ResourceDef: worker_resources: "ResourceDef" = {} if spec['device_kind'] == 'CPU': worker_resources["compute"] = 1 worker_resources["CPU"] = 1 worker_resources["ndarray"] = 1 if spec["device_kind"] == "CUDA": worker_resources["compute"] = 1 worker_resources["CUDA"] = 1 if spec["has_cupy"]: worker_resources["ndarray"] = 1 return worker_resources
[docs] class WorkerPool: """ Combination of worker processes and matching request queues, and a single response queue. Processes are started using the spawn method, so they need to use primitives that are compatible with spawn. Take care to properly close queues and join workers at shutdown. Note ---- We are not using the vanilla :class:`multiprocesing.Pool` here, because we need to coordinate the sending and receiving side for each task, and we need to keep state on the worker, pin processes to cores etc. """ def __init__(self, worker_fn: Callable, spec: list[WorkerSpec]): self._worker_q_cls = SimpleMPWorkerQueue self._workers: Optional[list[PoolWorkerInfo]] = [] self._worker_fn = worker_fn self._response_q: Optional[SimpleMPWorkerQueue] = self._worker_q_cls() self._message_q = self._worker_q_cls() self._mp_ctx = mp.get_context("spawn") self._spec = spec self._start_workers() @property def response_queue(self) -> "WorkerQueue": if self._response_q is None: raise PoolStateError("Response queue not available") return self._response_q @property def message_queue(self) -> "WorkerQueue": return self._message_q @property def size(self) -> int: return len(self._spec) def _start_workers(self): with tracer.start_as_current_span("WorkerPool._start_workers") as span: span_context = span.get_span_context() for spec_item in self._spec: queues = self._make_worker_queues() p = self._mp_ctx.Process(target=self._worker_fn, kwargs={ "queues": queues, "spec": spec_item, "span_context": span_context, }) p.start() self._workers.append( PoolWorkerInfo( queues=queues, process=p, spec=spec_item, ) ) def kill_worker(self, worker_info: PoolWorkerInfo, timeout: float = 5.0): worker_info.queues.request.close(drain=False, force=True) worker_info.process.terminate() worker_info.process.join(timeout) if worker_info.process.exitcode is None: worker_info.process.kill() # reap the dead process: worker_info.process.join(30) def kill(self, timeout: float = 5): if self._workers is None: raise PoolStateError("Cannot kill workers, not running") for worker in self._workers: self.kill_worker(worker, timeout=timeout) exitcodes = [ worker.process.exitcode for worker in self._workers ] self._workers = None if self._response_q is None: raise PoolStateError("Response queue not available") self._response_q.close(drain=False, force=True) self._response_q = None if any([e is None for e in exitcodes]): raise PoolStateError("At least one worker failed to be killed") @property def workers(self) -> list[PoolWorkerInfo]: if self._workers is None: raise PoolStateError("No workers are running") return self._workers def all_alive(self) -> bool: if self._workers is None: raise PoolStateError("No workers are running") return all(qp.process.is_alive() for qp in self._workers)
[docs] def assert_all_alive(self): """ Check if any workers are dead, if so, kill the whole pool and raise a `RuntimeError`. """ if not self.all_alive(): self.kill() raise RuntimeError( "One or more workers failed to start" )
def close_resp_queue(self): self._response_q.close() def close_mesg_queue(self): self._message_q.close() def get_worker_queues(self, worker_idx: int) -> WorkerQueues: if self._workers is None: raise PoolStateError("No workers are running") return self._workers[worker_idx].queues def _make_worker_queues(self): return WorkerQueues( request=self._worker_q_cls(), response=self._response_q, message=self._message_q, )
WorkerSelector = Callable[ [list[PoolWorkerInfo], WorkerPool, TaskProtocol, int], PoolWorkerInfo ] def _select_by_queue_size( eligible: list[PoolWorkerInfo], pool: WorkerPool, task: TaskProtocol, task_idx: int, ) -> PoolWorkerInfo: # FIXME: I think this introduces some latency currently, # as the workers `get` from their queue as soon as possible, # so the following could happen: # - task 0 becomes available # - all queues have size 0, so task gets assigned to worker 0 # - worker 0 picks the task from its queue, so its queue size is 0 again # - task 1 becomes available # - all queues have size 0 (again), so task gets assigned to worker 0 # => this should only happen in the offline case, as in the live # processing case, we push more than one message per task into the # queue, so there should be more than one message in the backlog # when the `TaskCommHandler` is finished pushing messages for one task. return min(eligible, key=lambda w: w.queues.request.size()) def _select_by_round_robin( eligible: list[PoolWorkerInfo], pool: WorkerPool, task: TaskProtocol, task_idx: int, ) -> PoolWorkerInfo: return eligible[task_idx % len(eligible)]
[docs] def schedule_task( task_idx: int, task: TaskProtocol, pool: WorkerPool, scheduler: Scheduler, pool_info_for_worker: dict[Worker, PoolWorkerInfo], selector: WorkerSelector, ) -> tuple[int, WorkerQueues]: """ Returns the worker index and its queues of the worker the given task should be scheduled on. Workers are first filtered based on requested resources, then the concrete worker is selected by the function specified in the :code:`selector` parameter. """ eligible = [ pool_info_for_worker[w] for w in scheduler.workers_for_task(task) ] if len(eligible) == 0: resources = task.get_resources() raise ResourceError(f"No worker available with resources {resources}") worker = selector(eligible, pool, task, task_idx) idx = pool.workers.index(worker) return idx, worker.queues
[docs] def set_thread_name(name: str): """ Set a thread name; mostly useful for using system tools for profiling Parameters ---------- name : str The thread name """ if prctl is None: return prctl.set_name(name)
[docs] def worker_run_task( header: dict, work_mem: dict, queues: WorkerQueues, worker_idx: int, env: Environment, ): """ Called from the worker main loop when a RUN_TASK message is received Parameters ---------- header The header of the message that was received work_mem The worker's working memory, for accessing scattered data queues request and response queues worker_idx This worker's index env The Environment for preparing thread counts etc. for the UDF run """ with tracer.start_as_current_span("RUN_TASK") as span: try: span.set_attributes({ "libertem.task_size_pickled": len(header["task"]), "": os.getpid(), }) task: TaskProtocol = cloudpickle.loads(header["task"]) params_handle = header["params_handle"] params = work_mem[params_handle] partition = task.get_partition() add_partition_to_span(partition) result = task(params, env) queues.response.put({ "type": "RESULT", "result": result, "task_id": header["task_id"], "uuid": header["uuid"], "worker_id": worker_idx, }) except Exception as e: logger.exception("failure in RUN_TASK") pickling_support.install(e) queues.response.put({ "type": "ERROR", "error": e, "exception": e, "worker_id": worker_idx, "uuid": header["uuid"], }) finally: pass
[docs] def worker_run_function(header, queues, worker_idx): """ Called from the worker main loop when a RUN_FUNCTION message is received Parameters ---------- header The header of the message that was received queues request and response queues worker_idx This worker's index """ with tracer.start_as_current_span("RUN_FUNCTION"): try: fn = cloudpickle.loads(header["fn"]) result = fn() queues.response.put({ "type": "RUN_FUNCTION_RESULT", "result": result, "worker_id": worker_idx, }) except Exception as e: logger.exception("failure in RUN_FUNCTION") pickling_support.install(e) queues.response.put({ "type": "ERROR", "error": e, "exception": e, "worker_id": worker_idx, })
def _drain_after_task(queues: WorkerQueues): """ In case of an error, need to drain the request queue (anything that is left over from the detector-specific data that was sent in `TaskCommHandler.handle_task`): """ with tracer.start_as_current_span("drain after task") as span: while True: with queues.request.get() as msg: header, payload = msg header_type = header["type"] span.add_event("msg", {"type": header_type}) if header_type in ( "RUN_TASK", "SCATTER", "SCATTER_UPDATE", "SCATTER_UPDATE_PATCH", "RUN_FUNCTION", "DELETE", "SHUTDOWN", "WARMUP", ): raise RuntimeError( f"unexpected message type {header_type}" ) if header_type == "END_TASK": break
[docs] def worker_loop( queues: WorkerQueues, work_mem: dict, worker_idx: int, env: Environment ): """ The worker main loop, called when the worker setup is done. Waits for messages on the request queue until a SHUTDOWN message is received. Parameters ---------- work_mem The worker's working memory, for accessing scattered data queues request and response queues worker_idx This worker's index env The Environment for preparing thread counts etc. for the UDF run """ while True: try: with queues.request.get() as msg: header, payload = msg header_type = header["type"] if header_type == "RUN_TASK": with attach_to_parent(header["span_context"]): worker_run_task(header, work_mem, queues, worker_idx, env) _drain_after_task(queues) elif header_type == "SCATTER": # FIXME: array data could be transferred and stored in SHM instead key = header["key"] if key in work_mem: queues.response.put({ "type": "ERROR", "error": f"key {key} already stored in worker memory", "worker_id": worker_idx, }) continue work_mem[key] = header["value"] continue elif header_type == "SCATTER_UPDATE": # FIXME: array data could be transferred and stored in SHM instead key = header["key"] if key not in work_mem: queues.response.put({ "type": "ERROR", "error": f"key {key} not stored in worker memory, can't update", "worker_id": worker_idx, }) continue work_mem[key] = header["value"] continue elif header_type == "SCATTER_UPDATE_PATCH": # FIXME: array data could be transferred and stored in SHM instead key = header["key"] if key not in work_mem: queues.response.put({ "type": "ERROR", "error": f"key {key} not stored in worker memory, can't update", "worker_id": worker_idx, }) continue if not hasattr(work_mem[key], 'patch'): queues.response.put({ "type": "ERROR", "error": f"key {key} is not patcheable", "worker_id": worker_idx, }) continue work_mem[key].patch(header["patch"]) continue elif header_type == "RUN_FUNCTION": with attach_to_parent(header["span_context"]): worker_run_function(header, queues, worker_idx) elif header_type == "DELETE": key = header["key"] if key in work_mem: del work_mem[key] continue elif header_type == "SHUTDOWN": with attach_to_parent(header["span_context"]): with tracer.start_as_current_span("SHUTDOWN"): queues.request.close() queues.response.close(drain=False) queues.message.close(drain=False) break elif header_type == "WARMUP": with attach_to_parent(header["span_context"]): with tracer.start_as_current_span("WARMUP"): import libertem.udf.base # NOQA import libertem.api # NOQA import libertem.preload # NOQA with env.enter(): pass else: queues.response.put({ "type": "ERROR", "error": f"unknown message {header}", "worker_id": worker_idx, }) # probably desynchronized with the main process, so give up: raise RuntimeError(f"unknown message, shutting down worker {worker_idx}") except KeyboardInterrupt: queues.response.put({ "type": "ERROR", "error": "KeyboardInterrupt", "worker_id": worker_idx, }) raise
DeviceT = tuple[int, Union[Literal['CUDA'], Literal['CPU']]] def _setup_device(spec: WorkerSpec, pin: bool): """ Set up this worker for its given task - either CPU or GPU comptation, and maybe pin CPU workers to a given CPU core. """ if spec["device_kind"].lower() == "cpu": if hasattr(os, 'sched_setaffinity') and pin: os.sched_setaffinity(0, [spec["device_id"]]) set_use_cpu(spec["device_id"]) elif spec["device_kind"].lower() == "cuda": set_use_cuda(spec["device_id"]) else: raise RuntimeError(f"unknown device kind: {spec['device_kind']}")
[docs] def pipelined_worker( queues: WorkerQueues, pin: bool, spec: WorkerSpec, span_context: "SpanContext", early_setup: Optional[Callable] = None, ): """ Main pipelined worker function. Parameters ---------- queues request and response queues pin Whether or not the CPU worker should be pinned to a specific CPU spec This worker's spec, containing name, worker index, device kind etc. span_context The tracing span we should attach to for the setup code early_setup Function that will be called very early in the setup code, allowing to inject custom functionality or specific warmup code """ # FIXME: propagate to parent process with a pipe or similar? sys.stderr.close() sys.stderr = open(, os.O_RDWR), closefd=False) try: if early_setup: early_setup() set_file_limit() # FIXME: explicitly propagate exporter settings to the workers? # right now taken care of by environment variables... worker_idx = spec["worker_idx"] maybe_setup_tracing(service_name="pipelined_worker", service_id=f"worker-{worker_idx}") # attach to parent span context for startup: with attach_to_parent(span_context), \ tracer.start_as_current_span("pipelined_worker.startup") as span: _setup_device(spec, pin) work_mem: dict[str, Any] = {} span.set_attributes({ "": spec["name"], "libertem.spec.device_id": spec["device_id"], "libertem.spec.devide_kind": spec["device_kind"], "libertem.spec.worker_idx": spec["worker_idx"], }) set_thread_name(f"worker-{worker_idx}") worker_context = PipelinedWorkerContext(queues.request, queues.message) env = Environment( threaded_executor=False, threads_per_worker=1, worker_context=worker_context ) queues.response.put({ "type": "STARTUP_DONE", "worker_id": worker_idx, }) return worker_loop(queues, work_mem, worker_idx, env) except Exception as e: queues.response.put({ "type": "ERROR", "worker_id": worker_idx, "error": e, "exception": e, }) # drain, close and join queues: queues.request.close() queues.response.close(drain=False) queues.message.close(drain=False) sys.exit(1)
[docs] class PipelinedWorkerContext(WorkerContext): """ A context object that is made available to the Partition for custom communication to its matching DataSet class (currently uni-directional) """ def __init__(self, queue: "WorkerQueue", msg_queue: "WorkerQueue"): self._queue = queue self._msg_queue = msg_queue def get_worker_queue(self) -> WorkerQueue: return self._queue def signal(self, ident: str, topic: str, msg_dict: dict[str, Any]): msg_dict.update({'ident': ident}) self._msg_queue.put((topic, msg_dict))
ResultT = Generator[tuple[Any, TaskProtocol], None, None] ResultWithID = Generator[tuple[Any, TaskProtocol, int], None, None] def _order_results(results_in: ResultWithID) -> ResultT: """ Order the `results_in` generator by the result id, yielding ordered results. Requires indexes to be without gaps. """ last_sent_id = -1 # for results that are received out-of-order, keep sorted: result_stack: list[tuple[Any, TaskProtocol, int]] = [] span = trace.get_current_span() for result, task, task_id in results_in: if task_id == last_sent_id + 1: span.add_event("_order_results.yield", { 'task_id': task_id, }) yield (result, task) last_sent_id = task_id else: span.add_event("_order_results.postpone", { "expect": last_sent_id + 1, "is": task_id, }) result_stack = sorted( result_stack + [(result, task, task_id)], key=lambda x: x[-1] ) # top of the stack looks good, yield as long as it matches: while len(result_stack) > 0 and result_stack[0][2] == last_sent_id + 1: res = result_stack.pop(0) span.add_event("_order_results.yield", { 'task_id': res[2], }) yield res[0], res[1] last_sent_id = res[2] for result, task, task_id in result_stack: if task_id != last_sent_id + 1: raise RuntimeError( f"missing tasks? end of result but next id on result_stack is " f"{task_id}, was expecting {last_sent_id + 1}" ) span.add_event("_order_results.yield") yield (result, task) last_sent_id = task_id def _make_spec( cpus: Union[int, Iterable[int]], cudas: Union[int, Iterable[int]], has_cupy: bool = False, ) -> list[WorkerSpec]: """ Takes the output of :func:`libertem.utils.devices.detect` and makes a plan for starting workers on them. Parameters ---------- cpus: int | Iterable Iterable of integer CPU identifiers or an integer number of workers to create. If pinning is enabled, each worker processe is pinned to one of these identifiers, as accepted by :func:`python:os.sched_setaffinity`. Pinning is currently only supported on platforms that implement :func:`python:os.sched_setaffinity`. cudas: int | Iterable Interable of CUDA device identifiers for which workers should be started or an integer number of GPU workers to create across the available devices. Identifiers can be repeated to start multiple workers per GPU, which can result in better device utilization. has_cupy Is cupy available? """ spec = [] worker_idx = 0 if isinstance(cpus, int): cpus = tuple(range(cpus)) cudas = assign_cudas(cudas) grouped_cudas = itertools.groupby(cudas, lambda x: x) for device_id, group in grouped_cudas: for i in range(len(list(group))): spec.append(WorkerSpec( name=f"cuda-{device_id}-{i}", device_id=device_id, device_kind="CUDA", worker_idx=worker_idx, has_cupy=has_cupy, )) worker_idx += 1 for device_id in cpus: spec.append(WorkerSpec( name=f"cpu-{device_id}", device_id=device_id, device_kind="CPU", worker_idx=worker_idx, has_cupy=False, )) worker_idx += 1 return spec def _raise_from_msg(msg: dict, err_prefix: str): if "exception" in msg: raise msg["exception"] else: raise RuntimeError(f"{err_prefix}: {msg['error']}") def _inspect_startup(msg, span): if msg["type"] == "ERROR": _raise_from_msg(msg, "error on startup") if msg["type"] != "STARTUP_DONE": raise RuntimeError( f"unknown message type {msg['type']}, expected STARTUP_DONE" ) span.add_event("worker startup done", {"worker_id": msg["worker_id"]})
[docs] class PipelinedExecutor(BaseJobExecutor): """ Multi-process pipelined executor. Useful for live processing using `LiberTEM-live <>`_ if your processing function is not able to keep up with the incoming data stream in a single process, but also works for offline processing. Parameters ---------- spec Specification for the worker processes - can be generated by :meth:`make_spec`. pin_workers Pin each CPU worker to a specific CPU, as defined by :func:`python:os.sched_setaffinity`. Only works on OSes that implement :func:`python:os.sched_setaffinity`. startup_timeout Startup of the executor is cancelled if it doesn't finish within this limit (in detail: each worker's startup time is limited by this timeout). In seconds. cleanup_timeout When cleaning up using :meth:`close`, give up after this limit. In seconds. early_setup Callable that will be run as early as possible on each worker process. Useful for custom warmup code or testing. Note ---- This executor is not thread-safe - concurrent calls into :meth:`run_tasks` or :meth:`run_function` are not supported. """ def __init__( self, spec: Optional[list[WorkerSpec]] = None, pin_workers: bool = True, startup_timeout: float = 30.0, cleanup_timeout: float = 10.0, early_setup: Optional[Callable] = None, ) -> None: self._pin_workers = pin_workers if spec is None: spec = self._default_spec() self._spec = spec self._closed = True self._early_setup = early_setup # for testing via monkeypatching, for example: self._worker_selector: Optional[WorkerSelector] = None # timeout for cleanup, either from exception or when joining processes self._cleanup_timeout = cleanup_timeout # timeout for starting a single worker self._startup_timeout = startup_timeout # keep this at the bottom: self._pool = self._start_pool() def _start_pool(self) -> WorkerPool: with tracer.start_as_current_span("PipelinedExecutor.start_pool") as span: pool = WorkerPool( worker_fn=functools.partial( pipelined_worker, pin=self._pin_workers, early_setup=self._early_setup, ), spec=self._spec, ) # if any processes are already dead here, raise an exception so we # don't have to run into a timeout below: pool.assert_all_alive() warn_time = 5.0 check_interval = 0.1 warn_deadline = time.monotonic() + warn_time startup_deadline = time.monotonic() + self._startup_timeout warned = False num_started = 0 while num_started < pool.size: pool.assert_all_alive() try: with pool.response_queue.get(timeout=check_interval) as (msg, _): _inspect_startup(msg, span) num_started += 1 continue except WorkerQueueEmpty: if time.monotonic() > startup_deadline: pool.assert_all_alive() pool.kill() # break possibly confusing exception chain using "from None": raise RuntimeError( f"Timeout while starting workers, might need to increase " f"`startup_timeout` (is {self._startup_timeout}s)" ) from None if time.monotonic() > warn_deadline and not warned: warnings.warn('Slow worker startup, please be patient...', RuntimeWarning) warned = True for qp in pool.workers: qp.queues.request.put({ "type": "WARMUP", "span_context": span.get_span_context(), }) # set here, so we don't try to close the pool if it doesn't exist self._closed = False return pool def _restart_pool(self): self._pool.kill() self._pool = self._start_pool() @classmethod def _default_spec(cls): from libertem.utils.devices import detect detected = detect() return _make_spec(**detected)
[docs] @classmethod def make_local(cls, **kwargs): """ Create a :code:`PipelinedExecutor` with the default spec. """ spec = cls._default_spec() return cls(spec=spec, **kwargs)
@classmethod def make_spec( cls, cpus: Union[int, Iterable[int]], cudas: Union[int, Iterable[int]], has_cupy: bool = False, ): return _make_spec(cpus=cpus, cudas=cudas, has_cupy=has_cupy) make_spec.__doc__ = _make_spec.__doc__ def _validate_worker_state(self): if not self._pool.all_alive(): raise RuntimeError("some workers are stopped, cannot continue") def _run_tasks_inner( self, tasks: Iterable[TaskProtocol], params_handle: Any, cancel_id: Any, task_comm_handler: "TaskCommHandler", ) -> ResultWithID: # In theory, `in_flight` could be calculated from `id_to_task`, but in case # of exceptions, it becomes a bit harder to keep attribution of messages to # tasks, which is why we have a separate counter for now. in_flight = 0 id_to_task: dict[int, TaskProtocol] = {} tasks_uuid = str(uuid.uuid4()) all_workers = self.get_available_workers() scheduler = Scheduler(all_workers=all_workers) pool_info_for_worker = dict(zip(all_workers, self._pool.workers)) selector: WorkerSelector # feature detect: can we access the queue size? support for this can be # missing depending on platform if self._worker_selector is None: try: self._pool.workers[0].queues.request.size() selector = _select_by_queue_size except NotImplementedError: # fall back to simple round robin worker selection selector = _select_by_round_robin else: selector = self._worker_selector try: self._validate_worker_state() task_comm_handler.start() span = trace.get_current_span() span_context = span.get_span_context() def yield_result_if_found(block, timeout): nonlocal in_flight try: with self._pool.response_queue.get(block=block, timeout=timeout) as (result, _): if result.get('uuid') != tasks_uuid: # mismatch, log and ignore: logger.warning( "mismatched result, ignoring: %s != %s", result.get('uuid'), tasks_uuid, ) return in_flight -= 1 if result["type"] == "ERROR": _raise_from_msg(result, "failed to run tasks") result_task_id = result["task_id"] yield (result["result"], id_to_task[result_task_id], result_task_id) del id_to_task[result_task_id] if len(id_to_task) != in_flight: raise RuntimeError( "state mismatch; `id_to_task` mapping should match `in_flight`" ) except WorkerQueueEmpty: self._validate_worker_state() with tracer.start_as_current_span('schedule_tasks') as span: for task_idx, task in enumerate(tasks): # important: call `schedule_task` before incrementing number of # in-flight tasks, otherwise, if we run into an exception, we # will wait for a response that never comes. # select a worker for `task` according to `selector`: _worker_idx, worker_queues = schedule_task( task_idx, task, self._pool, scheduler, pool_info_for_worker, selector=selector, ) worker_queues.request.put({ "type": "RUN_TASK", "uuid": tasks_uuid, "task": cloudpickle.dumps(task), "task_id": task_idx, "params_handle": params_handle, "span_context": span_context, }) span.add_event('RUN_TASK', { 'task_id': task_idx, }) in_flight += 1 id_to_task[task_idx] = task if len(id_to_task) != in_flight: raise RuntimeError( "state mismatch; `id_to_task` mapping should match `in_flight`" ) # FIXME: semantics of this - is this enough? # does it matter if this is enough? we can change it in the future if not # could be: the function returns once it has forwarded # all the data necessary for the given task, # (or, in the offline case, immediately) try: task_comm_handler.handle_task(task, worker_queues.request) except JobCancelledError: worker_queues.request.put({ "type": "END_TASK", "task_id": task_idx, "params_handle": params_handle, "span_context": span_context, }) raise # NOTE: sentinel message; in case of errors, the worker # needs to discard the data from the queue until it receives # this message: worker_queues.request.put({ "type": "END_TASK", "task_id": task_idx, "params_handle": params_handle, "span_context": span_context, }) yield from yield_result_if_found(block=False, timeout=None) # at the end, block to get the remaining results: while in_flight > 0: yield from yield_result_if_found(block=True, timeout=0.1) except JobCancelledError: # don't drain here, as the next acquisition could start # very soon; instead, we just ignore the mismatched # responses in the next run raise except Exception as e: # In case of an exception, we need to drain the response queue, # so the next `run_tasks` call isn't polluted by old responses. # -> just like in the happy case, we need to wait for responses # for all in-flight tasks. In case the error happened between incrementing # `in_flight` and actually sending the task to the queue, we should # have a timeout here to not wait infinitely long. try: self._drain_response_queue(in_flight=in_flight) except RuntimeError as e2: raise e2 from e # if from a worker, this is the first exception that got put into the queue raise finally: task_comm_handler.done() def _drain_response_queue(self, in_flight: int) -> None: """ Drain response queue and log any errors returned from workers Parameters ---------- in_flight : int The number of requests that are still in flight """ t0 = time.time() while in_flight > 0: try: timeout = max(0.010, self._cleanup_timeout - (time.time() - t0)) with self._pool.response_queue.get( timeout=timeout, ) as (result, _): t0 = time.time() in_flight -= 1 # we only raise the first exception; log the others here: if result["type"] == "ERROR": logger.error(f"Error response from worker: {result['error']}") except WorkerQueueEmpty: # kill and restart workers: self._restart_pool() raise RuntimeError( f'Worker or queue presumably in a bad state, lost {in_flight} in-flight tasks.' )
[docs] def run_tasks( self, tasks: Iterable[TaskProtocol], params_handle: Any, cancel_id: Any, task_comm_handler: "TaskCommHandler", ) -> ResultT: with tracer.start_as_current_span("PipelinedExecutor.run_tasks"): with task_comm_handler.monitor(self._pool.message_queue): yield from _order_results(self._run_tasks_inner( tasks, params_handle, cancel_id, task_comm_handler, ))
[docs] def get_available_workers(self) -> WorkerSet: return WorkerSet([ Worker( name=worker_info.spec["name"], host="localhost", resources=_resources_for_spec(worker_info.spec), nthreads=1, ) for worker_info in self._pool.workers ])
[docs] def close(self): with tracer.start_as_current_span("PipelinedExecutor.close") as span: if self._closed: return for idx, worker_info in enumerate(self._pool.workers): span.add_event("sending SHUTDOWN", {"idx": idx}) worker_info.queues.request.put({ "type": "SHUTDOWN", "span_context": span.get_span_context(), }) span.add_event("SHUTDOWN sent", {"idx": idx}) while True: try: with worker_info.queues.response.get(block=False) as msg: logger.warning(f"got message on close: {msg[0]}") except WorkerQueueEmpty: break worker_info.process.join(timeout=self._cleanup_timeout) if worker_info.process.exitcode is None: self._pool.kill_worker(worker_info) worker_info.queues.request.close(force=True) self._pool.close_resp_queue() self._pool.close_mesg_queue() self._closed = True
def __del__(self): # `self` may be already partially garbage-collected; only close # if "enough" of `self` still exists: if hasattr(self, '_closed') and not self._closed: self.close() def _run_function(self, fn: Callable[..., T], worker_idx, *args, **kwargs) -> T: self._validate_worker_state() qs = self._pool.get_worker_queues(worker_idx) f = functools.partial(fn, *args, **kwargs) pickled = cloudpickle.dumps(f) qs.request.put({ "type": "RUN_FUNCTION", "fn": pickled, "span_context": trace.get_current_span().get_span_context(), }) # FIXME: timeout? with qs.response.get() as (response, _): if response["type"] == "ERROR": if "exception" in response: raise response["exception"] else: raise RuntimeError(f"failed to run function: {response['error']}") if not response["type"] == "RUN_FUNCTION_RESULT": raise RuntimeError(f"invalid response type: {response['type']}") result: T = response["result"] return result
[docs] def run_function(self, fn: Callable[..., T], *args, **kwargs) -> T: # FIXME: this is not concurrency-safe currently! beware! with tracer.start_as_current_span("PipelinedExecutor.run_function"): return self._run_function(fn, 0, *args, **kwargs)
[docs] def run_each_worker(self, fn, *args, **kwargs): # FIXME: not as fast as it could be, but also not really perf-sensitive? result = {} for idx, worker_info in enumerate(self._pool.workers): result[worker_info.spec["name"]] = self._run_function( fn, worker_idx=idx, *args, **kwargs ) return result
[docs] def run_each_host(self, fn, *args, **kwargs): return {"localhost": self.run_function(fn, *args, **kwargs)}
[docs] @contextlib.contextmanager def scatter(self, obj): self._validate_worker_state() key = str(uuid.uuid4()) with tracer.start_as_current_span("scatter_put_to_workers"): for worker_info in self._pool.workers: worker_info.queues.request.put({ "type": "SCATTER", "key": key, "value": obj, }) try: yield key finally: if not self._closed: for worker_info in self._pool.workers: worker_info.queues.request.put({ "type": "DELETE", "key": key, })
[docs] def scatter_update(self, handle: str, obj): self._validate_worker_state() for worker_info in self._pool.workers: worker_info.queues.request.put({ "type": "SCATTER_UPDATE", "key": handle, "value": obj, })
[docs] def scatter_update_patch(self, handle: str, patch): self._validate_worker_state() for worker_info in self._pool.workers: worker_info.queues.request.put({ "type": "SCATTER_UPDATE_PATCH", "key": handle, "patch": patch, })
[docs] def map(self, fn, iterable): # FIXME: replace with efficient impl if needed with tracer.start_as_current_span(""): return [ self._run_function(fn=lambda: fn(item), worker_idx=0) for item in iterable ]