Source code for openprotein.jobs.futures

"""Application futures for waiting for results from jobs."""

import concurrent.futures
import logging
import time
from abc import ABC, abstractmethod
from datetime import datetime
from types import UnionType
from typing import Collection, Generator

import tqdm
from requests import Response
from typing_extensions import Self

from openprotein import config
from openprotein.base import APISession
from openprotein.errors import TimeoutException
from openprotein.jobs.schemas import Job, JobStatus, JobType

from . import api

logger = logging.getLogger(__name__)


[docs] class Future(ABC): """ Base class for all Futures returning results from a job. """ # NOTE: This base class should be directly inherited for class discovery by the factory `create` method. session: APISession job: Job def __init__(self, session: APISession, job: Job): self.session = session self.job = job @classmethod def create( cls: type[Self], session: APISession, job_id: str | None = None, job: Job | None = None, response: Response | dict | None = None, **kwargs, ) -> Self: """Create an instance of the appropriate Future class based on the job type. Parameters ---------- session : APISession Session for API interactions. job_id : str | None, optional The ID of the Job to initialize this future with. job : Job | None, optional The Job object to initialize this future with. response : Response | dict | None, optional The response from a job request returning a job-like object. **kwargs Additional keyword arguments to pass to the Future class constructor. Returns ------- Self An instance of the appropriate Future class. Raises ------ ValueError If `job_id`, `job`, and `response` are all None. ValueError If an appropriate Future subclass cannot be found for the job type. :meta private: """ # parse job # default to use job_id first if job_id is not None: # get job job = api.job_get(session=session, job_id=job_id) # set obj to parse using job or response obj = job or response if obj is None: raise ValueError("Expected job_id, job or response") # parse specific job job = Job.create(obj, **kwargs) # Dynamically discover all subclasses of FutureBase future_classes = Future.__subclasses__() # Find the Future class that matches the job for future_class in future_classes: if ( type(job) == (future_type := future_class.__annotations__.get("job")) or isinstance(future_type, UnionType) and type(job) in future_type.__args__ ): if isinstance(future_class.__dict__.get("create"), classmethod): future = future_class.create(session=session, job=job, **kwargs) else: future = future_class(session=session, job=job, **kwargs) return future # type: ignore - needed since type checker doesnt know subclass raise ValueError(f"Unsupported job type: {job.job_type}") def __str__(self) -> str: return str(self.job) def __repr__(self): return repr(self.job) @property def id(self) -> str: """The unique identifier of the job.""" return self.job.job_id job_id = id @property def job_type(self) -> str: """The type of the job.""" return self.job.job_type @property def status(self) -> JobStatus: """The current status of the job.""" return self.job.status @property def created_date(self) -> datetime: """The creation timestamp of the job.""" return self.job.created_date @property def start_date(self) -> datetime | None: """The start timestamp of the job.""" return self.job.start_date @property def end_date(self) -> datetime | None: """The end timestamp of the job.""" return self.job.end_date @property def progress_counter(self) -> int: """The progress counter of the job.""" return self.job.progress_counter or 0 def done(self) -> bool: """Check if the job has completed. Returns ------- bool True if the job is done, False otherwise. """ return self.status.done() def cancelled(self) -> bool: """Check if the job has been cancelled. Returns ------- bool True if the job is cancelled, False otherwise. """ return self.status.cancelled() def _update_progress(self, job: Job) -> int: """Update progress for jobs that may not have explicit counters. Parameters ---------- job : Job The job object to update progress from. Returns ------- int The calculated progress value (0-100). """ progress = job.progress_counter # if progress is not None: # Check None before comparison if progress is None: if job.status == JobStatus.PENDING: progress = 5 if job.status == JobStatus.RUNNING: progress = 25 if job.status in [JobStatus.SUCCESS, JobStatus.FAILURE]: progress = 100 return progress or 0 # never None def _refresh_job(self) -> Job: """Refresh and return the internal job object. Returns ------- Job The refreshed job object. """ # dump extra kwargs to keep on refresh kwargs = { k: v for k, v in self.job.model_dump().items() if k not in Job.model_fields } job = Job.create( api.job_get(session=self.session, job_id=self.job_id), **kwargs ) return job def refresh(self): """Refresh the job status and internal job object.""" self.job = self._refresh_job()
[docs] @abstractmethod def get(self, verbose: bool = False, **kwargs): """ Return the results from this job. Parameters ---------- verbose : bool, optional Flag to enable verbose output, by default False. **kwargs Additional keyword arguments. """ raise NotImplementedError()
def _wait_job( self, interval: float = config.POLLING_INTERVAL, timeout: int | None = None, verbose: bool = False, ) -> Job: """Wait for a job to finish and return the final job object. Parameters ---------- interval : float, optional Time in seconds to wait between polls. Defaults to `config.POLLING_INTERVAL`. timeout : int | None, optional Maximum time in seconds to wait before raising an error. Defaults to None (unlimited). verbose : bool, optional If True, print status updates. Defaults to False. Returns ------- Job The completed job object. Raises ------ TimeoutException If the wait time exceeds the specified timeout. """ start_time = time.time() def is_done(job: Job): if timeout is not None: elapsed_time = time.time() - start_time if elapsed_time >= timeout: raise TimeoutException( f"Wait time exceeded timeout {timeout}, waited {elapsed_time}" ) return job.status.done() pbar = None if verbose: pbar = tqdm.tqdm(total=100, desc="Waiting", position=0) job = self._refresh_job() while not is_done(job): if pbar is not None: # pbar.update(1) # pbar.set_postfix({"status": job.status}) progress = self._update_progress(job) pbar.n = progress pbar.set_postfix({"status": job.status}) # pbar.refresh() # print(f'Retry {retries}, status={self.job.status}, time elapsed {time.time() - start_time:.2f}') time.sleep(interval) job = self._refresh_job() if pbar is not None: # pbar.update(1) # pbar.set_postfix({"status": job.status}) progress = self._update_progress(job) pbar.n = progress pbar.set_postfix({"status": job.status}) # pbar.refresh() return job
[docs] def wait_until_done( self, interval: float = config.POLLING_INTERVAL, timeout: int | None = None, verbose: bool = False, ): """Wait for the job to complete. Parameters ---------- interval : float, optional Time in seconds between polling. Defaults to `config.POLLING_INTERVAL`. timeout : int, optional Maximum time in seconds to wait. Defaults to None. verbose : bool, optional Verbosity flag. Defaults to False. Returns ------- bool True if the job completed successfully. Notes ----- This method does not fetch the job results, unlike `wait()`. """ job = self._wait_job(interval=interval, timeout=timeout, verbose=verbose) self.job = job return self.done()
[docs] def wait( self, interval: int = config.POLLING_INTERVAL, timeout: int | None = None, verbose: bool = False, ): """Wait for the job to complete, then fetch results. Parameters ---------- interval : int, optional Time in seconds between polling. Defaults to `config.POLLING_INTERVAL`. timeout : int | None, optional Maximum time in seconds to wait. Defaults to None. verbose : bool, optional Verbosity flag. Defaults to False. Returns ------- Any The results of the job. """ time.sleep(1) # buffer for BE to register job job = self._wait_job(interval=interval, timeout=timeout, verbose=verbose) self.job = job return self.get()
class StreamingFuture(ABC): """Abstract base class for Futures that support streaming results.""" @abstractmethod def stream(self, **kwargs) -> Generator: """Return the results from this job as a generator. Parameters ---------- **kwargs Keyword arguments passed to the streaming implementation. Returns ------- Generator A generator that yields job results. Raises ------ NotImplementedError This is an abstract method and must be implemented by a subclass. """ raise NotImplementedError() def get(self, verbose: bool = False, **kwargs) -> list: """Return all results from the job by consuming the stream. Parameters ---------- verbose : bool, optional If True, display a progress bar. Defaults to False. **kwargs Keyword arguments passed to the `stream` method. Returns ------- list A list containing all results from the job. """ generator = self.stream(**kwargs) if verbose: total = None if hasattr(self, "__len__"): total = len(self) # type: ignore - static type checker doesnt know generator = tqdm.tqdm( generator, desc="Retrieving", total=total, position=0, mininterval=1.0 ) return [entry for entry in generator] class MappedFuture(StreamingFuture, ABC): """Base future for jobs with a key-to-result mapping. This class provides methods to retrieve results from jobs where each result is associated with a unique key (e.g., sequence to embedding). """ def __init__( self, session: APISession, job: Job, max_workers: int = config.MAX_CONCURRENT_WORKERS, ): """Initialize the MappedFuture. Parameters ---------- session : APISession The session for API interactions. job : Job The job to retrieve results from. max_workers : int, optional The number of workers for concurrent result retrieval. Defaults to `config.MAX_CONCURRENT_WORKERS`. Notes ----- Use `max_workers` > 0 to enable concurrent retrieval. """ self.session = session self.job = job self.max_workers = max_workers self._cache = {} @abstractmethod def __keys__(self): """Return the keys for the mapped results. Raises ------ NotImplementedError This is an abstract method and must be implemented by a subclass. """ raise NotImplementedError() @abstractmethod def get_item(self, k): """Retrieve a single item by its key. Parameters ---------- k The key of the item to retrieve. Raises ------ NotImplementedError This is an abstract method and must be implemented by a subclass. """ raise NotImplementedError() def stream_sync(self): """Stream the results synchronously. Yields ------ tuple A tuple of (key, value) for each result. :meta private: """ for k in self.__keys__(): v = self[k] yield k, v def stream_parallel(self): """Stream the results in parallel using a thread pool. Yields ------ tuple A tuple of (key, value) for each result. :meta private: """ num_workers = self.max_workers def process(k): v = self[k] return k, v with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: futures = [] for k in self.__keys__(): if k in self._cache: yield k, self._cache[k] else: f = executor.submit(process, k) futures.append(f) for f in futures: yield f.result() def stream(self): """Retrieve results for this job as a stream. Returns ------- Generator A generator that yields (key, value) tuples. """ if self.max_workers > 0: return self.stream_parallel() return self.stream_sync() def __getitem__(self, k): """Get an item by key, using the cache if available. Parameters ---------- k The key of the item to retrieve. Returns ------- Any The value associated with the key. """ if k in self._cache: return self._cache[k] v = self.get_item(k) self._cache[k] = v return v def __len__(self): """Return the total number of items.""" return len(self.__keys__()) def __iter__(self): """Return an iterator over the results.""" return self.stream() class PagedFuture(StreamingFuture, ABC): """Base future class for jobs which have paged results.""" DEFAULT_PAGE_SIZE = 1024 def __init__( self, session: APISession, job: Job, page_size: int | None = None, num_records: int | None = None, max_workers: int = config.MAX_CONCURRENT_WORKERS, ): """Initialize the PagedFuture. Parameters ---------- session : APISession The session for API interactions. job : Job The job to retrieve results from. page_size : int | None, optional The number of records per page. Defaults to `DEFAULT_PAGE_SIZE`. num_records : int | None, optional The total number of records expected. max_workers : int, optional Number of workers for concurrent page retrieval. Defaults to `config.MAX_CONCURRENT_WORKERS`. Notes ----- Use `max_workers` > 0 to enable concurrent retrieval of multiple pages. """ if page_size is None: page_size = self.DEFAULT_PAGE_SIZE self.session = session self.job = job self.page_size = page_size self.max_workers = max_workers self._num_records = num_records @abstractmethod def get_slice(self, start: int, end: int, **kwargs) -> Collection: """Retrieve a slice of results. Parameters ---------- start : int The starting index of the slice. end : int The ending index of the slice. **kwargs Additional keyword arguments. Returns ------- Collection A collection of results for the specified slice. Raises ------ NotImplementedError This is an abstract method and must be implemented by a subclass. """ raise NotImplementedError() def stream_sync(self): """Stream results by fetching pages synchronously. Yields ------ Any Individual results from the paged endpoint. :meta private: """ step = self.page_size num_returned = step offset = 0 while num_returned >= step: result_page = self.get_slice(start=offset, end=offset + step) for result in result_page: yield result num_returned = len(result_page) offset += num_returned def stream_parallel(self): """Stream results by fetching pages in parallel. Yields ------ Any Individual results from the paged endpoint. Notes ----- The number of results should be checked, or stored somehow, so that we don't need to check the number of returned entries to see if we're finished (very awkward when using concurrency). :meta private: """ step = self.page_size offset = 0 num_workers = self.max_workers with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: # submit the paged requests futures: dict[concurrent.futures.Future, int] = {} index: int = 0 for _ in range(num_workers * 2): f = executor.submit(self.get_slice, offset, offset + step) futures[f] = index index += 1 offset += step # until we've retrieved all pages (known by retrieving a page with less than the requested number of records) done = False while not done: results: list[list | None] = [None] * len(futures) futures_next: dict[concurrent.futures.Future, int] = {} index_next: int = 0 next_result_index = 0 # iterate the futures and submit new requests as needed for f in concurrent.futures.as_completed(futures): index = futures[f] result_page = f.result() results[index] = result_page # check if we're done, meaning the result page is not full done = done or len(result_page) < step # if we aren't done, submit another request if not done: f = executor.submit(self.get_slice, offset, offset + step) futures_next[f] = index_next index_next += 1 offset += step # yield the results from this page while ( next_result_index < len(results) and results[next_result_index] is not None ): result_page = results[next_result_index] assert result_page is not None # checked above for result in result_page: yield result next_result_index += 1 # update the list of futures and wait on them again futures = futures_next def stream(self): """Retrieve results for this job as a stream. Returns ------- Generator A generator that yields job results. """ if self.max_workers > 0: return self.stream_parallel() return self.stream_sync() class InvalidFutureError(Exception): """Error for when an unexpected future is created from a job.""" def __init__(self, future: Future, expected: type[Future]): """Initialize the InvalidFutureError. Parameters ---------- future : Future The future instance that was created. expected : type[Future] The type of future that was expected. """ self.future = future self.expected = future self.message = f"Expected future of type {expected}, got {type(future)}" super().__init__(self.message)