Source code for beaker.client

from __future__ import annotations

import logging
import os
import threading
import time
from contextlib import contextmanager
from functools import cached_property
from typing import ClassVar, Generator, TypeVar

import grpc
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

from . import beaker_pb2_grpc
from ._cluster import ClusterClient
from ._dataset import DatasetClient
from ._experiment import ExperimentClient
from ._group import GroupClient
from ._image import ImageClient
from ._job import JobClient
from ._node import NodeClient
from ._organization import OrganizationClient
from ._queue import QueueClient
from ._secret import SecretClient
from ._user import UserClient
from ._workload import WorkloadClient
from ._workspace import WorkspaceClient
from .config import Config, InternalConfig
from .exceptions import *
from .version import VERSION

__all__ = ["Beaker"]


_LATEST_VERSION_CHECKED = False
T = TypeVar("T")


[docs] class Beaker: """ A client for interacting with `Beaker <https://beaker.org>`_. This should be used as a context manager to ensure connections are properly closed on exit. .. tip:: Use :meth:`from_env()` to create a client instance. :param config: The Beaker :class:`Config`. :param check_for_upgrades: Automatically check that beaker-py is up-to-date. You'll see a warning if it isn't. :param user_agent: Override the "User-Agent" header used in requests to the Beaker server. """ API_VERSION: ClassVar[str] = "v3" CLIENT_VERSION: ClassVar[str] = VERSION VERSION_CHECK_INTERVAL: ClassVar[int] = 12 * 3600 # 12 hours RPC_MAX_SEND_MESSAGE_LENGTH: ClassVar[int] = 64 * 1024 * 1024 # 64MiB RECOVERABLE_SERVER_ERROR_CODES: ClassVar[tuple[int, ...]] = (429, 500, 502, 503, 504) MAX_RETRIES: ClassVar[int] = 5 BACKOFF_FACTOR: ClassVar[int] = 1 BACKOFF_MAX: ClassVar[int] = 120 TIMEOUT: ClassVar[float] = 5.0 POOL_MAXSIZE: ClassVar[int] = min(100, (os.cpu_count() or 16) * 6) logger = logging.getLogger("beaker") def __init__( self, config: Config, check_for_upgrades: bool = True, user_agent: str = f"beaker-py v{VERSION}", ): self.user_agent = user_agent self._config = config self._channel: grpc.Channel | None = None self._service: beaker_pb2_grpc.BeakerStub | None = None self._thread_local = threading.local() self._thread_local.http_session = None # requests.Session not thread safe # See if there's a newer version, and if so, suggest that the user upgrades. if check_for_upgrades: self._check_for_upgrades() def _get_latest_version(self) -> str: response = requests.get( "https://pypi.org/simple/beaker-py", headers={"Accept": "application/vnd.pypi.simple.v1+json"}, timeout=2, ) response.raise_for_status() return response.json()["versions"][-1] def _check_for_upgrades(self, force: bool = False) -> Exception | bool | None: global _LATEST_VERSION_CHECKED if not force and _LATEST_VERSION_CHECKED: return None import warnings import packaging.version try: config = InternalConfig.load() if ( not force and config is not None and config.version_checked is not None and (time.time() - config.version_checked <= self.VERSION_CHECK_INTERVAL) ): return None should_upgrade: bool | None = None latest_version = packaging.version.parse(self._get_latest_version()) current_version = packaging.version.parse(self.CLIENT_VERSION) if latest_version > current_version and ( not latest_version.is_prerelease or current_version.is_prerelease ): warnings.warn( f"You're using beaker-py v{current_version}, " f"but a newer version (v{latest_version}) is available.\n\n" f"Please upgrade with `pip install --upgrade beaker-py`.", UserWarning, ) should_upgrade = True else: should_upgrade = False _LATEST_VERSION_CHECKED = True if config is not None: config.version_checked = time.time() config.save() return should_upgrade except Exception as e: return e
[docs] @classmethod def from_env( cls, check_for_upgrades: bool = True, user_agent: str = f"beaker-py v{VERSION}", **overrides, ) -> Beaker: """ Initialize client from a config file and/or environment variables. :examples: >>> with Beaker.from_env(default_workspace="ai2/my-workspace") as beaker: ... print(beaker.user_name) :param check_for_upgrades: Automatically check that beaker-py is up-to-date. You'll see a warning if it isn't. :param user_agent: Override the "User-Agent" header used in requests to the Beaker server. :param overrides: Fields in the :class:`Config` to override. .. note:: This will use the same config file that the Beaker command-line client creates and uses, which is usually located at ``$HOME/.beaker/config.yml``. If you haven't configured the command-line client, then you can alternately just set the environment variable ``BEAKER_TOKEN`` to your Beaker `user token <https://beaker.org/user>`_. """ return cls( Config.from_env(**overrides), check_for_upgrades=check_for_upgrades, user_agent=user_agent, )
@property def service(self) -> beaker_pb2_grpc.BeakerStub: if self._service is None: self._channel = grpc.secure_channel( self.config.rpc_address, grpc.ssl_channel_credentials(), options=[ ("grpc.max_send_message_length", self.RPC_MAX_SEND_MESSAGE_LENGTH), # ("grpc.keepalive_time_ms", 10_000), ], ) self._service = beaker_pb2_grpc.BeakerStub(self._channel) return self._service @property def config(self) -> Config: """ The client's :class:`Config`. """ return self._config @cached_property def user_name(self) -> str: return self.user.get().name @cached_property def org_name(self) -> str: return self.organization.get().name @cached_property def organization(self) -> OrganizationClient: """ Manage organizations. """ return OrganizationClient(self) @cached_property def user(self) -> UserClient: """ Manage users. """ return UserClient(self) @cached_property def workspace(self) -> WorkspaceClient: """ Manage workspaces. """ return WorkspaceClient(self) @cached_property def cluster(self) -> ClusterClient: """ Manage clusters. """ return ClusterClient(self) @cached_property def node(self) -> NodeClient: """ Manage nodes. """ return NodeClient(self) @cached_property def dataset(self) -> DatasetClient: """ Manage datasets. """ return DatasetClient(self) @cached_property def image(self) -> ImageClient: """ Manage images. """ return ImageClient(self) @cached_property def job(self) -> JobClient: """ Manage jobs. """ return JobClient(self) @cached_property def experiment(self) -> ExperimentClient: """ Manage experiments. """ return ExperimentClient(self) @cached_property def workload(self) -> WorkloadClient: """ Manage workloads. """ return WorkloadClient(self) @cached_property def secret(self) -> SecretClient: """ Manage secrets. """ return SecretClient(self) @cached_property def group(self) -> GroupClient: """ Manage groups. """ return GroupClient(self) @cached_property def queue(self) -> QueueClient: """ Manage queues. """ return QueueClient(self)
[docs] @contextmanager def http_session(self) -> Generator[requests.Session, None, None]: if ( not hasattr(self._thread_local, "http_session") or self._thread_local.http_session is None ): self._thread_local.http_session = self._init_http_session() try: yield self._thread_local.http_session finally: self._thread_local.http_session.close() self._thread_local.http_session = None else: yield self._thread_local.http_session
def _init_http_session(self): session = requests.Session() retries = Retry( total=self.MAX_RETRIES * 2, connect=self.MAX_RETRIES, status=self.MAX_RETRIES, backoff_factor=self.BACKOFF_FACTOR, status_forcelist=self.RECOVERABLE_SERVER_ERROR_CODES, ) session.mount("https://", HTTPAdapter(max_retries=retries, pool_maxsize=self.POOL_MAXSIZE)) return session def __enter__(self) -> "Beaker": if ( not hasattr(self._thread_local, "http_session") or self._thread_local.http_session is None ): self._thread_local.http_session = self._init_http_session() return self def __exit__(self, exc_type, exc_val, exc_tb): del exc_type, exc_val, exc_tb self.close() return False
[docs] def close(self): """ Close down RPC channels and HTTP sessions. This will be called automatically when using the client as a context manager. """ # Close RPC channel. if self._channel is not None: self._channel.close() self._channel = None self._service = None # Close HTTP session. if ( hasattr(self._thread_local, "http_session") and self._thread_local.http_session is not None ): self._thread_local.http_session.close() self._thread_local.http_session = None
def __del__(self): self.close()