Source code for openprotein.api.poet

from typing import Iterator, Optional, List, Literal, Dict
from openprotein.pydantic import BaseModel, validator
from io import BytesIO
import random
import requests

from openprotein.base import APISession
from openprotein.api.jobs import AsyncJobFuture, StreamingAsyncJobFuture
import numpy as np
from openprotein.jobs import ResultsParser, Job, register_job_type, JobType
import openprotein.config as config

from openprotein.errors import (
    InvalidParameterError,
    MissingParameterError,
    APIError,
)
from openprotein.api.align import csv_stream, AlignFutureMixin
from openprotein.futures import FutureBase, FutureFactory


class PoetSSPResult(BaseModel):
    sequence: bytes
    score: List[float]
    name: Optional[str] = None
    _n: int = 0

    @validator("sequence", pre=True, always=True)
    def replacename(cls, value):
        """rename X0X"""
        if "X0X" in str(value):
            return b"WT"
        return value

    @validator("name", pre=True, always=True)
    def incrementname(cls, value):
        if value is None:
            cls._n += 1
            return f"Mutant{cls._n}"
        return value


class PoetScoreResult(BaseModel):
    sequence: bytes
    score: List[float]
    name: Optional[str] = None


@register_job_type(JobType.poet_score)
class PoetScoreJob(Job):
    parent_id: Optional[str] = None
    s3prefix: Optional[str] = None
    page_size: Optional[int] = None
    page_offset: Optional[int] = None
    num_rows: Optional[int] = None
    result: Optional[List[PoetScoreResult]] = None
    n_completed: Optional[int] = None

    job_type: Literal[JobType.poet_score] = JobType.poet_score


@register_job_type(JobType.poet_single_site)
class PoetSSPJob(PoetScoreJob):
    parent_id: Optional[str] = None
    s3prefix: Optional[str] = None
    page_size: Optional[int] = None
    page_offset: Optional[int] = None
    num_rows: Optional[int] = None
    result: Optional[List[PoetSSPResult]] = None
    n_completed: Optional[int] = None

    job_type: Literal[JobType.poet_single_site] = JobType.poet_single_site


@register_job_type(JobType.poet_generate)
class PoetGenerateJob(Job):
    parent_id: Optional[str] = None
    s3prefix: Optional[str] = None
    page_size: Optional[int] = None
    page_offset: Optional[int] = None
    num_rows: Optional[int] = None
    result: Optional[List[PoetScoreResult]] = None
    n_completed: Optional[int] = None

    job_type: Literal[JobType.poet_generate] = JobType.poet_generate


def poet_score_post(
    session: APISession, prompt_id: str, queries: List[bytes]
) -> FutureFactory:
    """
    Submits a job to score sequences based on the given prompt.

    Parameters
    ----------
    session : APISession
        An instance of APISession to manage interactions with the API.
    prompt_id : str
        The ID of the prompt.
    queries : List[str]
        A list of query sequences to be scored.

    Raises
    ------
    APIError
        If there is an issue with the API request.

    Returns
    -------
    PoetScoreJob
        An object representing the status and results of the scoring job.
    """
    endpoint = "v1/poet/score"

    if len(queries) == 0:
        raise MissingParameterError("Must include queries for scoring!")
    if not prompt_id:
        raise MissingParameterError("Must include prompt_id in request!")

    if isinstance(queries[0], str):
        queries = [i.encode() for i in queries]
    try:
        variant_file = BytesIO(b"\n".join(queries))
        params = {"prompt_id": prompt_id}
        response = session.post(
            endpoint, files={"variant_file": variant_file}, params=params
        )
        return FutureFactory.create_future(session=session, response=response)
    except Exception as exc:
        raise APIError(f"Failed to post poet score: {exc}") from exc


def poet_score_get(
    session: APISession, job_id, page_size=config.POET_PAGE_SIZE, page_offset=0
):
    """
    Fetch a page of results from a PoET score job.

    Parameters
    ----------
    session : APISession
        An instance of APISession to manage interactions with the API.
    job_id : str
        The ID of the PoET scoring job to fetch results from.
    page_size : int, optional
        The number of results to fetch in a single page. Defaults to config.POET_PAGE_SIZE.
    page_offset : int, optional
        The offset (number of results) to start fetching results from. Defaults to 0.

    Raises
    ------
    APIError
        If the provided page size is larger than the maximum allowed page size.

    Returns
    -------
    PoetScoreJob
        An object representing the PoET scoring job, including its current status and results (if any).
    """
    endpoint = "v1/poet/score"

    if page_size > config.POET_MAX_PAGE_SIZE:
        raise APIError(
            f"Page size must be less than the max for PoET: {config.POET_MAX_PAGE_SIZE}"
        )

    response = session.get(
        endpoint,
        params={"job_id": job_id, "page_size": page_size, "page_offset": page_offset},
    )

    # return results to be assembled together
    return ResultsParser.parse_obj(response)


