Source code for libertem.common.progress

import threading
from typing import TYPE_CHECKING, Callable, Any, NamedTuple, Optional
from collections.abc import Iterable
import time

from libertem.common.executor import WorkerQueueEmpty


if TYPE_CHECKING:
    from libertem.common.executor import WorkerQueue, TaskCommHandler
    from libertem.udf.base import UDFTask
    from libertem.io.dataset.base.tiling import DataTile
    from libertem.io.dataset.base.partition import Partition


class CommsDispatcher:
    """
    Monitors a :code:`WorkerQueue` in a background thread
    and launches callbacks in response to messages recieved
    Callbacks are registered as a dictionary of subscriptions
        {topic_name: [callback, ...]]}
    and are called in order they were recieved in the same thread
    that is doing the monitoring. The callbacks should be lightweight
    in order to not build up too many messages in the queue

    The functionality of this class mirrors Dask's structured logs
    feature, which has a similar message => topic => callback
    model running in the client event loop
    """
    def __init__(self, queue: 'WorkerQueue', subscriptions: dict[str, list[Callable]]):
        self._message_q = queue
        self._subscriptions = subscriptions
        self._thread = None

    def __enter__(self, *args, **kwargs):
        if self._thread is not None:
            raise RuntimeError('Cannot re-enter CommsDispatcher')
        self._thread = threading.Thread(
            target=self.monitor_queue,
            name="CommsDispatcher",
        )
        self._thread.daemon = True
        self._thread.start()

    def __exit__(self, *args, **kwargs):
        if self._thread is None:
            return
        self._message_q.put(('STOP', {}))
        self._thread.join()
        self._thread = None
        # Drain queue just in case
        while True:
            try:
                with self._message_q.get(block=False) as _:
                    ...
            except WorkerQueueEmpty:
                break

    def monitor_queue(self):
        """
        Monitor the queue for messages
        If there are no subscribers this should drain
        messages from the queue as fast as they are recieved
        """
        while True:
            with self._message_q.get(block=True) as ((topic, msg), _):
                if topic == 'STOP':
                    break
                try:
                    for callback in self._subscriptions[topic]:
                        callback(topic, msg)
                except KeyError:
                    pass


