Source code for libertem_live.detectors.asi_tpx3.acquisition

from contextlib import contextmanager
import logging
import time
from typing import Optional

import scipy
import scipy.sparse
import numpy as np
from opentelemetry import trace
from sparseconverter import SCIPY_CSR

from libertem.common import Shape, Slice
from libertem.common.math import prod
from libertem.common.executor import (
    WorkerContext, TaskProtocol, WorkerQueue, TaskCommHandler,
    JobCancelledError,
)
from libertem.io.dataset.base import (
    DataTile, DataSetMeta, BasePartition, Partition, DataSet, TilingScheme,
)
from libertem.corrections.corrset import CorrectionSet

from libertem_live.detectors.base.acquisition import (
    AcquisitionMixin,
)
from libertem_live.detectors.base.controller import AcquisitionController
from libertem_live.hooks import ReadyForDataEnv, Hooks
from .connection import AsiTpx3DetectorConnection, AsiTpx3PendingAcquisition

from libertem_asi_tpx3 import CamClient, ChunkStackHandle

tracer = trace.get_tracer(__name__)
logger = logging.getLogger(__name__)


def get_chunk_stacks(worker_context: WorkerContext):
    """
    Consume all CHUNK_STACK messages from the request queue until we get an
    END_PARTITION message (which we also consume)
    """
    request_queue = worker_context.get_worker_queue()

    with request_queue.get() as msg:
        header, _ = msg
        header_type = header["type"]
        assert header_type == "BEGIN_TASK"
        socket_path = header["socket"]
        cam_client = CamClient(socket_path)
    try:
        while True:
            with request_queue.get() as msg:
                header, payload = msg
                header_type = header["type"]
                if header_type == "CHUNK_STACK":
                    chunk_stack = ChunkStackHandle.deserialize(payload)
                    chunks = cam_client.get_chunks(handle=chunk_stack)
                    try:
                        yield chunks
                    finally:
                        cam_client.done(chunk_stack)
                elif header_type == "END_PARTITION":
                    return
                else:
                    raise RuntimeError(
                        f"invalid header type {header['type']}; "
                        f"CHUNK_STACK or END_PARTITION expected"
                    )
    finally:
        cam_client.close()


class AsiCommHandler(TaskCommHandler):
    def __init__(
        self,
        conn: "AsiTpx3DetectorConnection",
    ):
        self.conn = conn.get_conn_impl()

    def handle_task(self, task: TaskProtocol, queue: WorkerQueue):
        with tracer.start_as_current_span("AsiCommHandler.handle_task") as span:
            span.set_attribute(
                "libertem_live.detectors.asi:socket",
                self.conn.get_socket_path(),
            )
            queue.put({
                "type": "BEGIN_TASK",
                "socket": self.conn.get_socket_path(),
            })
            put_time = 0.0
            recv_time = 0.0
            # send the data for this task to the given worker
            partition = task.get_partition()
            slice_ = partition.slice
            start_idx = slice_.origin[0]
            end_idx = slice_.origin[0] + slice_.shape[0]
            span.set_attributes({
                "libertem.partition.start_idx": start_idx,
                "libertem.partition.end_idx": end_idx,
            })
            stack_size = 1000000000000
            current_idx = start_idx
            while current_idx < end_idx:
                current_stack_size = min(stack_size, end_idx - current_idx)

                t0 = time.perf_counter()
                chunk_stack = self.conn.get_next_stack(
                    max_size=current_stack_size
                )
                if chunk_stack is None:
                    if current_idx != end_idx:
                        queue.put({
                            "type": "END_PARTITION",
                        })
                        raise JobCancelledError("premature end of frame iterator")
                    break
                assert len(chunk_stack) <= current_stack_size, \
                    f"{len(chunk_stack)} <= {current_stack_size}"
                t1 = time.perf_counter()
                recv_time += t1 - t0

                t0 = time.perf_counter()
                serialized = chunk_stack.serialize()
                queue.put({
                    "type": "CHUNK_STACK",
                }, payload=serialized)
                t1 = time.perf_counter()
                put_time += t1 - t0

                current_idx += len(chunk_stack)
            span.set_attributes({
                "total_put_time": put_time,
                "total_recv_time": recv_time,
            })

    def start(self):
        pass  # nothing to do here really

    def done(self):
        pass  # ... likewise here


