Source code for beaker._dataset

import io
import os
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Generator, Iterable, Literal
from urllib.parse import urlparse

import grpc
from requests import Response

from . import beaker_pb2 as pb2
from ._service_client import RpcMethod, ServiceClient
from .exceptions import *
from .types import *


@dataclass
class _DatasetStorage:
    id: str
    token: str
    tokenExpires: datetime
    address: str | None = None
    url: str | None = None
    urlv2: str | None = None
    totalSize: int | None = None
    numFiles: int | None = None

    def __post_init__(self):
        if self.address is not None and self.address.startswith("fh://"):
            # HACK: fix prior to https://github.com/allenai/beaker/pull/2962
            self.address = self.address.replace("fh://", "https://", 1)

    @property
    def scheme(self) -> str | None:
        return "fh" if self.urlv2 is None else urlparse(self.urlv2).scheme

    @property
    def base_url(self) -> str:
        if self.address is not None:
            return self.address
        elif self.urlv2 is not None:
            return f"https://{urlparse(self.urlv2).netloc}"
        else:
            raise ValueError("Missing field 'urlv2' or 'address'")


[docs] class DatasetClient(ServiceClient): """ Methods for interacting with Beaker `Datasets <https://beaker-docs.apps.allenai.org/concept/datasets.html>`_. Accessed via the :data:`Beaker.dataset <beaker.Beaker.dataset>` property. .. warning:: Do not instantiate this class directly! The :class:`~beaker.Beaker` client will create one automatically which you can access through the corresponding property. """ HEADER_UPLOAD_ID = "Upload-ID" HEADER_UPLOAD_LENGTH = "Upload-Length" HEADER_UPLOAD_OFFSET = "Upload-Offset" HEADER_DIGEST = "Digest" HEADER_LAST_MODIFIED = "Last-Modified" HEADER_CONTENT_LENGTH = "Content-Length" REQUEST_SIZE_LIMIT = 32 * 1024 * 1024 DOWNLOAD_CHUNK_SIZE = 10 * 1024
[docs] def get(self, dataset: str) -> pb2.Dataset: """ :examples: >>> with Beaker.from_env() as beaker: ... dataset = beaker.dataset.get(dataset_name) :returns: A :class:`~beaker.types.BeakerDataset`. :raises ~beaker.exceptions.BeakerDatasetNotFound: If the cluster doesn't exist. """ return self.rpc_request( RpcMethod[pb2.GetDatasetResponse](self.service.GetDataset), pb2.GetDatasetRequest(dataset_id=self.resolve_dataset_id(dataset)), exceptions_for_status={grpc.StatusCode.NOT_FOUND: BeakerDatasetNotFound(dataset)}, ).dataset
def _get_storage(self, dataset: pb2.Dataset) -> _DatasetStorage: dataset_info = self.http_request( f"datasets/{self._url_quote(dataset.id)}", exceptions_for_status={404: BeakerDatasetNotFound(dataset.id)}, ).json() return _DatasetStorage(**dataset_info["storage"])
[docs] def create( self, name: str, *sources: PathOrStr, target: PathOrStr | None = None, workspace: pb2.Workspace | None = None, description: str | None = None, force: bool = False, max_workers: int | None = None, commit: bool = True, strip_paths: bool = False, ) -> pb2.Dataset: """ Create a dataset from local source files. :param name: The name to assign to the new dataset. :param sources: Local source files or directories to upload to the dataset. :param target: If specified, all source files/directories will be uploaded under a directory of this name. :param workspace: The workspace to upload the dataset to. If not specified your default workspace is used. :param description: Text description for the dataset. :param force: If ``True`` and a dataset by the given name already exists, it will be overwritten. :param max_workers: The maximum number of thread pool workers to use to upload files concurrently. :param commit: Whether to commit the dataset after successfully uploading source files. :param strip_paths: If ``True``, all source files and directories will be uploaded under their name, not their path. E.g. the file "docs/source/index.rst" would be uploaded as just "index.rst", instead of "docs/source/index.rst". .. note:: This only applies to source paths that are children of the current working directory. If a source path is outside of the current working directory, it will always be uploaded under its name only. :returns: A new :class:`beaker.types.BeakerDataset` object. :raises ~beaker.exceptions.BeakerDatasetConflict: If a dataset with the given name already exists. """ self._validate_beaker_name(name) workspace_id = self.resolve_workspace_id(workspace) # Create the dataset. def make_dataset() -> tuple[pb2.Dataset, _DatasetStorage]: dataset_info = self.http_request( "datasets", method="POST", query={"name": name}, data=dict(workspace=workspace_id, description=description), exceptions_for_status={409: BeakerDatasetConflict(name)}, ).json() return self.get(dataset_info["id"]), _DatasetStorage(**dataset_info["storage"]) with self.beaker.http_session(): try: dataset, storage = make_dataset() except BeakerDatasetConflict: if force: self.delete(self.get(f"{self.beaker.user_name}/{name}")) dataset, storage = make_dataset() else: raise # Upload the file(s). if sources: self._sync( dataset, storage=storage, source_paths=sources, target=target, max_workers=max_workers, strip_paths=strip_paths, ) if commit: return self.commit(dataset) else: return dataset
[docs] def commit(self, dataset: pb2.Dataset) -> pb2.Dataset: """ Commit a dataset. :returns: The updated :class:`~beaker.types.BeakerDataset` object. """ if dataset.HasField("committed"): return dataset @self._retriable() def commit(): # It's okay to retry this because committing a dataset multiple # times does nothing. self.http_request( f"datasets/{self._url_quote(self.resolve_dataset_id(dataset))}", method="PATCH", data={"commit": True}, exceptions_for_status={404: BeakerDatasetNotFound(dataset.id)}, ) commit() return self.get(self.resolve_dataset_id(dataset))
def _sync( self, dataset: pb2.Dataset, *, storage: _DatasetStorage, source_paths: Iterable[PathOrStr], target: PathOrStr | None = None, max_workers: int | None = None, strip_paths: bool = False, ) -> int: if dataset.HasField("committed"): raise BeakerDatasetWriteError(f"Dataset '{dataset.id}' has already been committed") total_bytes = 0 # map source path to (target_path, size) path_info: dict[Path, tuple[Path, int]] = {} for source in source_paths: source = Path(source) strip_path = strip_paths or not source.is_relative_to(".") if source.is_file(): target_path = Path(source.name) if strip_path else source if target is not None: target_path = Path(str(target)) / target_path size = source.lstat().st_size path_info[source] = (target_path, size) total_bytes += size elif source.is_dir(): for path in source.glob("**/*"): if path.is_dir(): continue target_path = path.relative_to(source) if strip_path else path if target is not None: target_path = Path(str(target)) / target_path size = path.lstat().st_size if size == 0: continue path_info[path] = (target_path, size) total_bytes += size else: raise FileNotFoundError(source) import concurrent.futures # Now upload. with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: # Dispatch tasks to thread pool executor. future_to_path = {} for path, (target_path, size) in path_info.items(): future = executor.submit( self._upload_file, dataset, storage=storage, size=size, source=path, target=target_path, ignore_errors=True, ) future_to_path[future] = path # Collect completed tasks. for future in concurrent.futures.as_completed(future_to_path): path = future_to_path[future] original_size = path_info[path][1] actual_size = future.result() if actual_size != original_size: # If the size of the file has changed since we started, adjust total. total_bytes += actual_size - original_size return total_bytes
[docs] def upload(self, dataset: pb2.Dataset, source: PathOrStr | bytes, target: PathOrStr) -> int: """ Upload a file to a dataset. :param dataset: The dataset to upload to (must be uncommitted). :param source: Path to the local source file or the contents as bytes. :param target: The path within the dataset to upload the file to. :returns: The number of bytes uploaded. :raises ~beaker.exceptions.BeakerDatasetWriteError: If the dataset is already committed. """ if dataset.HasField("committed"): raise BeakerDatasetWriteError(f"Dataset '{dataset.id}' has already been committed") size = len(source) if isinstance(source, bytes) else Path(source).stat().st_size storage = self._get_storage(dataset) return self._upload_file(dataset, storage=storage, size=size, source=source, target=target)
def _upload_file( self, dataset: pb2.Dataset, *, storage: _DatasetStorage, size: int, source: PathOrStr | bytes, target: PathOrStr, ignore_errors: bool = False, ) -> int: if storage.scheme != "fh": raise NotImplementedError( f"Datasets API is not implemented for '{storage.scheme}' backend yet" ) source_file: io.BufferedReader | io.BytesIO if isinstance(source, (str, Path, os.PathLike)): source = Path(source) if ignore_errors and not source.exists(): self.logger.warning(f"Skipping uploading '{source}' since it doesn't exist") return 0 source_file = source.open("rb") elif isinstance(source, bytes): source_file = io.BytesIO(source) else: raise ValueError(f"Expected path-like or raw bytes, got {type(source)}") try: body: io.BufferedReader | io.BytesIO | None = source_file digest: str | None = None self.logger.info(f"Uploading {size} bytes to '{target}'...") with self.beaker.http_session(): if size > self.REQUEST_SIZE_LIMIT: @self._retriable() def get_upload_id() -> str: response = self.http_request( "uploads", method="POST", token=storage.token, base_url=storage.base_url, ) return response.headers[self.HEADER_UPLOAD_ID] upload_id = get_upload_id() written = 0 while written < size: chunk = source_file.read(self.REQUEST_SIZE_LIMIT) if not chunk: break @self._retriable() def upload() -> Response: return self.http_request( f"uploads/{upload_id}", method="PATCH", data=chunk, token=storage.token, base_url=storage.base_url, headers={ self.HEADER_UPLOAD_LENGTH: str(size), self.HEADER_UPLOAD_OFFSET: str(written), }, ) response = upload() written += len(chunk) digest = response.headers.get(self.HEADER_DIGEST) if digest: break if written != size: raise BeakerUnexpectedEOFError(str(source)) body = None @self._retriable() def finalize(): self.http_request( f"datasets/{storage.id}/files/{str(target)}", method="PUT", data=body if size > 0 else b"", token=storage.token, base_url=storage.base_url, headers=None if not digest else {self.HEADER_DIGEST: digest}, stream=body is not None and size > 0, exceptions_for_status={ 403: BeakerDatasetWriteError(dataset.id), 404: BeakerDatasetNotFound(dataset.id), }, ) finalize() return size finally: source_file.close()
[docs] def stream_file( self, dataset: pb2.Dataset, file_path: str, *, offset: int = 0, length: int = -1, chunk_size: int | None = None, validate_checksum: bool = True, ) -> Generator[bytes, None, None]: """ Stream download the bytes content of a file from a dataset. """ file = self.get_file_info(dataset, file_path) yield from self._stream_file( dataset, file, offset=offset, length=length, chunk_size=chunk_size, validate_checksum=validate_checksum, )
def _stream_file( self, dataset: pb2.Dataset, file: pb2.DatasetFile, chunk_size: int | None = None, offset: int = 0, length: int = -1, validate_checksum: bool = True, ) -> Generator[bytes, None, None]: def stream_file() -> Generator[bytes, None, None]: headers = {} if offset > 0 and length > 0: headers["Range"] = f"bytes={offset}-{offset + length - 1}" elif offset > 0: headers["Range"] = f"bytes={offset}-" response = self.http_request( f"datasets/{dataset.id}/files/{self._url_quote(file.path)}", method="GET", stream=True, headers=headers, exceptions_for_status={404: FileNotFoundError(file.path)}, ) for chunk in response.iter_content(chunk_size=chunk_size or self.DOWNLOAD_CHUNK_SIZE): yield chunk contents_hash = None if offset == 0 and validate_checksum and file.HasField("digest"): contents_hash = BeakerDatasetFileAlgorithmType(file.digest.algorithm).hasher() retries = 0 while True: try: for chunk in stream_file(): offset += len(chunk) if contents_hash is not None: contents_hash.update(chunk) yield chunk break except RequestException as err: if retries < self.beaker.MAX_RETRIES: self._log_and_wait(retries, err) retries += 1 else: raise # Validate digest. if file.HasField("digest") and contents_hash is not None: import binascii actual_digest = binascii.hexlify(contents_hash.digest()).decode() expected_digest = binascii.hexlify(file.digest.value).decode() if actual_digest != expected_digest: raise BeakerChecksumFailedError( f"Checksum for '{file.path}' failed. " f"Expected '{expected_digest}', got '{actual_digest}'." )
[docs] def list_files( self, dataset: pb2.Dataset, *, prefix: str | None = None ) -> Iterable[pb2.DatasetFile]: """ List files in a dataset. :returns: An iterator over :class:`~beaker.types.BeakerDatasetFile` protobuf objects. """ for response in self.rpc_paged_request( RpcMethod[pb2.ListDatasetFilesResponse](self.service.ListDatasetFiles), pb2.ListDatasetFilesRequest( options=pb2.ListDatasetFilesRequest.Opts( dataset_id=self.resolve_dataset_id(dataset), prefix=prefix, ) ), ): yield from response.dataset_files
[docs] def get_file_info(self, dataset: pb2.Dataset, file_path: str) -> pb2.DatasetFile: """ :returns: A :class:`~beaker.types.BeakerDatasetFile` protobuf object. """ prefix = os.path.dirname(file_path) for f in self.list_files(dataset, prefix=prefix): if f.path == file_path: return f raise FileNotFoundError(file_path)
[docs] def update(self, dataset: pb2.Dataset, *, description: str | None = None) -> pb2.Dataset: """ Update fields of a dataset. :returns: The updated :class:`~beaker.types.BeakerDataset` object. """ return self.rpc_request( RpcMethod[pb2.UpdateDatasetResponse](self.service.UpdateDataset), pb2.UpdateDatasetRequest( dataset_id=self.resolve_dataset_id(dataset), description=description, ), ).dataset
[docs] def delete(self, *datasets: pb2.Dataset): """ Delete datasets. """ self.rpc_request( RpcMethod[pb2.DeleteDatasetsResponse](self.service.DeleteDatasets), pb2.DeleteDatasetsRequest( dataset_ids=[self.resolve_dataset_id(dataset) for dataset in datasets] ), )
[docs] def list( self, *, org: pb2.Organization | None = None, author: pb2.User | None = None, workspace: pb2.Workspace | None = None, created_before: datetime | None = None, created_after: datetime | None = None, results: bool | None = None, committed: bool | None = None, name_or_description: str | None = None, sort_order: BeakerSortOrder | None = None, sort_field: Literal["created", "name"] = "name", limit: int | None = None, ) -> Iterable[pb2.Dataset]: """ List datasets. :returns: An iterator over :class:`~beaker.types.BeakerDataset` protobuf objects. """ Opts = pb2.ListDatasetsRequest.Opts if limit is not None and limit <= 0: raise ValueError("'limit' must be a positive integer") dataset_type = None if results is True: dataset_type = Opts.DatasetType.DATASET_TYPE_IS_RESULT elif results is False: dataset_type = Opts.DatasetType.DATASET_TYPE_IS_NOT_RESULT committed_status = None if committed is True: committed_status = Opts.CommittedStatus.STATUS_COMMITTED elif committed is False: committed_status = Opts.CommittedStatus.STATUS_UNCOMMITTED count = 0 for response in self.rpc_paged_request( RpcMethod[pb2.ListDatasetsResponse](self.service.ListDatasets), pb2.ListDatasetsRequest( options=Opts( sort_clause=Opts.SortClause( sort_order=None if sort_order is None else sort_order.as_pb2(), created={} if sort_field == "created" else None, name={} if sort_field == "name" else None, ), created_before=created_before, # type: ignore[arg-type] created_after=created_after, # type: ignore[arg-type] dataset_type=dataset_type, committed_status=committed_status, name_or_description_substring=name_or_description, organization_id=self.resolve_org_id(org), author_id=None if author is None else self.resolve_user_id(author), workspace_id=None if workspace is None else self.resolve_workspace_id(workspace), page_size=self.MAX_PAGE_SIZE if limit is None else min(self.MAX_PAGE_SIZE, limit), ) ), ): for dataset in response.datasets: count += 1 yield dataset if limit is not None and count >= limit: return
[docs] def url(self, dataset: pb2.Dataset) -> str: """ Get the URL to the cluster on the Beaker dashboard. """ dataset_id = self.resolve_dataset_id(dataset) return f"{self.config.agent_address}/ds/{self._url_quote(dataset_id)}"