Source code for beaker._queue

from __future__ import annotations

import logging
import threading
import time
from contextlib import AbstractContextManager as ContextManager
from contextlib import contextmanager
from dataclasses import dataclass
from queue import Empty as QueueEmpty
from queue import SimpleQueue
from typing import Generator, Iterable, Literal, overload

import grpc
from google.protobuf.duration_pb2 import Duration
from google.protobuf.empty_pb2 import Empty
from google.protobuf.struct_pb2 import Struct
from google.protobuf.timestamp_pb2 import Timestamp

from . import beaker_pb2 as pb2
from ._service_client import (
    RpcBidirectionalStreamingMethod,
    RpcMethod,
    RpcStreamingMethod,
    ServiceClient,
)
from .exceptions import *
from .types import *
from .utils import pb2_to_dict


[docs] class QueueClient(ServiceClient): """ Methods for interacting with Beaker `Queues <https://beaker-docs.apps.allenai.org/concept/queues.html>`_. Accessed via the :data:`Beaker.queue <beaker.Beaker.queue>` property. .. warning:: Do not instantiate this class directly! The :class:`~beaker.Beaker` client will create one automatically which you can access through the corresponding property. """
[docs] def get(self, queue: str) -> pb2.Queue: """ :examples: >>> with Beaker.from_env() as beaker: ... queue = beaker.queue.get(queue_id) :returns: A :class:`~beaker.types.BeakerQueue` protobuf object. :raises ~beaker.exceptions.BeakerQueueNotFound: If the queue doesn't exist. """ return self.rpc_request( RpcMethod[pb2.GetQueueResponse](self.service.GetQueue), pb2.GetQueueRequest(queue_id=self.resolve_queue_id(queue)), exceptions_for_status={grpc.StatusCode.NOT_FOUND: BeakerQueueNotFound(queue)}, ).queue
[docs] def create( self, name: str | None = None, workspace: pb2.Workspace | None = None, input_schema: dict | None = {}, output_schema: dict | None = {}, batch_size: int | None = 1, max_claimed_entries: int | None = None, wait_timeout_ms: int | None = 0, ) -> pb2.Queue: """ Create a new queue. :returns: A new :class:`~beaker.types.BeakerQueue` object. """ wait_timeout = None if wait_timeout_ms is not None: wait_timeout = Duration() wait_timeout.FromMilliseconds(wait_timeout_ms) input_schema_struct = Struct() if input_schema is not None: input_schema_struct.update(input_schema) output_schema_struct = Struct() if output_schema is not None: output_schema_struct.update(output_schema) return self.rpc_request( RpcMethod[pb2.CreateQueueResponse](self.service.CreateQueue), pb2.CreateQueueRequest( workspace_id=self.resolve_workspace_id(workspace), name=name, input_schema=input_schema_struct, output_schema=output_schema_struct, batch_size=batch_size, max_claimed_entries=max_claimed_entries if max_claimed_entries is not None else batch_size, wait_timeout=wait_timeout, ), ).queue
[docs] def delete( self, *queues: pb2.Queue, ): """ Delete queues. """ self.rpc_request( RpcMethod[pb2.DeleteQueuesResponse](self.service.DeleteQueues), pb2.DeleteQueuesRequest(queue_ids=[self.resolve_queue_id(q) for q in queues]), )
[docs] def create_worker(self, queue: pb2.Queue) -> pb2.QueueWorker: """ Create a new queue worker. :returns: A new :class:`~beaker.types.BeakerQueueWorker` object. """ return self.rpc_request( RpcMethod[pb2.CreateQueueWorkerResponse](self.service.CreateQueueWorker), pb2.CreateQueueWorkerRequest(queue_id=queue.id), exceptions_for_status={grpc.StatusCode.NOT_FOUND: BeakerQueueNotFound(queue.id)}, ).queue_worker
[docs] def list_workers(self, queue: pb2.Queue, limit: int | None = None) -> Iterable[pb2.QueueWorker]: """ List queue workers. :returns: An iterator over :class:`~beaker.types.BeakerQueueWorker` objects. """ if limit is not None and limit <= 0: raise ValueError("'limit' must be a positive integer") count = 0 for response in self.rpc_paged_request( RpcMethod[pb2.ListQueueWorkersResponse](self.service.ListQueueWorkers), pb2.ListQueueWorkersRequest( options=pb2.ListQueueWorkersRequest.Opts( queue_id=queue.id, page_size=self.MAX_PAGE_SIZE if limit is None else min(self.MAX_PAGE_SIZE, limit), ) ), exceptions_for_status={grpc.StatusCode.NOT_FOUND: BeakerQueueNotFound(queue.id)}, ): for worker in response.queue_workers: count += 1 yield worker if limit is not None and count >= limit: return
[docs] def get_entry(self, entry_id: str) -> pb2.QueueEntry: """ Get a queue entry object. :returns: A :class:`~beaker.types.BeakerQueueEntry` object. :raises ~beaker.exceptions.BeakerQueueEntryNotFound: If the entry doesn't exist or has expired. """ return self.rpc_request( RpcMethod[pb2.GetQueueEntryResponse](self.service.GetQueueEntry), pb2.GetQueueEntryRequest(queue_entry_id=entry_id), exceptions_for_status={ grpc.StatusCode.NOT_FOUND: BeakerQueueEntryNotFound( f"queue entry '{entry_id}' not found or expired" ) }, ).queue_entry
[docs] def create_entry( self, queue: pb2.Queue, *, input: dict | None = {}, expires_in_sec: int = 3600 * 24, block: bool = True, ) -> Iterable[pb2.CreateQueueEntryResponse]: """ Submit an entry to a queue and stream response events as they happen. .. important:: This method will block until the entry has been finalized. If you expect the entry will take a while to process, you should use :meth:`create_entry_async()` instead and periodically poll the entry status with :meth:`get_entry()`. :param input: The input data. :param expires_in_sec: Time until the entry expires (in seconds). Defaults to 24 hours. :param block: If ``True`` (the default), this method will block until new responses become available and continue streaming until the entry is finalized. If ``False`` this method will only yield the ``pending_entry`` response and then return. """ expiry = Timestamp() expiry.GetCurrentTime() expiry.FromSeconds(expiry.seconds + expires_in_sec) input_struct = Struct() if input is not None: input_struct.update(input) request = pb2.CreateQueueEntryRequest( queue_id=self.resolve_queue_id(queue), input=input_struct, expiry=expiry, # NOTE: 'async' is a reserved keyword in Python so we have to do this. **({} if block else {"async": True}), ) yield from self.rpc_streaming_request( RpcStreamingMethod[pb2.CreateQueueEntryResponse](self.service.CreateQueueEntry), request, exceptions_for_status={grpc.StatusCode.NOT_FOUND: BeakerQueueNotFound(queue.id)}, )
[docs] def create_entry_async( self, queue: pb2.Queue, *, input: dict | None = {}, expires_in_sec: int = 3600 * 24, ) -> pb2.QueueEntry: """ A convenience wrapper for :meth:`create_entry()` with ``block=False``. Returns the created entry right away. :returns: A new :class:`~beaker.types.BeakerQueueEntry` object. """ status_count = 0 for status in self.create_entry( queue, input=input, expires_in_sec=expires_in_sec, block=False ): status_count += 1 if status.HasField("pending_entry"): return status.pending_entry raise BeakerCreateQueueEntryFailedError( f"Failed to create queue entry (no 'pending_entry' status was produced out of {status_count} statuses)" )
[docs] def list_entries(self, queue: pb2.Queue, limit: int | None = None) -> Iterable[pb2.QueueEntry]: """ List entries within a queue. :returns: An iterator over :class:`~beaker.types.BeakerQueueEntry` objects. """ if limit is not None and limit <= 0: raise ValueError("'limit' must be a positive integer") count = 0 Opts = pb2.ListQueueEntriesRequest.Opts for response in self.rpc_paged_request( RpcMethod[pb2.ListQueueEntriesResponse](self.service.ListQueueEntries), pb2.ListQueueEntriesRequest( options=Opts( queue_id=queue.id, page_size=self.MAX_PAGE_SIZE if limit is None else min(self.MAX_PAGE_SIZE, limit), ) ), exceptions_for_status={grpc.StatusCode.NOT_FOUND: BeakerQueueNotFound(queue.id)}, ): for entry in response.entries: count += 1 yield entry if limit is not None and count >= limit: return
[docs] def worker_channel( self, queue: pb2.Queue, worker: pb2.QueueWorker, ) -> ContextManager[tuple[BeakerEntrySender, BeakerEntryReceiver]]: """ A context manager for opening a bidirectional worker channel for consuming and responding to entries. The channel returned is a tuple of a :class:`BeakerEntrySender` and a :class:`BeakerEntryReceiver`, respectively. Example: >>> with beaker.queue.worker_channel(queue, worker) as (tx, rx): ... for batch in rx.recv(max_batches=2, time_limit=10.0): ... for entry_id, entry_input in batch: ... tx.send(entry_id, output=entry_input) ... tx.send(entry_id, done=True) """ # NOTE: the extra indirection here is just to make the type hints on the public method # more concrete/clear. return self._worker_channel(queue, worker)
@contextmanager def _worker_channel( self, queue: pb2.Queue, worker: pb2.QueueWorker, ) -> Generator[tuple[BeakerEntrySender, BeakerEntryReceiver], None, None]: tx: SimpleQueue[pb2.ProcessQueueEntriesRequest | None] = SimpleQueue() rx: SimpleQueue[list[pb2.QueueWorkerInput] | None] = SimpleQueue() done_event = threading.Event() error_event = threading.Event() thread = threading.Thread( target=self._process_queue_entries, args=(worker, tx, rx, done_event, error_event), name=f"beaker-queue-worker-{worker.id}", ) thread.start() try: yield BeakerEntrySender( queue=queue, worker=worker, tx=tx, ), BeakerEntryReceiver( queue=queue, worker=worker, rx=rx, error=error_event, logger=self.logger, ) if error_event.is_set(): raise BeakerWorkerThreadError("channel thread died unexpectedly") finally: self.logger.debug( f"Closing down {self.__class__.__name__} queues and worker threads..." ) tx.put(None) done_event.set() thread.join() def _process_queue_entries( self, worker: pb2.QueueWorker, tx: SimpleQueue[pb2.ProcessQueueEntriesRequest | None], rx: SimpleQueue[list[pb2.QueueWorkerInput] | None], done: threading.Event, error: threading.Event, ): # NOTE (epwalsh): For reasons I don't fully understand we need to be very careful # when retrying these streaming requests to ensure that the `request_iter` # generator function (defined below) from the failed request (that we're about to retry) # gets exhausted before we restart the request with another `request_iter`. # Otherwise we end up in a bad state where we stop sending or receiving new streaming messages. # # Hence these extra bookkeeping flags: # # We set `iter_done` to `True` within the `request_iter` function when it gets exhausted # in order to signal to the output while-loop that we can safely recreate a new `request_iter` function. iter_done = True # We set `iter_canceled` to `True` in the outer while-loop below each time we intercept a retriable # error in order to signal to the `request_iter` function that it should complete early. iter_canceled = False retries = 0 while not done.is_set(): try: if not iter_done: self.logger.debug("Waiting for previous entry requests iterator to exit...") while not iter_done: time.sleep(0.5) iter_canceled = False iter_done = False def request_iter() -> Generator[pb2.ProcessQueueEntriesRequest, None, None]: nonlocal iter_done yield pb2.ProcessQueueEntriesRequest( init=pb2.ProcessQueueEntriesRequest.Init(worker_id=worker.id) ) self.logger.debug("Waiting for new entry process requests from thread") while not iter_canceled: try: request = tx.get( block=True, timeout=0.5, ) except QueueEmpty: continue if request is None: break self.logger.debug("Sending new entry process request from thread") yield request self.logger.debug("Exhausted entry process requests from thread") iter_done = True for response in self.rpc_bidirectional_streaming_request( RpcBidirectionalStreamingMethod[pb2.ProcessQueueEntriesResponse]( self.service.ProcessQueueEntries ), request_iter(), exceptions_for_status={ grpc.StatusCode.NOT_FOUND: BeakerQueueWorkerNotFound(worker.id) }, ): batch_inputs = list(response.batch.worker_input) self.logger.debug("Received new entry batch from thread") rx.put(batch_inputs) if done.is_set(): break rx.put(None) return except BeakerStreamConnectionClosedError as err: # These errors are expected, see https://github.com/allenai/beaker/issues/6532 iter_canceled = True self._log_and_wait(1, err, log_level=logging.DEBUG) except BeakerServerError as err: iter_canceled = True if retries < self.beaker.MAX_RETRIES: self._log_and_wait(retries, err) retries += 1 else: error.set() rx.put(None) raise except BaseException: iter_canceled = True error.set() rx.put(None) raise
[docs] def list( self, *, org: pb2.Organization | None = None, workspace: pb2.Workspace | None = None, sort_order: BeakerSortOrder | None = BeakerSortOrder.descending, sort_field: Literal["created"] = "created", limit: int | None = None, ) -> Iterable[pb2.Queue]: """ List queues. :returns: An iterator over :class:`~beaker.types.BeakerQueue` objects. """ if limit is not None and limit <= 0: raise ValueError("'limit' must be a positive integer") count = 0 for response in self.rpc_paged_request( RpcMethod[pb2.ListQueuesResponse](self.service.ListQueues), pb2.ListQueuesRequest( options=pb2.ListQueuesRequest.Opts( sort_clause=pb2.ListQueuesRequest.Opts.SortClause( sort_order=None if sort_order is None else sort_order.as_pb2(), created={} if sort_field == "created" else None, ), organization_id=self.resolve_org_id(org), workspace_id=None if workspace is None else self.resolve_workspace_id(workspace), page_size=self.MAX_PAGE_SIZE if limit is None else min(self.MAX_PAGE_SIZE, limit), ), ), ): for queue in response.queues: count += 1 yield queue if limit is not None and count >= limit: return
[docs] @dataclass class BeakerEntrySender: """ Queue entry sender. Use this to respond to queue entries consumed by a worker. .. warning:: Do not instantiated this class directly! Use :meth:`~QueueClient.worker_channel()` to create one. """ queue: pb2.Queue worker: pb2.QueueWorker tx: SimpleQueue[pb2.ProcessQueueEntriesRequest | None] @overload def send( self, entry_id: str, *, output: dict, ): ... @overload def send( self, entry_id: str, *, rejection: str, ): ... @overload def send( self, entry_id: str, *, done: Literal[True], ): ...
[docs] def send( self, entry_id: str, *, output: dict | None = None, rejection: str | None = None, done: bool = False, ): """ Send output to an entry, reject, or mark the entry as done. .. important:: Only one of ``output``, ``rejection``, or ``done`` can be specified at a time, and you should eventually set ``done=True`` (or ``rejection=...``) on every entry. :param entry_id: The ID of the entry. :param output: Worker response data for the entry. Mutually exclusive with the other keyword args. :param rejection: Marks the entry as rejected. This should be a human-readable reason for rejecting the entry. Mutually exclusive with the other keyword args. :param done: Mark the entry as done. Mutually exclusive with the other keyword args. """ if sum([(done is True), (output is not None), (rejection is not None)]) != 1: raise ValueError("exactly one of `output`, `rejection`, or `done` can be specified") request = pb2.ProcessQueueEntriesRequest( worker_output=pb2.QueueWorkerOutput( metadata=pb2.QueueEntryMetadata( queue_id=self.queue.id, entry_id=entry_id, worker_id=self.worker.id ), output=output, rejection=rejection, done=Empty() if done else None, ), ) if output is not None: self.tx.put(request) elif rejection is not None: self.tx.put(request) elif done: self.tx.put(request)
[docs] @dataclass class BeakerEntryReceiver: """ Queue entry receiver. Use this to consume queue entries as a worker. .. warning:: Do not instantiated this class directly! Use :meth:`~QueueClient.worker_channel()` to create one. """ queue: pb2.Queue worker: pb2.QueueWorker rx: SimpleQueue[list[pb2.QueueWorkerInput] | None] error: threading.Event logger: logging.Logger
[docs] def recv( self, *, max_batches: int | None = None, time_limit: float | None = None, ) -> Generator[list[tuple[str, dict | None]], None, None]: """ Receive batches of queue entries as they become available. Returns a generator of lists of tuples in the form ``(entry_id: str, input_data: dict | None)``. This will wait indefinitely on more batches unless ``max_batches`` or ``time_limit`` is set. :param max_batches: Stop receiving after this many batches. :param time_limit: Stop receiving after this many seconds. """ batches_received = 0 start_time = time.monotonic() def elapsed_time() -> float: return time.monotonic() - start_time def time_left() -> float | None: return None if time_limit is None else max(time_limit - elapsed_time(), 0.0) def wait_timeout() -> float: if (seconds_remaining := time_left()) is not None: return min(seconds_remaining, 1.0) else: return 1.0 def should_wait() -> bool: if self.error.is_set(): return False elif max_batches is not None and batches_received >= max_batches: self.logger.debug( f"{self.__class__.__name__}.receive() finished due to max batches" ) return False elif (seconds_remaining := time_left()) is not None and seconds_remaining <= 0: self.logger.debug(f"{self.__class__.__name__}.receive() finished due to time limit") return False else: return True while should_wait(): try: batch = self.rx.get( block=True, timeout=wait_timeout(), ) except QueueEmpty: continue if batch is not None: batches_received += 1 entries = [] for worker_input in batch: entry_id = worker_input.metadata.entry_id entry_input = ( None if not worker_input.HasField("input") else pb2_to_dict(worker_input.input) ) entries.append((entry_id, entry_input)) yield entries else: break if self.error.is_set(): raise BeakerWorkerThreadError("channel thread died unexpectedly")