[docs] class AsiTpx3Acquisition(AcquisitionMixin, DataSet): ''' Acquisition from a ASI TPX3 detector Use :meth:`libertem_live.api.LiveContext.make_acquisition` to instantiate this class! Examples -------- >>> with ctx.make_connection('asi_tpx3').open( ... data_port=TPX3_PORT, ... ) as conn: ... pending_aq = conn.wait_for_acquisition(1.0) ... aq = ctx.make_acquisition( ... conn=conn, ... pending_aq=pending_aq, ... ) Parameters ---------- conn An existing `AsiDetectorConnection` instance nav_shape The number of scan positions as a 2-tuple :code:`(height, width)` frames_per_partition A tunable for configuring the feedback rate - more frames per partition means slower feedback, but less computational overhead. Might need to be tuned to adapt to the dwell time. pending_aq A pending acquisition in passive mode, obtained from :meth:`AsiTpx3DetectorConnection.wait_for_acquisition`. hooks Acquisition hooks to react to certain events ''' def __init__( self, conn: AsiTpx3DetectorConnection, hooks: Optional[Hooks] = None, # in passive mode, we get this: pending_aq: Optional[AsiTpx3PendingAcquisition] = None, # this is for future compatibility with an active mode: controller: Optional[AcquisitionController] = None, nav_shape: Optional[tuple[int, ...]] = None, frames_per_partition: Optional[int] = None, ): assert pending_aq is not None, "only supporting passive mode for now" assert controller is None, "only supporting passive mode for now" if frames_per_partition is None: frames_per_partition = 4096 super().__init__( conn=conn, frames_per_partition=frames_per_partition, nav_shape=nav_shape, controller=None, pending_aq=pending_aq, hooks=hooks, ) self._sig_shape: tuple[int, ...] = () self._acquisition_header = pending_aq.header def initialize(self, executor) -> "DataSet": '' self._sig_shape = self._acquisition_header.get_sig_shape() self._nav_shape = self._acquisition_header.get_nav_shape() self._meta = DataSetMeta( shape=Shape(self._nav_shape + self._sig_shape, sig_dims=2), array_backends=[SCIPY_CSR], raw_dtype=np.float32, # this is a lie, the dtype can vary by chunk! dtype=np.float32, ) return self @property def dtype(self): return self._meta.dtype @property def raw_dtype(self): return self._meta.raw_dtype @property def shape(self): return self._meta.shape @property def meta(self): return self._meta def get_correction_data(self): '' return CorrectionSet() @contextmanager def _acquire_active(self): raise NotImplementedError("") def start_acquisition(self): with tracer.start_as_current_span('start_acquisition'): if self._pending_aq is None: # active mode: with tracer.start_as_current_span("AsiAcquisition.on_ready_for_data"): self._hooks.on_ready_for_data(ReadyForDataEnv(aq=self)) else: pass # passive mode def end_acquisition(self): return def check_valid(self): "" pass def need_decode(self, read_dtype, roi, corrections): "" return True # FIXME: we just do this to get a large tile size def adjust_tileshape(self, tileshape, roi): "" depth = 512 # FIXME: hardcoded, hmm... return (depth, *self.meta.shape.sig) # return Shape((self._end_idx - self._start_idx, 256, 256), sig_dims=2) def get_max_io_size(self): "" # return 12*256*256*8 # FIXME magic numbers? return 1200000 * prod(self.meta.shape.sig) * 8 def get_base_shape(self, roi): "" return (1, *self.meta.shape.sig) def get_partitions(self): "" # FIXME: only works for inline executor or similar, as we are using a zeromq socket # which is not safe to be passed to other threads num_frames = np.prod(self._nav_shape, dtype=np.uint64) num_partitions = int(num_frames // self._frames_per_partition) slices = BasePartition.make_slices(self.shape, num_partitions) for part_slice, start, stop in slices: yield AsiLivePartition( start_idx=start, end_idx=stop, meta=self._meta, partition_slice=part_slice, ) def get_task_comm_handler(self) -> "AsiCommHandler": return AsiCommHandler( conn=self._conn, )
class AsiLivePartition(Partition): def __init__( self, start_idx, end_idx, partition_slice, meta, ): super().__init__(meta=meta, partition_slice=partition_slice, io_backend=None, decoder=None) self._start_idx = start_idx self._end_idx = end_idx def shape_for_roi(self, roi): return self.slice.adjust_for_roi(roi).shape @property def shape(self): return self.slice.shape @property def dtype(self): return self.meta.raw_dtype def set_corrections(self, corrections): self._corrections = corrections def set_worker_context(self, worker_context: WorkerContext): self._worker_context = worker_context def _preprocess(self, tile_data, tile_slice): if self._corrections is None: return self._corrections.apply(tile_data, tile_slice) def _get_tiles_straight( self, tiling_scheme: TilingScheme, dest_dtype="float32", roi=None, array_backend=None ): assert len(tiling_scheme) == 1, "only supports full frames tiling scheme for now" logger.debug("reading up to frame idx %d for this partition", self._end_idx) to_read = self._end_idx - self._start_idx # FIXME: use depth? # depth = tiling_scheme.depth tile_start = self._start_idx stacks = get_chunk_stacks(self._worker_context) sig_shape = tuple(self.slice.shape.sig) sig_dims = len(sig_shape) buf = None while to_read > 0: try: stack = next(stacks) except StopIteration: if to_read != 0: raise JobCancelledError( f"we were still expecting to read {to_read} frames more!" ) for ( chunk_layout, chunk_indptr, chunk_indices, chunk_values ) in stack: chunk_values_arr = np.frombuffer(chunk_values, dtype=chunk_layout.get_value_dtype()) chunk_indices_arr = np.frombuffer( chunk_indices, dtype=chunk_layout.get_indices_dtype() ) chunk_indptr_arr = np.frombuffer( chunk_indptr, dtype=chunk_layout.get_indptr_dtype() ) if chunk_layout.get_value_dtype() != np.dtype(dest_dtype): if buf is None or buf.shape[0] < chunk_values_arr.shape[0]: buf = np.zeros_like(chunk_values_arr, dtype=dest_dtype) buf[:chunk_values_arr.shape[0]] = chunk_values_arr chunk_values_arr = buf[:chunk_values_arr.shape[0]] arr = scipy.sparse.csr_matrix( (chunk_values_arr, chunk_indices_arr, chunk_indptr_arr), shape=(chunk_layout.get_nframes(), prod(self.slice.shape.sig)), ) frames_in_tile = chunk_layout.get_nframes() to_read -= frames_in_tile tile_slice = Slice( origin=(tile_start, ) + (0, ) * sig_dims, shape=Shape((arr.shape[0], ) + sig_shape, sig_dims=sig_dims), ) assert tile_slice.origin[0] < self._end_idx yield DataTile( data=arr, tile_slice=tile_slice, scheme_idx=0, ) tile_start += frames_in_tile logger.debug("LivePartition.get_tiles: end of method") def get_tiles(self, tiling_scheme, dest_dtype="float32", roi=None, array_backend=None): yield from self._get_tiles_straight(tiling_scheme, dest_dtype, roi, array_backend) def __repr__(self): return f"<AsiLivePartition {self._start_idx}:{self._end_idx}>"