Source code for openprotein.prompt.prompt

"""Prompt API providing the interface to create prompts for use with PoET models."""

from typing import List, Sequence

from openprotein.base import APISession
from openprotein.molecules import Protein, Complex
from openprotein.utils import uuid

from . import api
from .models import Prompt, Query
from .schemas import Context


[docs] class PromptAPI: """Prompt API providing the interface to create prompts for use with PoET models.""" def __init__(self, session: APISession): self.session = session
[docs] def create_prompt( self, context: Context | Sequence[Context], name: str | None = None, description: str | None = None, ) -> Prompt: """ Create a prompt. Parameters ---------- context : Context | Sequence[Context] Context or list of contexts. Each context is a sequence of entries where each entry is a raw sequence (``bytes``/``str``, optionally with ``:`` chain breaks for multichain), :py:class:`Protein`, or :py:class:`Complex`. Currently only protein chains are accepted; passing a Complex with DNA, RNA, or Ligand chains raises :py:class:`InvalidParameterError`. This restriction may be relaxed in the future. name : str Name of the prompt. description : Optional[str] Description of the prompt. Returns ------- Prompt The created prompt. """ return Prompt( session=self.session, metadata=api.create_prompt( session=self.session, context=context, name=name, description=description, ), )
[docs] def get_prompt(self, prompt_id: str) -> Prompt: """ Get the prompt for a given prompt ID. Parameters ---------- prompt_id : str The prompt ID. Returns ------- Prompt The prompt. """ return Prompt( session=self.session, metadata=api.get_prompt_metadata(session=self.session, prompt_id=prompt_id), )
[docs] def list_prompts(self) -> List[Prompt]: """ List all prompts. Returns ------- List[Prompt] List of prompts. """ return [ Prompt(session=self.session, metadata=p) for p in api.list_prompts(session=self.session) ]
[docs] def create_query( self, query: str | bytes | Protein | Complex, force_structure: bool = False, ) -> Query: """ Create a query. Parameters ---------- query : bytes or str or Protein or Complex A query protein or complex. Raw ``bytes``/``str`` inputs may include ``:`` chain breaks to denote a multichain protein. Currently only protein chains are accepted; passing a Complex with DNA, RNA, or Ligand chains raises :py:class:`InvalidParameterError`. This restriction may be relaxed in the future. force_structure : bool, optional Optionally force a query to be interpreted with a structure. Useful for creating structure prediction queries which can have no structure. Returns ------- Query The created query. """ return Query( session=self.session, metadata=api.create_query( session=self.session, query=query, force_structure=force_structure, ), )
[docs] def get_query(self, query_id: str) -> Query: """ Get the query for a given query ID. Parameters ---------- query_id : str The query ID. Returns ------- Query The query. """ return Query( session=self.session, metadata=api.get_query_metadata(session=self.session, query_id=query_id), )
def _resolve_query( self, query: ( str | bytes | Protein | Complex | Query | list[str | bytes | Protein | Complex | Query] | None ) = None, force_structure: bool = False, ) -> str | list[str] | None: if query is None: query_id = None elif isinstance(query, list): query_id = [ self._resolve_query(query=q, force_structure=force_structure) for q in query ] elif ( isinstance(query, Protein) or isinstance(query, Complex) or isinstance(query, bytes) or (isinstance(query, str) and not uuid.is_valid_uuid(query)) ): query_ = self.create_query(query=query, force_structure=force_structure) query_id = query_.id else: query_id = query if isinstance(query, str) else query.id return query_id