Source code for openprotein.svd.svd
"""SVD API providing the interface for creating and using SVD models."""
from openprotein.base import APISession
from openprotein.common import ReductionType
from openprotein.data import AssayDataset, AssayMetadata
from openprotein.embeddings import EmbeddingModel, EmbeddingsAPI
from . import api
from .models import SVDModel
[docs]
class SVDAPI:
"""SVD API providing the interface for creating and using SVD models."""
def __init__(
self,
session: APISession,
):
self.session = session
[docs]
def fit_svd(
self,
model_id: str,
sequences: list[bytes] | list[str] | None = None,
assay: AssayMetadata | AssayDataset | str | None = None,
n_components: int = 1024,
reduction: ReductionType | None = None,
**kwargs,
) -> SVDModel:
"""
Fit an SVD on the sequences with the specified model_id and hyperparameters (n_components).
Parameters
----------
model_id : str
ID of embeddings model to use.
sequences : list of bytes or None, optional
Optional sequences to fit SVD with. Either use sequences or
assay_id. sequences is preferred.
assay : AssayMetadata or AssayDataset or str or None, optional
Optional assay containing sequences to fit SVD with.
Or its assay_id. Either use sequences or assay.
Ignored if sequences are provided.
n_components : int, optional
The number of components for the SVD. Defaults to 1024.
reduction : str or None, optional
Type of embedding reduction to use for computing features.
E.g. "MEAN" or "SUM". Useful when dealing with variable length
sequence. Defaults to None.
kwargs :
Additional keyword arguments to be passed to foundational models, e.g. prompt_id for PoET models.
Returns
-------
SVDModel
The SVD model being fit.
"""
embeddings_api = getattr(self.session, "embedding", None)
assert isinstance(embeddings_api, EmbeddingsAPI)
model = embeddings_api.get_model(model_id)
assert isinstance(model, EmbeddingModel), "Expected EmbeddingModel"
# get assay_id
assay_id = (
assay.assay_id
if isinstance(assay, AssayMetadata)
else assay.id if isinstance(assay, AssayDataset) else assay
)
return SVDModel(
session=self.session,
job=api.svd_fit_post(
session=self.session,
model_id=model.id,
sequences=sequences,
assay_id=assay_id,
n_components=n_components,
reduction=reduction,
**kwargs,
),
)
[docs]
def get_svd(self, svd_id: str) -> SVDModel:
"""
Get SVD job results. Including SVD dimension and sequence lengths.
Requires a successful SVD job from fit_svd
Parameters
----------
svd_id : str
The ID of the SVD job.
Returns
-------
SVDModel
The model with the SVD fit.
"""
metadata = api.svd_get(self.session, svd_id)
return SVDModel(
session=self.session,
metadata=metadata,
)
def __delete_svd(self, svd_id: str) -> bool:
"""
Delete SVD model.
Parameters
----------
svd_id : str
The ID of the SVD job.
Returns
-------
bool
Whether or not the SVD was successfully deleted.
"""
return api.svd_delete(self.session, svd_id)
[docs]
def list_svd(self) -> list[SVDModel]:
"""
List SVD models made by user.
Returns
-------
list of SVDModel
List of SVDs that the user has access to.
"""
return [
SVDModel(
session=self.session,
metadata=metadata,
)
for metadata in api.svd_list_get(self.session)
]