Source code for gantry.recipe

import dataclasses
from dataclasses import dataclass
from typing import Any, Literal, Sequence

from beaker import Beaker, BeakerWorkload

from . import constants, utils
from .aliases import PathOrStr
from .callbacks import Callback
from .exceptions import *
from .git_utils import GitRepoState
from .launch import launch_experiment


[docs] @dataclass class Recipe: """ A recipe defines how Gantry creates a Beaker workload and can be used to programmatically launch Gantry runs from Python as opposed to from the command-line. """ # Workload settings args: Sequence[str] name: str | None = None description: str | None = None workspace: str | None = None budget: str | None = None group_names: Sequence[str] | None = None # Launch settings. allow_dirty: bool = False yes: bool | None = None save_spec: PathOrStr | None = None # Callbacks. callbacks: Sequence[Callback] | None = None # Constraints. clusters: Sequence[str] | None = None gpu_types: Sequence[str] | None = None interconnect: Literal["ib", "tcpxo"] | None = None tags: Sequence[str] | None = None hostnames: Sequence[str] | None = None # Resources. cpus: float | None = None gpus: int | None = None memory: str | None = None shared_memory: str | None = None # Inputs. beaker_image: str | None = None docker_image: str | None = None datasets: Sequence[str] | None = None env_vars: Sequence[str | tuple[str, str]] | None = None env_secrets: Sequence[str | tuple[str, str]] | None = None dataset_secrets: Sequence[str | tuple[str, str]] | None = None mounts: Sequence[str | tuple[str, str]] | None = None weka: Sequence[str | tuple[str, str]] | None = None uploads: Sequence[str | tuple[str, str]] | None = None ref: str | None = None branch: str | None = None git_repo: GitRepoState | None = None gh_token_secret: str = constants.GITHUB_TOKEN_SECRET aws_config_secret: str | None = None aws_credentials_secret: str | None = None google_credentials_secret: str | None = None # Outputs. results: str = constants.RESULTS_DIR # Task settings. task_name: str = "main" priority: str | None = None task_timeout: str | None = None preemptible: bool | None = None retries: int | None = None # Multi-node config. replicas: int | None = None leader_selection: bool | None = None host_networking: bool | None = None propagate_failure: bool | None = None propagate_preemption: bool | None = None synchronized_start_timeout: str | None = None skip_tcpxo_setup: bool = dataclasses.field(default=False, repr=False) # deprecated skip_nccl_setup: bool = False # Runtime. runtime_dir: str = constants.RUNTIME_DIR exec_method: Literal["exec", "bash"] = "exec" torchrun: bool = False # Setup hooks. pre_setup: str | None = None post_setup: str | None = None # Python settings. python_manager: Literal["uv", "conda"] | None = None default_python_version: str = utils.get_local_python_version() system_python: bool = False install: str | None = None no_python: bool = False # Python UV settings. uv_venv: str | None = None uv_extras: Sequence[str] | None = None uv_all_extras: bool | None = None uv_torch_backend: str | None = None # Python Conda settings. conda_file: PathOrStr | None = None conda_env: str | None = None
[docs] @classmethod def multi_node_torchrun( cls, cmd: Sequence[str], gpus_per_node: int, num_nodes: int, shared_memory: str | None = "10GiB", **kwargs, ) -> "Recipe": """ Create a multi-node recipe using torchrun. """ return cls( args=cmd, gpus=gpus_per_node, replicas=num_nodes, shared_memory=shared_memory, torchrun=True, **kwargs, )
def _get_launch_args(self) -> Sequence[str]: if isinstance(self.args, str): raise ConfigurationError("args must be a sequence of strings, not a single string") return self.args def _get_launch_kwargs(self) -> dict[str, Any]: kwargs = dataclasses.asdict(self) kwargs["callbacks"] = self.callbacks kwargs["git_repo"] = self.git_repo kwargs.pop("args") return kwargs
[docs] def dry_run(self, client: Beaker | None = None) -> None: """ Do a dry-run to validate options. """ launch_experiment( self._get_launch_args(), **self._get_launch_kwargs(), client=client, dry_run=True, )
[docs] def launch( self, show_logs: bool | None = None, timeout: int | None = None, start_timeout: int | None = None, inactive_timeout: int | None = None, inactive_soft_timeout: int | None = None, client: Beaker | None = None, ) -> BeakerWorkload: """ Launch an experiment on Beaker. Same as the ``gantry run`` command. :returns: The Beaker workload. """ workload = launch_experiment( self._get_launch_args(), **self._get_launch_kwargs(), show_logs=show_logs, timeout=timeout, start_timeout=start_timeout, inactive_timeout=inactive_timeout, inactive_soft_timeout=inactive_soft_timeout, client=client, ) assert workload is not None return workload
[docs] def with_replicas( self, replicas: int, leader_selection: bool = True, host_networking: bool = True, propagate_failure: bool = True, propagate_preemption: bool = True, synchronized_start_timeout: str = "5m", skip_nccl_setup: bool = False, ) -> "Recipe": """ Add replicas to the recipe. """ if replicas < 2: raise ConfigurationError("replicas must be at least 2") return dataclasses.replace( self, replicas=replicas, leader_selection=leader_selection, host_networking=host_networking, propagate_failure=propagate_failure, propagate_preemption=propagate_preemption, synchronized_start_timeout=synchronized_start_timeout, skip_nccl_setup=skip_nccl_setup, )