import queue
from typing import (
    Callable, Optional, Any, TYPE_CHECKING,
from import Generator, Iterable
from contextlib import contextmanager
import multiprocessing as mp

import cloudpickle
from opentelemetry import trace
import numpy as np
from typing_extensions import Protocol, Literal

from libertem.common.scheduler import WorkerSet
from libertem.common.threading import set_num_threads, mitigations
from import Partition

    from libertem.udf.base import UDFParams, UDFRunner
    from opentelemetry.trace import SpanContext

ResourceDef = dict[
        'CPU', 'compute', 'ndarray', 'CUDA',

tracer = trace.get_tracer(__name__)

class ExecutorError(Exception):

class JobCancelledError(Exception):
    raised by async executors in run_tasks() or run_each_partition() if the task was cancelled

class Environment:
    Create the environment to run a task, in particular thread count.
    def __init__(
        threads_per_worker: Optional[int],
        threaded_executor: bool,
        worker_context: Optional["WorkerContext"] = None,
        self._threads_per_worker = threads_per_worker
        self._threaded_executor = threaded_executor
        self._worker_context = worker_context

    def threads_per_worker(self) -> Optional[int]:
        int or None : number of threads that a UDF is allowed to use in the `process_*` method.
                      For numba, pyfftw, OMP, MKL, OpenBLAS, this limit is set automatically;
                      this property can be used for other cases, like manually creating
                      thread pools.
                      None means no limit is set, and the UDF can use any number of threads
                      it deems necessary (should be limited to system limits, of course).

        See also: :func:`libertem.common.threading.set_num_threads`

        .. versionadded:: 0.7.0
        return self._threads_per_worker

    def threaded_executor(self) -> bool:
        Whether or not the executor uses threading for parallelism.
        If this flag is set, mitigations for common threading issues are
        automatically applied.
        return self._threaded_executor

    def worker_context(self) -> Optional["WorkerContext"]:
        A :code:`WorkerContext` instance, if available, as supplied
        by the `JobExecutor`. This is used to manage streaming communication
        between the main process and the workers. If :code:`None` is
        returned, streaming communication with workers is not available.
        return self._worker_context

    def enter(self):
        Note: we are using the @contextmanager decorator here,
        because with separate `__enter__`, `__exit__` methods,
        we can't easily delegate to `set_num_threads`, or other
        contextmanagers that may come later.
        with set_num_threads(self._threads_per_worker):
            if self.threaded_executor:
                with mitigations():
                    yield self
                yield self

class TaskProtocol(Protocol):
    Interface for tasks
    def __call__(self, params: "UDFParams", env: Environment):

    def get_tracing_span_context(self) -> "SpanContext":

    def get_partition(self) -> Partition:

    def get_resources(self) -> ResourceDef:

T = TypeVar('T')
V = TypeVar('V')

[docs] class JobExecutor: ''' Interface to execute functions on workers. '''
[docs] def run_function(self, fn: Callable[..., T], *args, **kwargs) -> T: """ run a callable :code:`fn` on any worker """ raise NotImplementedError()
[docs] def scatter_update(self, handle, obj): ''' Update :code:`handle` to point to :code:`obj` Must have been scattered before using :meth:`scatter`. Parameters ---------- handle The handle, as returned from :meth:`scatter`. obj Some kind of Python object. Must be serializable. ''' raise NotImplementedError()
[docs] def scatter_update_patch(self, handle, patch): ''' Update :code:`handle` by remotely calling :code:`obj.patch(patch)` on the underlying object. The referenced object must have a :code:`patch` method, and must have been scattered before using :meth:`scatter`. Parameters ---------- handle The handle, as returned from :meth:`scatter`. patch Some kind of Python object. Must be serializable and must match the :code:`obj.patch()` method. ''' raise NotImplementedError()
[docs] @contextmanager def scatter(self, obj): ''' Scatter :code:`obj` throughout the cluster Parameters ---------- obj Some kind of Python object. Must be serializable. Returns ------- handle Handle for the scattered :code:`obj` ''' raise NotImplementedError()
[docs] def run_tasks( self, tasks: Iterable[TaskProtocol], params_handle: Any, cancel_id: Any, task_comm_handler: "TaskCommHandler", ): """ Run the tasks with the given parameters. Raises ------ JobCancelledError Either the job was cancelled using :meth:`AsyncJobExecutor.cancel`, or the underlying data source was interrupted. Parameters ---------- tasks The tasks to be run params_handle : [type] A handle for the task parameters, as returned from :meth:`scatter` cancel_id An identifier which can be used for cancelling all tasks together. The same identifier should be passed to :meth:`AsyncJobExecutor.cancel` """ raise NotImplementedError()
[docs] def run_each_partition(self, partitions, fn, all_nodes=False): """ Run `fn` for all partitions. Yields results in order of completion. Parameters ---------- partitions : List[Partition] List of relevant partitions. fn : callable Function to call, will get the partition as first and only argument. all_nodes : bool If all_nodes is True, run the function on all nodes that have this partition, otherwise run on any node that has the partition. If a partition has no location, the function will not be run for that partition if `all_nodes` is True, otherwise it will be run on any node. """ raise NotImplementedError()
[docs] def map(self, fn: Callable[[V], T], iterable: Iterable[V]) -> Iterable[T]: """ Run a callable :code:`fn` for each element in :code:`iterable`, on arbitrary worker nodes. Parameters ---------- fn : callable Function to call. Should accept exactly one parameter. iterable : Iterable Which elements to call the function on. """ raise NotImplementedError()
[docs] def run_each_host(self, fn, *args, **kwargs): """ Run a callable :code:`fn` once on each host, gathering all results into a dict host -> result Parameters ---------- fn : callable Function to call *args Arguments for fn **kwargs Keyword arguments for fn """ raise NotImplementedError()
[docs] def run_each_worker(self, fn, *args, **kwargs): """ Run :code:`fn` on each worker process, and pass :code:`*args`, :code:`**kwargs` to it. Useful, for example, if you need to prepare the environment of each Python interpreter, warm up per-process caches etc. Parameters ---------- fn : callable Function to call \\*args Arguments for fn \\*\\*kwargs Keyword arguments for fn Returns ------- dict Return values keyed by worker name (executor-specific) """ raise NotImplementedError()
[docs] def close(self): """ cleanup resources used by this executor, if any """
[docs] def get_available_workers(self) -> WorkerSet: """ Returns a WorkerSet that contains the available workers Each worker should correspond to a "worker process", so if the executor is using multiple processes or threads, each process/thread should be included in this list. """ raise NotImplementedError()
[docs] def get_resource_details(self) -> list[dict[str, Any]]: """ Returns a list of dicts with cluster details key of the dictionary: host: ip address or hostname where the worker is running """ raise NotImplementedError()
[docs] def ensure_sync(self): """ Returns a synchronous executor. In case of a :class:`~libertem.common.executor.JobExecutor` we just return :code:`self`; in case of :class:`~libertem.common.executor.AsyncJobExecutor` below more work is needed! """ return self
[docs] def ensure_async(self, pool=None): """ Returns an asynchronous executor; by default just wrap into :class:`~libertem.executor.base.AsyncAdapter`. """ raise NotImplementedError()
[docs] def modify_buffer_type(self, buf): """ Allow the executor to modify result buffers if necessary Currently only called for buffers on the main node """ return buf
def get_udf_runner(self) -> type['UDFRunner']: raise NotImplementedError
[docs] class AsyncJobExecutor: ''' Async version of :class:`JobExecutor`. '''
[docs] async def run_tasks(self, tasks, params_handle, cancel_id): """ Run a number of Tasks, yielding (result, task) tuples """ raise NotImplementedError()
[docs] async def run_function(self, fn: Callable[..., T], *args, **kwargs) -> T: """ Run a callable :code:`fn` on any worker """ raise NotImplementedError()
async def run_each_partition(self, partitions, fn, all_nodes=False): raise NotImplementedError()
[docs] async def map(self, fn, iterable): """ Run a callable :code:`fn` for each item in :code:`iterable`, on arbitrary worker nodes Parameters ---------- fn : callable Function to call. Should accept exactly one parameter. iterable : Iterable Which elements to call the function on. """ raise NotImplementedError()
async def run_each_host(self, fn, *args, **kwargs): raise NotImplementedError()
[docs] async def run_each_worker(self, fn, *args, **kwargs): """ Run :code:`fn` on each worker process, and pass :code:`*args`, :code:`**kwargs` to it. Useful, for example, if you need to prepare the environment of each Python interpreter, warm up per-process caches etc. Parameters ---------- fn : callable Function to call \\*args Arguments for fn \\*\\*kwargs Keyword arguments for fn """ raise NotImplementedError()
[docs] async def close(self): """ Cleanup resources used by this executor, if any. """
[docs] async def cancel(self, cancel_id): """ cancel execution identified by `cancel_id` """ pass
async def get_available_workers(self): raise NotImplementedError() async def get_resource_details(self): raise NotImplementedError() def ensure_sync(self): raise NotImplementedError()
[docs] def ensure_async(self, pool=None): """ Returns an asynchronous executor; by default just return `self`. """ return self
def get_udf_runner(self) -> type['UDFRunner']: raise NotImplementedError()
class WorkerQueueEmpty(Exception): """ A non-blocking get was called on an empty queue, or a blocking `get` with a non-zero timeout timed out. """ pass class WorkerQueue: ''' Interface for queues to send input data to workers. ''' @contextmanager def get( self, block: bool = True, timeout: Optional[float] = None, ) -> Generator[tuple[Any, memoryview], None, None]: raise NotImplementedError() def put(self, header: Any, payload: Optional[memoryview] = None): ''' Put header and payload into the queue. ''' raise NotImplementedError() @contextmanager def put_nocopy(self, header: Any, size: int) -> Generator[memoryview, None, None]: """ Put data into the queue, without an additional copy. This will yield a writable :class:`python:memoryview`, instead of requiring that the data to be sent is already available as a :class:`python:memoryview` and copying into its final destination, as is the case in :meth:`put`. This can be useful, for example, if you need to perform any kind of operation that writes into a buffer. For example: >>> q = SomeWorkerQueueImpl() # doctest: +SKIP >>> with q.put_nocopy(1024) as send_buf: # doctest: +SKIP ... # NOTE: in reality, need to handle n_bytes < 1024 ... n_bytes = some_socket.recvinto(send_buf) If the data you receive is already in a shared memory segment, you can send the handle as part of the :code:`header` and use the normal :meth:`put` method instead. Note: for releasing memory, it might be necessary to implement another queue that synchronizes the release process. """ raise NotImplementedError() def close(self, drain: bool = True, force: bool = False): """ Parameters ---------- drain If needed by the underlying queue, remove any items from the queue before closing force Don't wait for data to be flushed, forcefully close the queue """ raise NotImplementedError() def size(self) -> int: """ Approximate number of items currently in the queue Can raise `NotImplementedError` depending on operating system or underlying implementation. """ raise NotImplementedError() class SimpleWorkerQueue(WorkerQueue): """ A :class:`WorkerQueue` that uses a threading :class:`python:queue.Queue` under the hood. """ def __init__(self) -> None: self.q: queue.Queue = queue.Queue() def put(self, header, payload: Optional[memoryview] = None): self.q.put((header, payload)) @contextmanager def put_nocopy(self, header: Any, size: int) -> Generator[memoryview, None, None]: payload = np.zeros(size, dtype=np.uint8) yield memoryview(payload) self.q.put((header, payload)) @contextmanager def get(self, block: bool = True, timeout: Optional[float] = None): try: res = self.q.get(block=block, timeout=timeout) yield res except queue.Empty: raise WorkerQueueEmpty() def close(self, drain: bool = True, force: bool = False): pass def size(self) -> int: return self.q.qsize() class SimpleMPWorkerQueue(WorkerQueue): """ A :class:`WorkerQueue` that uses a :class:`python:multiprocessing.Queue` under the hood. """ def __init__(self) -> None: self._mp_ctx = mp.get_context("spawn") self.q: mp.Queue = self._mp_ctx.Queue() self._closed = False def put(self, header, payload: Optional[memoryview] = None): with tracer.start_as_current_span("SimpleMPWorkerQueue.put") as span: header_serialized = cloudpickle.dumps(header) payload_serialized = cloudpickle.dumps(payload) span.set_attributes({ 'libertem.pickle.header_size': len(header_serialized), 'libertem.pickle.payload_size': len(payload_serialized), }) self.q.put((header_serialized, payload_serialized)) @contextmanager def put_nocopy(self, header: Any, size: int) -> Generator[memoryview, None, None]: payload = np.zeros(size, dtype=np.uint8) yield memoryview(payload) header_serialized = cloudpickle.dumps(header) payload_serialized = cloudpickle.dumps(payload) self.q.put((header_serialized, payload_serialized)) @contextmanager def get(self, block: bool = True, timeout: Optional[float] = None): try: res = self.q.get(block=block, timeout=timeout) yield (cloudpickle.loads(res[0]), cloudpickle.loads(res[1])) except queue.Empty: raise WorkerQueueEmpty() def close(self, drain: bool = True, force: bool = False): if not self._closed: if drain: while True: try: self.q.get_nowait() except queue.Empty: break self.q.close() if force: self.q.cancel_join_thread() else: self.q.join_thread() self._closed = True def size(self) -> int: return self.q.qsize() class WorkerContext: """ A :class:`WorkerContext` is used to manage streaming communication between the main process and the workers. """ def get_worker_queue(self) -> WorkerQueue: raise NotImplementedError() def signal(self, ident: str, topic: str, msg_dict: dict[str, Any]): raise NotImplementedError() class TaskCommHandler: """ This is the interface that is implemented by the acquisition object to allow streaming communication with workers. """ def handle_task(self, task: TaskProtocol, queue: WorkerQueue): """ Handle the streaming connunication for the given :code:`task` using the provided :code:`queue`. This function should block until the communication for the given task has finished. It may be run in a background thread on the main node, or synchronously, depending on the :class:`JobExecutor`. May raise :class:`JobCancelledError` to signal that the acquisition has been cancelled for some reason. Parameters ---------- task : TaskProtocol queue : WorkerQueue A queue used to communicate with the worker process using an acquisition-specific protocol. The `Partition` in the worker has access to this queue, too, and can communicate using a data-source specific protocol. The `TaskCommHandler` and the `Partition` are tightly coupled by this protocol, and the queue needs to be "clean" after the data for a given task has been exchanged. """ ... def start(self): """ A lifecycle method that is called before any task has been run. """ ... def done(self): """ A lifecycle method that is called after all tasks have benn run. """ ... @property def subscriptions(self) -> dict[str, list[Callable]]: # Instantiate on first get to avoid creating __init__ try: return self._subscriptions except AttributeError: self._subscriptions = {} return self._subscriptions def subscribe(self, topic: str, callback: Callable[[str, dict], None]): """ Register a callback to run in response to messages matching the topic string identifier The callback should accept the arguments (topic, message_dict) message_dict will contain an 'ident' key with the identiy of the message sender """ try: self.subscriptions[topic].append(callback) except KeyError: self.subscriptions[topic] = [callback] @contextmanager def monitor(self, queue: WorkerQueue): """ Monitor queue in a background thread and run the callbacks in subscriptions in response to messages """ # Avoid circular import from libertem.common.progress import CommsDispatcher with CommsDispatcher(queue, self.subscriptions): yield class NoopCommHandler(TaskCommHandler): """ A :class:`TaskCommHandler` that doesn't perform any action, and doesn't stream any data. """ def handle_task(self, task: TaskProtocol, queue: WorkerQueue): pass def start(self): pass def done(self): pass