def poet_single_site_post(
    session: APISession, variant, parent_id=None, prompt_id=None
) -> FutureFactory:
    """
    Request PoET single-site analysis for a variant.

    This function will mutate every position in the variant to every amino acid and return the scores.
    Note that if parent_id is set then it will inherit all prompt properties of that parent.

    Parameters
    ----------
    session : APISession
        An instance of APISession for API interactions.
    variant : str
        The variant to analyze.
    parent_id : str, optional
        The ID of the parent job. Either parent_id or prompt_id must be set. Defaults to None.
    prompt_id : str, optional
        The ID of the prompt. Either parent_id or prompt_id must be set. Defaults to None.

    Raises
    ------
    APIError
        If the input parameters are invalid or there is an issue with the API request.

    Returns
    -------
    PoetSSPJob
        An object representing the status and results of the PoET single-site analysis job.
        Note that the input variant score is given as `X0X`.
    """
    endpoint = "v1/poet/single_site"

    if (parent_id is None and prompt_id is None) or (
        parent_id is not None and prompt_id is not None
    ):
        raise InvalidParameterError("Either parent_id or prompt_id must be set.")

    if isinstance(variant, str):
        variant = variant.encode()

    params = {"variant": variant}
    if prompt_id is not None:
        params["prompt_id"] = prompt_id
    if parent_id is not None:
        params["parent_id"] = parent_id

    try:
        response = session.post(endpoint, params=params)
        return FutureFactory.create_future(session=session, response=response)
    except Exception as exc:
        raise APIError(f"Failed to post poet single-site analysis: {exc}") from exc


def poet_single_site_get(
    session: APISession, job_id: str, page_size: int = 100, page_offset: int = 0
) -> FutureFactory:
    """
    Fetch paged results of a PoET single-site analysis job.

    Parameters
    ----------
    session : APISession
        An instance of APISession for API interactions.
    job_id : str
        The ID of the PoET single-site analysis job to fetch results from.
    page_size : int, optional
        The number of results to fetch in a single page. Defaults to 100.
    page_offset : int, optional
        The offset (number of results) to start fetching results from. Defaults to 0.

    Raises
    ------
    APIError
        If there is an issue with the API request.

    Returns
    -------
    PoetSSPJob
        An object representing the status and results of the PoET single-site analysis job.
    """
    endpoint = "v1/poet/single_site"

    params = {"job_id": job_id, "page_size": page_size, "page_offset": page_offset}

    try:
        response = session.get(endpoint, params=params)

    except Exception as exc:
        raise APIError(
            f"Failed to get poet single-site analysis results: {exc}"
        ) from exc
    # return results to be assembled together
    return ResultsParser.parse_obj(response)


def poet_generate_post(
    session: APISession,
    prompt_id: str,
    num_samples=100,
    temperature=1.0,
    topk=None,
    topp=None,
    max_length=1000,
    random_seed=None,
) -> FutureFactory:
    """
    Generate protein sequences with a prompt.

    Parameters
    ----------
    session : APISession
        An instance of APISession for API interactions.
    prompt_id : str
        The ID of the prompt to generate samples from.
    num_samples : int, optional
        The number of samples to generate. Defaults to 100.
    temperature : float, optional
        The temperature for sampling. Higher values produce more random outputs. Defaults to 1.0.
    topk : int, optional
        The number of top-k residues to consider during sampling. Defaults to None.
    topp : float, optional
        The cumulative probability threshold for top-p sampling. Defaults to None.
    max_length : int, optional
        The maximum length of generated proteins. Defaults to 1000.
    random_seed : int, optional
        Seed for random number generation. Defaults to a random number.

    Raises
    ------
    APIError
        If there is an issue with the API request.

    Returns
    -------
    Job
        An object representing the status and information about the generation job.
    """
    endpoint = "v1/poet/generate"

    if not (0.1 <= temperature <= 2):
        raise InvalidParameterError("The 'temperature' must be between 0.1 and 2.")
    if topk:
        if not (2 <= topk <= 20):
            raise InvalidParameterError("The 'topk' must be between 2 and 22.")
    if topp:
        if not (0 <= topp <= 1):
            raise InvalidParameterError("The 'topp' must be between 0 and 1.")
    if random_seed:
        if not (0 <= random_seed <= 2**32):
            raise InvalidParameterError("The 'random_seed' must be between 0 and 1.")

    if random_seed is None:
        random_seed = random.randrange(2**32)

    params = {
        "prompt_id": prompt_id,
        "generate_n": num_samples,
        "temperature": temperature,
        "maxlen": max_length,
        "seed": random_seed,
    }
    if topk is not None:
        params["topk"] = topk
    if topp is not None:
        params["topp"] = topp

    try:
        response = session.post(endpoint, params=params)
        return FutureFactory.create_future(session=session, response=response)
    except Exception as exc:
        raise APIError(f"Failed to post PoET generation request: {exc}") from exc


