from openprotein.base import APISession
from openprotein.api.jobs import Job, MappedAsyncJobFuture
import openprotein.config as config
from openprotein.api.embedding import ModelMetadata
from openprotein.api.align import validate_msa, MSAFuture
import openprotein.pydantic as pydantic
from typing import Optional, List, Union, Tuple
from openprotein.futures import FutureBase, FutureFactory
from abc import ABC, abstractmethod
PATH_PREFIX = "v1/fold"
class FoldModelBase:
# overridden by subclasses
# get correct fold model
model_id = None
@classmethod
def get_model(cls):
if isinstance(cls.model_id, str):
return [cls.model_id]
return cls.model_id
class FoldModelFactory:
"""Factory class for creating Future instances based on job_type."""
@staticmethod
def create_model(session, model_id, metadata=None, default=None):
"""
Create and return an instance of the appropriate Future class based on the job type.
Returns:
- An instance of the appropriate Future class.
"""
# Dynamically discover all subclasses of FutureBase
future_classes = FoldModelBase.__subclasses__()
# Find the Future class that matches the job type
for future_class in future_classes:
if model_id in future_class.get_model():
return future_class(
session=session, model_id=model_id, metadata=metadata
)
# default to FoldModel
try:
return default(session=session, model_id=model_id, metadata=metadata)
except Exception:
raise ValueError(f"Unsupported model_id type: {model_id}")
def fold_models_list_get(session: APISession) -> List[str]:
"""
List available fold models.
Args:
session (APISession): API session
Returns:
List[str]: list of model names.
"""
endpoint = PATH_PREFIX + "/models"
response = session.get(endpoint)
result = response.json()
return result
def fold_model_get(session: APISession, model_id: str) -> ModelMetadata:
endpoint = PATH_PREFIX + f"/models/{model_id}"
response = session.get(endpoint)
result = response.json()
return ModelMetadata(**result)
def fold_get_sequences(session: APISession, job_id: str) -> List[bytes]:
"""
Get results associated with the given request ID.
Parameters
----------
session : APISession
Session object for API communication.
job_id : str
job ID to fetch
Returns
-------
sequences : List[bytes]
"""
endpoint = PATH_PREFIX + f"/{job_id}/sequences"
response = session.get(endpoint)
return pydantic.parse_obj_as(List[bytes], response.json())
def fold_get_sequence_result(
session: APISession, job_id: str, sequence: bytes
) -> bytes:
"""
Get encoded result for a sequence from the request ID.
Parameters
----------
session : APISession
Session object for API communication.
job_id : str
job ID to retrieve results from
sequence : bytes
sequence to retrieve results for
Returns
-------
result : bytes
"""
if isinstance(sequence, bytes):
sequence = sequence.decode()
endpoint = PATH_PREFIX + f"/{job_id}/{sequence}"
response = session.get(endpoint)
return response.content
[docs]
class FoldResultFuture(MappedAsyncJobFuture, FutureBase):
job_type = ["/embeddings/fold"]
"""Future Job for manipulating results"""
[docs]
def __init__(
self,
session: APISession,
job: Job,
sequences=None,
max_workers=config.MAX_CONCURRENT_WORKERS,
):
super().__init__(session, job, max_workers)
if sequences is None:
sequences = fold_get_sequences(self.session, job_id=job.job_id)
self._sequences = sequences
@property
def sequences(self):
if self._sequences is None:
self._sequences = fold_get_sequences(self.session, self.job.job_id)
return self._sequences
@property
def id(self):
return self.job.job_id
def keys(self):
return self.sequences
def get(self, verbose=False) -> List[Tuple[str, str]]:
return super().get(verbose=verbose)
[docs]
def get_item(self, sequence: bytes) -> bytes:
"""
Get fold results for specified sequence.
Args:
sequence (bytes): sequence to fetch results for
Returns:
np.ndarray: fold
"""
data = fold_get_sequence_result(self.session, self.job.job_id, sequence)
return data #
def fold_models_esmfold_post(
session: APISession,
sequences: List[bytes],
num_recycles: Optional[int] = None,
):
"""
POST a request for structure prediction using ESMFold. Returns a Job object referring to this request
that can be used to retrieve results later.
Parameters
----------
session : APISession
Session object for API communication.
sequences : List[bytes]
sequences to request results for
num_recycles : Optional[int]
number of recycles for structure prediction
Returns
-------
job : Job
"""
endpoint = PATH_PREFIX + "/models/esmfold"
sequences_unicode = [(s if isinstance(s, str) else s.decode()) for s in sequences]
body = {
"sequences": sequences_unicode,
}
if num_recycles is not None:
body["num_recycles"] = num_recycles
response = session.post(endpoint, json=body)
return FutureFactory.create_future(
session=session, response=response, sequences=sequences
)
def fold_models_alphafold2_post(
session: APISession,
msa: Union[str, MSAFuture],
num_recycles: Optional[int] = None,
num_models: Optional[int] = 1,
num_relax: Optional[int] = 0,
):
"""
POST a request for structure prediction using AlphaFold2. Returns a Job object referring to this request
that can be used to retrieve results later.
Parameters
----------
session : APISession
Session object for API communication.
msa : Union[str, MSAfuture]
MSA to use for structure prediction. The first sequence in the MSA is the query sequence.
num_recycles : Optional[int] = None
number of recycles for structure prediction
num_models : Optional[int] = 1
number of models to predict
num_relax : Optional[int] = 0
number of relaxation iterations to run
Returns
-------
job : Job
"""
endpoint = PATH_PREFIX + "/models/alphafold2"
msa_id = msa
if isinstance(msa, MSAFuture):
msa_id = msa.msa_id
body = {
"msa_id": msa_id,
"num_models": num_models,
"num_relax": num_relax,
}
if num_recycles is not None:
body["num_recycles"] = num_recycles
response = session.post(endpoint, json=body)
# GET endpoint for AF2 expects the query sequence (first sequence) within the MSA
# since we don't know what the is, leave the sequence out of the future to be retrieved when calling get()
return FutureFactory.create_future(session=session, response=response)
class FoldModel(ABC):
"""
ABC Class providing inference endpoints for protein fold models served by OpenProtein.
Must implement fold() method.
"""
def __init__(self, session, model_id, metadata=None):
self.session = session
self.id = model_id
self._metadata = metadata
def __str__(self) -> str:
return self.id
def __repr__(self) -> str:
return self.id
@property
def metadata(self):
return self.get_metadata()
def get_metadata(self) -> ModelMetadata:
"""
Get model metadata for this model.
Returns
-------
ModelMetadata
"""
if self._metadata is not None:
return self._metadata
self._metadata = fold_model_get(self.session, self.id)
return self._metadata
@abstractmethod
def fold(self, sequence: str, **kwargs):
pass
[docs]
class ESMFoldModel(FoldModel, FoldModelBase):
model_id = "esmfold"
[docs]
def __init__(self, session, model_id, metadata=None):
super().__init__(session, model_id, metadata)
self.id = self.model_id
[docs]
def fold(self, sequences: List[bytes], num_recycles: int = 1) -> FoldResultFuture:
"""
Fold sequences using this model.
Parameters
----------
sequences : List[bytes]
sequences to fold
num_recycles : int
number of times to recycle models
Returns
-------
FoldResultFuture
"""
return fold_models_esmfold_post(
self.session, sequences, num_recycles=num_recycles
)
[docs]
class AlphaFold2Model(FoldModel, FoldModelBase):
model_id = "alphafold2"
[docs]
def __init__(self, session, model_id, metadata=None):
super().__init__(session, model_id, metadata)
self.id = self.model_id
[docs]
def fold(
self,
msa: Union[str, MSAFuture],
num_recycles: Optional[int] = None,
num_models: int = 1,
num_relax: Optional[int] = 0,
):
"""
Post sequences to alphafold model.
Parameters
----------
msa : Union[str, MSAFuture]
msa
num_recycles : int
number of times to recycle models
num_models : int
number of models to train - best model will be used
max_msa : Union[str, int]
maximum number of sequences in the msa to use.
relax_max_iterations : int
maximum number of iterations
Returns
-------
job : Job
"""
if msa and not isinstance(msa, str):
msa = validate_msa(msa)
return fold_models_alphafold2_post(
self.session,
msa,
num_recycles=num_recycles,
num_models=num_models,
num_relax=num_relax,
)
def validate_fold_id(fold):
if isinstance(fold, str):
return fold
return fold.id
[docs]
class FoldAPI:
"""
This class defines a high level interface for accessing the fold API.
"""
esmfold: ESMFoldModel
alphafold2: AlphaFold2Model
[docs]
def __init__(self, session: APISession):
self.session = session
self._load_models()
@property
def af2(self):
"""Alias for AlphaFold2"""
return self.alphafold2
def _load_models(self):
# Dynamically add model instances as attributes - precludes any drift
models = self.list_models()
for model in models:
model_name = model.id.replace("-", "_") # hyphens out
setattr(self, model_name, model)
[docs]
def list_models(self) -> List[FoldModel]:
"""list models available for creating folds of your sequences"""
models = []
for model_id in fold_models_list_get(self.session):
models.append(
FoldModelFactory.create_model(self.session, model_id, default=FoldModel)
)
return models
[docs]
def get_model(self, model_id: str) -> FoldModel:
"""
Get model by model_id.
FoldModel allows all the usual job manipulation: \
e.g. making POST and GET requests for this model specifically.
Parameters
----------
model_id : str
the model identifier
Returns
-------
FoldModel
The model
Raises
------
HTTPError
If the GET request does not succeed.
"""
return FoldModelFactory.create_model(
session=self.session, model_id=model_id, default=FoldModel
)
[docs]
def get_results(self, job) -> FoldResultFuture:
"""
Retrieves the results of a fold job.
Parameters
----------
job : Job
The fold job whose results are to be retrieved.
Returns
-------
FoldResultFuture
An instance of FoldResultFuture
"""
return FutureFactory.create_future(job=job, session=self.session)