Source code for openprotein.app.models.predictor.predictor

from openprotein.api import assaydata
from openprotein.api import job as job_api
from openprotein.api import predictor, svd
from openprotein.base import APISession
from openprotein.errors import InvalidParameterError
from openprotein.schemas import (
    EnsembleJob,
    ModelCriterion,
    PredictorMetadata,
    PredictorType,
    TrainJob,
)
from openprotein.schemas.job import JobType

from ..assaydata import AssayDataset
from ..embeddings import EmbeddingModel
from ..futures import Future
from ..svd import SVDModel
from .predict import PredictionResultFuture
from .validate import CVResultFuture


[docs] class PredictorModel(Future): """ Class providing predict endpoint for fitted predictor models. Also implements a Future that waits for train job. """ job: TrainJob | None
[docs] def __init__( self, session: APISession, job: TrainJob | EnsembleJob | None = None, metadata: PredictorMetadata | None = None, ): """Initializes with either job get or predictor get.""" self._training_assay = None # initialize the metadata if metadata is None: if job is None or job.job_id is None: raise ValueError("Expected predictor metadata or job") metadata = predictor.predictor_get(session, job.job_id) self._metadata = metadata if job is None: if metadata.model_spec.type != PredictorType.ENSEMBLE: job = TrainJob.create( job_api.job_get(session=session, job_id=metadata.id) ) else: job = EnsembleJob( created_date=self._metadata.created_date, status=self._metadata.status, job_type=JobType.predictor_train, ) super().__init__(session, job)
def __str__(self) -> str: return str(self.metadata) def __repr__(self) -> str: return repr(self.metadata) def __or__(self, model: "PredictorModel") -> "PredictorModelGroup": if self.sequence_length is not None: if model.sequence_length != self.sequence_length: raise ValueError( "Expected sequence lengths to either match or be None." ) return PredictorModelGroup( session=self.session, models=[self, model], sequence_length=self.sequence_length or model.sequence_length, check_sequence_length=False, ) def __lt__(self, target: float) -> ModelCriterion: if len(self.training_properties) == 1: return ModelCriterion( model_id=self.id, measurement_name=self.training_properties[0], criterion=ModelCriterion.Criterion( target=target, direction=ModelCriterion.Criterion.DirectionEnum.lt ), ) raise self.InvalidMultitaskModelToCriterion() def __gt__(self, target: float) -> ModelCriterion: if len(self.training_properties) == 1: return ModelCriterion( model_id=self.id, measurement_name=self.training_properties[0], criterion=ModelCriterion.Criterion( target=target, direction=ModelCriterion.Criterion.DirectionEnum.gt ), ) raise self.InvalidMultitaskModelToCriterion() def __eq__(self, target: float) -> ModelCriterion: if len(self.training_properties) == 1: return ModelCriterion( model_id=self.id, measurement_name=self.training_properties[0], criterion=ModelCriterion.Criterion( target=target, direction=ModelCriterion.Criterion.DirectionEnum.eq ), ) raise self.InvalidMultitaskModelToCriterion() class InvalidMultitaskModelToCriterion(Exception): """ Exception raised when trying to create model criterion from multitask predictor. :meta private: """ @property def id(self): """ID of predictor.""" return self._metadata.id @property def reduction(self): """The reduction of th embeddings used to train the predictor, if any.""" return ( self._metadata.model_spec.features.reduction if self._metadata.model_spec.features is not None else None ) @property def sequence_length(self): """The sequence length constraint on the predictor, if any.""" if (constraints := self._metadata.model_spec.constraints) is not None: return constraints.sequence_length return None @property def training_assay(self) -> AssayDataset: """The assay the predictor was trained on.""" if self._training_assay is None: self._training_assay = self.get_assay() return self._training_assay @property def training_properties(self) -> list[str]: """The list of properties the predictor was trained on.""" return self._metadata.training_dataset.properties @property def metadata(self): """The predictor metadata.""" self._refresh_metadata() return self._metadata def _refresh_metadata(self): if not self._metadata.is_done(): self._metadata = predictor.predictor_get(self.session, self._metadata.id)
[docs] def get_model(self) -> EmbeddingModel | SVDModel | None: """Retrieve the embeddings or SVD model used to create embeddings to train on.""" if ( (features := self._metadata.model_spec.features) and (model_id := features.model_id) is None or features is None ): return None elif features.type.upper() == "PLM": model = EmbeddingModel.create(session=self.session, model_id=model_id) elif features.type.upper() == "SVD": model = SVDModel( session=self.session, metadata=svd.svd_get(session=self.session, svd_id=model_id), ) else: raise ValueError(f"Unexpected feature type {features.type}") return model
@property def model(self) -> EmbeddingModel | SVDModel | None: """The embeddings or SVD model used to create embeddings to train on.""" return self.get_model()
[docs] def delete(self) -> bool: """ Delete this predictor model. """ return predictor.predictor_delete(self.session, self.id)
[docs] def get(self, verbose: bool = False): """ Returns the train loss curves. """ return self.metadata.traingraphs
[docs] def get_assay(self) -> AssayDataset: """ Get assay used for train job. Returns ------- AssayDataset: Assay dataset used for train job. """ return AssayDataset( session=self.session, metadata=assaydata.get_assay_metadata( self.session, self._metadata.training_dataset.assay_id ), )
[docs] def crossvalidate(self, n_splits: int | None = None) -> CVResultFuture: """ Run a crossvalidation on the trained predictor. """ return CVResultFuture.create( session=self.session, job=predictor.predictor_crossvalidate_post( session=self.session, predictor_id=self.id, n_splits=n_splits, ), )
[docs] def predict(self, sequences: list[bytes] | list[str]) -> PredictionResultFuture: """ Make predictions about the trained properties for a list of sequences. """ if self.sequence_length is not None: for sequence in sequences: # convert to string to check token length sequence = sequence if isinstance(sequence, str) else sequence.decode() if len(sequence) != self.sequence_length: raise InvalidParameterError( f"Expected sequences to predict to be of length {self.sequence_length}" ) return PredictionResultFuture.create( session=self.session, job=predictor.predictor_predict_post( session=self.session, predictor_id=self.id, sequences=sequences ), )
[docs] def single_site(self, sequence: bytes | str) -> PredictionResultFuture: """ Compute the single-site mutated predictions of a base sequence. """ if self.sequence_length is not None: # convert to string to check token length seq = sequence if isinstance(sequence, str) else sequence.decode() if len(seq) != self.sequence_length: raise InvalidParameterError( f"Expected sequence to predict to be of length {self.sequence_length}" ) return PredictionResultFuture.create( session=self.session, job=predictor.predictor_predict_single_site_post( session=self.session, predictor_id=self.id, base_sequence=sequence ), )
class PredictorModelGroup(Future): """ Class providing predict endpoint for fitted predictor models. Also implements a Future that waits for train job. """ __models__: list[PredictorModel] def __init__( self, session: APISession, models: list[PredictorModel], sequence_length: int | None = None, check_sequence_length: bool = True, # turn off checking - prevent n^2 operation when chaining many ): if len(models) == 0: raise ValueError("Expected at least one model to group") # calculate and check sequence length compatibility if check_sequence_length: for m in models: if m.sequence_length is not None: if sequence_length is None: sequence_length = m.sequence_length elif sequence_length != m.sequence_length: raise ValueError( "Expected sequence lengths of all models to either match or be None." ) self.sequence_length = sequence_length self.session = session self.__models__ = models def __str__(self) -> str: return repr(self.__models__) def __repr__(self) -> str: return repr(self.__models__) def __or__(self, model: PredictorModel) -> "PredictorModelGroup": if self.sequence_length is not None: if model.sequence_length != self.sequence_length: raise ValueError( "Expected sequence lengths to either match or be None." ) return PredictorModelGroup( session=self.session, models=self.__models__ + [model], sequence_length=self.sequence_length or model.sequence_length, check_sequence_length=False, ) def predict(self, sequences: list[bytes] | list[str]) -> PredictionResultFuture: """ Make predictions about the trained properties for a list of sequences. """ if self.sequence_length is not None: for sequence in sequences: # convert to string to check token length sequence = sequence if isinstance(sequence, str) else sequence.decode() if len(sequence) != self.sequence_length: raise InvalidParameterError( f"Expected sequences to predict to be of length {self.sequence_length}" ) return PredictionResultFuture.create( session=self.session, job=predictor.predictor_predict_multi_post( session=self.session, predictor_ids=[m.id for m in self.__models__], sequences=sequences, ), ) def single_site(self, sequence: bytes | str) -> PredictionResultFuture: """ Compute the single-site mutated predictions of a base sequence. """ if self.sequence_length is not None: # convert to string to check token length seq = sequence if isinstance(sequence, str) else sequence.decode() if len(seq) != self.sequence_length: raise InvalidParameterError( f"Expected sequence to predict to be of length {self.sequence_length}" ) return PredictionResultFuture.create( session=self.session, job=predictor.predictor_predict_single_site_post( session=self.session, predictor_id=self.id, base_sequence=sequence ), ) def get(self, verbose: bool = False): """ Returns the predictor model. :meta private: """ return self def delete(self): return predictor.predictor_delete(session=self.session, predictor_id=self.id)