Source code for openprotein.embeddings.models

"""Embeddings model representations which can be used directly for creating embeddings."""

from typing import TYPE_CHECKING

from openprotein.base import APISession
from openprotein.common import FeatureType, ModelMetadata, ReductionType
from openprotein.data import AssayDataset, AssayMetadata, DataAPI
from openprotein.errors import InvalidParameterError

from . import api
from .future import EmbeddingsResultFuture

if TYPE_CHECKING:
    from openprotein.predictor import PredictorModel
    from openprotein.svd import SVDModel
    from openprotein.umap import UMAPModel


[docs] class EmbeddingModel: """Base embeddings model used to understand and provide embeddings from sequences.""" # overridden by subclasses # used to get correct emb model during factory create model_id: list[str] | str = "protembed" def __init__( self, session: APISession, model_id: str, metadata: ModelMetadata | None = None, ): self.session = session self.id = model_id self._metadata = metadata self.__doc__ = self.__fmt_doc() def __fmt_doc(self): summary = str(self.metadata.description.summary) return f"""\t{summary} \t max_sequence_length = {self.metadata.max_sequence_length} \t supported outputs = {self.metadata.output_types} \t supported tokens = {self.metadata.input_tokens} """ def __str__(self) -> str: return self.id def __repr__(self) -> str: return self.id @classmethod def get_model(cls): """ Get the model_id(s) for this EmbeddingModel subclass. Returns ------- list of str List of model_id strings associated with this class. """ if isinstance(cls.model_id, str): return [cls.model_id] return cls.model_id @classmethod def create( cls, session: APISession, model_id: str, default: type["EmbeddingModel"] | None = None, **kwargs, ): """ Create and return an instance of the appropriate EmbeddingModel subclass based on the model_id. Parameters ---------- session : APISession The API session to use. model_id : str The model identifier. default : type variable of EmbeddingModel or None, optional Default EmbeddingModel subclass to use if no match is found. kwargs : Additional keyword arguments to pass to the model constructor. Returns ------- EmbeddingModel An instance of the appropriate EmbeddingModel subclass. Raises ------ ValueError If no suitable EmbeddingModel subclass is found and no default is provided. """ # Dynamically discover all subclasses of EmbeddingModel model_classes = EmbeddingModel.__subclasses__() # Find the EmbeddingModel class that matches the model_id for model_class in model_classes: if model_id in model_class.get_model(): return model_class(session=session, model_id=model_id, **kwargs) # default to ProtembedModel if default is not None: try: return default(session=session, model_id=model_id, **kwargs) except: # continue to throw error as unsupported pass raise ValueError(f"Unsupported model_id type: {model_id}") @property def metadata(self): """ ModelMetadata for this model. Returns ------- ModelMetadata The metadata associated with this model. """ if self._metadata is None: self._metadata = self.get_metadata() return self._metadata def get_metadata(self) -> ModelMetadata: """ Get model metadata for this model. Returns ------- ModelMetadata The metadata associated with this model. """ return api.get_model(self.session, self.id) def embed( self, sequences: list[bytes] | list[str], reduction: ReductionType | None = ReductionType.MEAN, **kwargs, ) -> EmbeddingsResultFuture: """ Embed sequences using this model. Parameters ---------- sequences : list of bytes or list of str Sequences to embed. reduction : ReductionType or None, optional Reduction to use (e.g. mean). Defaults to mean embedding. kwargs: Additional keyword arguments to be used from foundational models, e.g. prompt_id for PoET models. Returns ------- EmbeddingsResultFuture Future object representing the embedding result. """ return EmbeddingsResultFuture.create( session=self.session, job=api.request_post( session=self.session, model_id=self.id, sequences=sequences, reduction=reduction, **kwargs, ), sequences=sequences, ) def logits( self, sequences: list[bytes] | list[str], **kwargs ) -> EmbeddingsResultFuture: """ Compute logit embeddings for sequences using this model. Parameters ---------- sequences : list of bytes or list of str Sequences to compute logits for. kwargs : Additional keyword arguments to be used from foundational models, e.g. prompt_id for PoET models. Returns ------- EmbeddingsResultFuture Future object representing the logits result. """ return EmbeddingsResultFuture.create( session=self.session, job=api.request_logits_post( session=self.session, model_id=self.id, sequences=sequences, **kwargs ), sequences=sequences, ) def fit_svd( self, sequences: list[bytes] | list[str] | None = None, assay: AssayDataset | None = None, n_components: int = 1024, reduction: ReductionType | None = None, **kwargs, ) -> "SVDModel": """ Fit an SVD on the embedding results of this model. This function will create an SVDModel based on the embeddings from this model as well as the hyperparameters specified in the arguments. Parameters ---------- sequences : list of bytes or list of str or None, optional Sequences to fit SVD on. assay : AssayDataset or None, optional Assay containing sequences to fit SVD on. n_components : int, optional Number of components in SVD. Determines output shapes. Default is 1024. reduction : ReductionType or None, optional Embeddings reduction to use (e.g. mean). kwargs : Additional keyword arguments to be used from foundational models, e.g. prompt_id for PoET models. Returns ------- SVDModel The fitted SVD model. Raises ------ InvalidParameterError If neither or both of `assay` and `sequences` are provided. """ # local import for cyclic dep from openprotein.svd import SVDAPI svd_api = getattr(self.session, "svd", None) assert isinstance(svd_api, SVDAPI) # Ensure either or if (assay is None and sequences is None) or ( assay is not None and sequences is not None ): raise InvalidParameterError( "Expected either assay or sequences to fit SVD on!" ) model_id = self.id return svd_api.fit_svd( model_id=model_id, sequences=sequences, assay=assay, n_components=n_components, reduction=reduction, **kwargs, ) def fit_umap( self, sequences: list[bytes] | list[str] | None = None, assay: AssayDataset | None = None, n_components: int = 2, reduction: ReductionType | None = ReductionType.MEAN, **kwargs, ) -> "UMAPModel": """ Fit a UMAP on the embedding results of this model. This function will create a UMAPModel based on the embeddings from this model as well as the hyperparameters specified in the arguments. Parameters ---------- sequences : list of bytes or list of str or None, optional Optional sequences to fit UMAP with. Either use sequences or assay. Sequences is preferred. assay : AssayDataset or None, optional Optional assay containing sequences to fit UMAP with. Either use sequences or assay. Ignored if sequences are provided. n_components : int, optional Number of components in UMAP fit. Determines output shapes. Default is 2. reduction : ReductionType or None, optional Embeddings reduction to use (e.g. mean). Defaults to MEAN. kwargs : Additional keyword arguments to be used from foundational models, e.g. prompt_id for PoET models. Returns ------- UMAPModel The fitted UMAP model. Raises ------ InvalidParameterError If neither or both of `assay` and `sequences` are provided. """ # local import for cyclic dep from openprotein.umap import UMAPAPI umap_api = getattr(self.session, "umap", None) assert isinstance(umap_api, UMAPAPI) # Ensure either or if (assay is None and sequences is None) or ( assay is not None and sequences is not None ): raise InvalidParameterError( "Expected either assay or sequences to fit UMAP on!" ) model_id = self.id return umap_api.fit_umap( model_id=model_id, feature_type=FeatureType.PLM, sequences=sequences, assay_id=assay.id if assay is not None else None, n_components=n_components, reduction=reduction, **kwargs, ) def fit_gp( self, assay: AssayMetadata | AssayDataset | str, properties: list[str], reduction: ReductionType, name: str | None = None, description: str | None = None, **kwargs, ) -> "PredictorModel": """ Fit a Gaussian Process (GP) on an assay using this embedding model and hyperparameters. Parameters ---------- assay : AssayMetadata, AssayDataset, or str Assay to fit GP on. properties : list of str Properties in the assay to fit the GP on. reduction : ReductionType Type of embedding reduction to use for computing features. PLM must use reduction. name : str or None, optional Optional name for the predictor model. description : str or None, optional Optional description for the predictor model. kwargs : Additional keyword arguments to be used from foundational models, e.g. prompt_id for PoET models. Returns ------- PredictorModel The fitted predictor model. Raises ------ InvalidParameterError If no properties are provided, properties are not a subset of assay measurements, or multitask GP is requested. """ # local import to resolve cyclic from openprotein.predictor import PredictorAPI data_api = getattr(self.session, "data", None) assert isinstance(data_api, DataAPI) predictor_api = getattr(self.session, "predictor", None) assert isinstance(predictor_api, PredictorAPI) # get assay if str assay = data_api.get(assay_id=assay) if isinstance(assay, str) else assay # extract assay_id if len(properties) == 0: raise InvalidParameterError("Expected (at-least) 1 property to train") if not set(properties) <= set(assay.measurement_names): raise InvalidParameterError( f"Expected all provided properties to be a subset of assay's measurements: {assay.measurement_names}" ) # TODO - support multitask if len(properties) > 1: raise InvalidParameterError( "Training a multitask GP is not yet supported (i.e. number of properties should only be 1 for now)" ) # inject into predictor api return predictor_api.fit_gp( assay=assay, properties=properties, feature_type=FeatureType.PLM, model=self, reduction=reduction, name=name, description=description, **kwargs, )
class AttnModel(EmbeddingModel): """Embeddings model that provides attention computation.""" def attn( self, sequences: list[bytes] | list[str], **kwargs ) -> EmbeddingsResultFuture: """ Compute attention embeddings for sequences using this model. Parameters ---------- sequences : list of bytes or list of str Sequences to compute attention embeddings for. kwargs : Additional keyword arguments to be used from foundational models. Returns ------- EmbeddingsResultFuture Future object representing the attention result. """ return EmbeddingsResultFuture.create( session=self.session, job=api.request_attn_post( session=self.session, model_id=self.id, sequences=sequences, **kwargs ), sequences=sequences, )