[docs] class ProgressState(NamedTuple): """ Container for progress state, used to communicate from ProgressManager to ProgressReporter """ #: float: Number of frames processed num_frames_complete: float #: int: Total number of frames to process num_frames_total: int #: int: Number of partitions completed num_part_complete: int #: int: Number of partitions in-progress num_part_in_progress: int #: int: Total number of partitions num_part_total: int #: str: A unique string identifier for the job associated #: with this progress message progress_id: str
[docs] class ProgressReporter: """ Interface for progress bar display / updating This class will receive :class:`ProgressState` instances to notify it about the start, progression and end of a job submitted to the :code:`UDFRunner`. The implementation should be adapted to display or log the progress as required for the use case. It is possible that multiple jobs are submitted to a single executor at the same time and therefore the implementation should ensure that concurrent instances of the class display correctly, or that the same instance of the class can handle updates from multiple threads concurrently. Each :class:`ProgressState` message contains a field :code:`progress_id` which is unique to each job, and therefore the implementation can use this to distinguish updates from multiple sources. """ def __init__(self): raise NotImplementedError()
[docs] def start(self, state: ProgressState): """ Signal the creation of a new job with the expected number of partitions and frames, and unique progress_id string. """ raise NotImplementedError()
[docs] def update(self, state: ProgressState): """ Signal an intermediate update to the progress of a job """ raise NotImplementedError()
[docs] def end(self, state: ProgressState): """ Signal the end of a given job This method will always be called and any updates recieved after this message should be ignored. """ raise NotImplementedError()
class TQDMProgressReporter(ProgressReporter): """ Progress bar display via tqdm Supports concurrent usage of multiple instances of this class, but does not handle multi-threaded use of the same instance to report multiple jobs. """ def __init__(self): self._bar = None # Integers used to check if bar description should be changed # integers because faster than strcmp, updated in _should_update_description self._desc_key = (-1, -1, -1) def start(self, state: ProgressState): from tqdm.auto import tqdm self._bar = tqdm(desc=self._get_description(state), total=state.num_frames_total, leave=True) def update(self, state: ProgressState): return self._update(state, clip=True, refresh=False) def _update(self, state: ProgressState, *, clip: bool, refresh: bool): if state.num_frames_total != self._bar.total: # Should never happen but handle just in case self._bar.total = state.num_frames_total self._bar.refresh() increment = self._get_increment(state, clip=clip) if increment > 0: self._bar.update(increment) if self._should_update_description(state): self._bar.set_description(self._get_description(state)) if refresh: self._bar.refresh() def end(self, state: ProgressState): self._update(state, clip=True, refresh=True) self._bar.close() def _should_update_description(self, state: ProgressState) -> bool: """ Check the state to see if the elements used by self._get_description have changed, and if so update our record (self._desc_key) and return True """ new_desc_key = ( state.num_part_complete, state.num_part_in_progress, state.num_part_total ) should_update = new_desc_key != self._desc_key if should_update: self._desc_key = new_desc_key return should_update @staticmethod def _get_description(state: ProgressState) -> str: """ Get the most recent description string for the bar, including partition information If we know that partitions are in progress include this in parentheses after n_completed """ if state.num_part_in_progress: return (f'Partitions {state.num_part_complete}({state.num_part_in_progress})' f'/{state.num_part_total}, Frames') else: return (f'Partitions {state.num_part_complete}' f'/{state.num_part_total}, Frames') def _get_increment(self, state: ProgressState, clip: bool = True): """ Get the increment to apply to the progress bar based on current state and the state of the bar itself (bar.n is the total as-tracked by tqdm) Assumes self._bar.total has first been updated to state.num_frames_total if this is necessary """ increment = int(state.num_frames_complete) - self._bar.n if clip: max_update = self._bar.total - self._bar.n else: max_update = increment + 1 return max(0, min(increment, max_update)) class ProgressManager: """ Handle updating of a progress reporter for a set of :code:`UDFTasks`, to be completed in any order. By default constructs a :code:`TQDMProgressReporter`, if no instance is passed in. The bar displays as such: Partitions: n_complete(n_in_progress) / n_total, ...\ Frames: [XXXXX..] frames_completed / total_frames ... When processing tile stacks, stacks are treated as frames as such: (pseudo_frames = tile.size // sig_size) The bar will render in a Jupyter notebook as a JS widget automatically via tqdm.auto """ def __init__( self, tasks: Iterable['UDFTask'], progress_id: str, reporter: Optional[ProgressReporter] = None, ): if not tasks: raise ValueError('Cannot display progress for empty tasks') self._progress_id = progress_id # the number of whole frames we expect each task to process self._task_max = {t.partition.get_ident(): t.task_frames for t in tasks} # _counters is our record of progress on a task, # values are floating whole frames processed # as in tile mode we can process part of a frame self._counters = {k: 0. for k in self._task_max.keys()} self._total_frames = sum(self._task_max.values()) # For converting tiles to pseudo-frames self._sig_size = tasks[0].partition.shape.sig.size # Counters for part display self._complete = set() self._in_progress = set() self._num_total = len(self._counters) if reporter is None: reporter = TQDMProgressReporter() elif not isinstance(reporter, ProgressReporter): # If not a ProgressReporter instance, # instantiate as if it has a bare __init__ # Useful to be able to inject an instance # of ProgressReporter in case we need to setup # or access the reporter somehow (e.g. for testing) reporter = reporter() assert isinstance(reporter, ProgressReporter) self.reporter = reporter reporter.start(self.state) @property def state(self) -> ProgressState: return ProgressState( sum(self._counters.values()), self._total_frames, len(self._complete), len(self._in_progress), self._num_total, self._progress_id, ) def finalize_task(self, task: 'UDFTask'): """ When a task completes and we recieve its results on the main node, this is called to update the partition progress counters and frame counter in case we didn't recieve a complete history of the partition yet """ topic = 'partition_complete' ident = task.partition.get_ident() message = {'ident': task.partition.get_ident()} if ident in self._task_max: self.handle_end_task(topic, message) def close(self): self.reporter.end(self.state) def connect(self, comms: 'TaskCommHandler'): """ Register the callbacks on this class with the TaskCommHandler which will be dispatching messages recieved from the tasks """ comms.subscribe('partition_start', self.handle_start_task) comms.subscribe('partition_complete', self.handle_end_task) comms.subscribe('tile_complete', self.handle_tile_update) def handle_start_task(self, topic: str, message: dict[str, Any]): """ Increment the num_in_progress counter # NOTE An extension to this would be to track the identities of partitions in progress / completed for a richer display / more accurate accounting """ if topic != 'partition_start': raise RuntimeError('Unrecognized topic') t_id = message['ident'] if t_id not in self._complete: # if not complete handles case task was completed # before we can process its start message, # completion is signalled in the main thread # while start messages are processed in the background self._in_progress.add(t_id) self.reporter.update(self.state) def handle_end_task(self, topic: str, message: dict[str, Any]): """ Increment the counter for the task to the max value and update the various counters / description """ if topic != 'partition_complete': raise RuntimeError('Unrecognized topic') t_id = message['ident'] remain = self._task_max[t_id] - int(self._counters[t_id]) if remain: self._counters[t_id] = self._task_max[t_id] self._in_progress.discard(t_id) self._complete.add(t_id) self.reporter.update(self.state) def handle_tile_update(self, topic: str, message: dict[str, Any]): """ Update the frame progress counter for the task and push the increment to the progress reporter Tile stacks are converted to pseudo-frames via the sig_size """ if topic != 'tile_complete': raise RuntimeError('Unrecognized topic') t_id = message['ident'] if self._counters[t_id] >= self._task_max[t_id]: return elements = message['elements'] pframes = elements / self._sig_size self._counters[t_id] += pframes self.reporter.update(self.state) class PartitionTrackerNoOp: """ A no-op class matching the PartitionProgressTracker interface Used when progress == False to avoid any additional overhead """ def signal_start(self, *args, **kwargs): ... def signal_tile_complete(self, *args, **kwargs): ... def signal_complete(self, *args, **kwargs): ... def get_time(): # Exists for testing / mocking return time.time() class PartitionProgressTracker(PartitionTrackerNoOp): """ Tracks the tile processing speed of a Partition and dispatches messages via the worker_context.signal() method under certain conditions Parameters ---------- partition : Partition The partition to track progress for min_message_interval : float, optional The minumum time between messages, by default 1 second. """ def __init__( self, partition: 'Partition', min_message_interval: float = 1., ): self._ident = partition.get_ident() try: self._worker_context = partition._worker_context except AttributeError: self._worker_context = None # Counters to track / rate-limit messages self._elements_complete = 0 self._last_message_t = None self._min_message_interval = min_message_interval def signal_start(self): """ Signal that the partition has begun processing """ if self._worker_context is None: return self._worker_context.signal( self._ident, 'partition_start', {}, ) def signal_tile_complete(self, tile: 'DataTile'): """ Register that tile.size more elements have been processed and if certain condition are met, send a signal """ if self._worker_context is None: return send_elements = self.should_send_progress(tile.size) if send_elements: self._worker_context.signal( self._ident, 'tile_complete', {'elements': send_elements}, ) def signal_complete(self): """ Signal that the partition has completed processing This is not currently called as partition completion is registered on the main node as a fallback """ if self._worker_context is None: return self._worker_context.signal( self._ident, 'partition_complete', {}, ) def should_send_progress(self, elements: int) -> int: """ Given the number elements of data that have been processed since the last message was sent, decide if a signal should be sent to the main node about the partition progress """ current_t = get_time() self._elements_complete += elements if self._last_message_t is None: # Never send a message for the first tile stack # as this might have warmup overheads associated # Include the first elements in the history, however, # to give a better accounting. The first tile stack # is essentially treated as 'free'. self._last_message_t = current_t return 0 time_since_last_m = current_t - self._last_message_t not_rate_limited = time_since_last_m > self._min_message_interval if not_rate_limited: completed_elements = self._elements_complete self._elements_complete = 0 self._last_message_t = current_t return completed_elements return 0