def poet_generate_get(session: APISession, job_id) -> requests.Response:
    """
    Get the results of a PoET generation job.

    Parameters
    ----------
    session : APISession
        An instance of APISession for API interactions.
    job_id : str
        Job ID from a poet/generate job.

    Raises
    ------
    APIError
        If there is an issue with the API request.

    Returns
    -------
    requests.Response
        The response object containing the results of the PoET generation job.
    """
    endpoint = "v1/poet/generate"

    params = {"job_id": job_id}

    try:
        response = session.get(endpoint, params=params, stream=True)
        return response
    except Exception as exc:
        raise APIError(f"Failed to get poet generation results: {exc}") from exc


class PoetFuture(AlignFutureMixin, AsyncJobFuture):
    def _fmt_results(self, results):
        # Format results after getting is complete
        return [(p.name, p.sequence, np.asarray(p.score)) for p in results]

    def get(self, verbose=False) -> List:
        return super().get(verbose=verbose)


[docs] class PoetScoreFuture(PoetFuture, FutureBase): """ Represents a result of a PoET scoring job. Attributes ---------- session : APISession An instance of APISession for API interactions. job : Job The PoET scoring job. page_size : int The number of results to fetch in a single page. Methods ------- get(verbose=False) Get the final results of the PoET job. """ job_type = ["/poet", "/poet/score"]
[docs] def __init__( self, session: APISession, job: Job, page_size=config.POET_PAGE_SIZE, **kwargs ): """ init a PoetScoreFuture instance. Parameters ---------- session (APISession): An instance of APISession for API interactions. job (Job): The PoET scoring job. page_size (int, optional): The number of results to fetch in a single page. Defaults to config.POET_PAGE_SIZE. """ super().__init__(session, job) self.page_size = page_size
[docs] def get(self, verbose=False) -> List[tuple]: """ Get the final results of the PoET scoring job. Parameters ---------- verbose : bool, optional If True, print verbose output. Defaults to False. Raises ------ APIError If there is an issue with the API request. Returns ------- List[PoetScoreResult] A list of PoetScoreResult objects representing the scoring results. """ job_id = self.job.job_id step = self.page_size results = [] num_returned = step offset = 0 while num_returned >= step: try: response = poet_score_get( self.session, job_id, page_offset=offset, page_size=step, ) results += response.result num_returned = len(response.result) offset += num_returned except APIError as exc: if verbose: print(f"Failed to get results: {exc}") return self._fmt_results(results) return self._fmt_results(results)
[docs] class PoetSingleSiteFuture(PoetFuture, FutureBase): """ Represents a result of a PoET single-site analysis job. Attributes ---------- session : APISession An instance of APISession for API interactions. job : Job The PoET scoring job. page_size : int The number of results to fetch in a single page. Methods ------- get(verbose=False) Get the final results of the PoET job. """ job_type = "/poet/single_site" def _fmt_results(self, results): # Format results after getting is complete return {p.sequence: np.asarray(p.score) for p in results}
[docs] def __init__( self, session: APISession, job: Job, page_size=config.POET_PAGE_SIZE, **kwargs ): """ init a PoetSingleSiteFuture instance. Parameters ---------- session (APISession): An instance of APISession for API interactions. job (Job): The PoET single-site analysis job. page_size (int, optional): The number of results to fetch in a single page. Defaults to config.POET_PAGE_SIZE. """ super().__init__(session, job) self.page_size = page_size
[docs] def get(self, verbose=False) -> Dict: """ Get the results of a PoET single-site analysis job. Parameters ---------- verbose : bool, optional If True, print verbose output. Defaults to False. Returns ------- Dict[bytes, float] A dictionary mapping mutation codes to scores. Raises ------ APIError If there is an issue with the API request. """ job_id = self.job.job_id step = self.page_size results = [] num_returned = step offset = 0 while num_returned >= step: try: response = poet_single_site_get( self.session, job_id, page_offset=offset, page_size=step, ) results += response.result num_returned = len(response.result) offset += num_returned except APIError as exc: if verbose: print(f"Failed to get results: {exc}") return self._fmt_results(results) return self._fmt_results(results)
[docs] class PoetGenerateFuture(PoetFuture, StreamingAsyncJobFuture, FutureBase): """ Represents a result of a PoET generation job. Attributes ---------- session : APISession An instance of APISession for API interactions. job : Job The PoET scoring job. Methods: stream() -> Iterator[PoetScoreResult]: Stream the results of the PoET generation job. """ job_type = "/poet/generate"
[docs] def stream(self) -> Iterator[PoetScoreResult]: """ Stream the results from the response. Returns ------ PoetScoreResult: Yield A result object containing the sequence, score, and name. Raises ------ APIError If the request fails. """ try: response = poet_generate_get(self.session, self.job.job_id) for tokens in csv_stream(response): try: name, sequence = tokens[:2] score = [float(s) for s in tokens[2:]] sequence = sequence.encode() sample = PoetScoreResult(sequence=sequence, score=score, name=name) yield self._fmt_results([sample])[0] except (IndexError, ValueError) as exc: # Skip malformed or incomplete tokens print( f"Skipping malformed or incomplete tokens: {tokens} with {exc}" ) except APIError as exc: print(f"Failed to stream PoET generation results: {exc}")
def get(self, verbose=False) -> List: return super().get(verbose=verbose)