"""UMAP API providing the interface to fit and run UMAP visualizations."""
import typing
from typing import Literal
from openprotein.base import APISession
from openprotein.common import Feature, FeatureType, Reduction, ReductionType
from openprotein.data import AssayDataset, AssayMetadata
from openprotein.embeddings import EmbeddingModel, EmbeddingsAPI
from openprotein.errors import InvalidParameterError
from openprotein.jobs import JobsAPI
from openprotein.svd import SVDAPI, SVDModel
from . import api
from .models import UMAPModel
[docs]
class UMAPAPI:
"""UMAP API providing the interface to fit and run UMAP visualizations."""
def __init__(
self,
session: APISession,
):
self.session = session
@typing.overload
def fit_umap(
self,
model: EmbeddingModel,
reduction: Reduction | ReductionType,
feature_type: Literal["PLM"] = "PLM",
sequences: list[bytes] | list[str] | None = None,
assay: AssayDataset | AssayMetadata | str | None = None,
n_components: int = 2,
n_neighbors: int = 15,
min_dist: float = 0.1,
) -> UMAPModel: ...
@typing.overload
def fit_umap(
self,
model: EmbeddingModel,
) -> UMAPModel: ...
[docs]
def fit_umap(
self,
model: EmbeddingModel | SVDModel | str,
reduction: Reduction | ReductionType | None = None,
feature_type: Feature | FeatureType | None = None,
sequences: list[bytes] | list[str] | None = None,
assay: AssayDataset | AssayMetadata | str | None = None,
n_components: int = 2,
n_neighbors: int = 15,
min_dist: float = 0.1,
**kwargs,
) -> UMAPModel:
"""
Fit an UMAP on the sequences with the specified model_id and hyperparameters (n_components).
Parameters
----------
sequences: list of bytes or None, optional
Optional sequences to fit UMAP with. Either use sequences or
assay_id. sequences is preferred.
assay : AssayMetadata or AssayDataset or str or None, optional
Optional assay containing sequences to fit UMAP with.
Or its assay_id. Either use sequences or assay.
Ignored if sequences are provided.
model : EmbeddingModel or SVDModel or str
Instance of either EmbeddingModel or SVDModel to use depending
on feature type. Can also be a str specifying the model id,
but then feature_type would have to be specified.
feature_type : str or FeatureType or None, optional
Type of features to use for encoding sequences. "SVD" or "PLM".
None would require model to be EmbeddingModel or SVDModel.
n_components : int, optional
Number of UMAP components to fit. Defaults to 2.
n_neighbors : int, optional
Number of neighbors to use for fitting. Defaults to 15.
min_dist : float, optional
Minimum distance in UMAP fitting. Defaults to 0.1.
reduction : str or ReductionType or None, optional
Type of embedding reduction to use for computing features.
E.g. "MEAN" or "SUM". Useful when dealing with variable length
sequence. Defaults to None.
kwargs :
Additional keyword arguments to be passed to foundational models, e.g. prompt_id for PoET models.
Returns
-------
UMAPModel
The UMAP model being fit.
"""
# 1. Check assay data input - just need the id
# get assay_id
assay_id = (
assay.assay_id
if isinstance(assay, AssayMetadata)
else assay.id if isinstance(assay, AssayDataset) else assay
)
# extract feature type
feature_type = (
FeatureType.PLM
if isinstance(model, EmbeddingModel)
else FeatureType.SVD if isinstance(model, SVDModel) else feature_type
)
if feature_type is None:
raise InvalidParameterError(
"Expected feature_type to be provided if passing str model_id as model"
)
if isinstance(feature_type, str):
feature_type = FeatureType(feature_type)
if isinstance(reduction, str):
reduction = ReductionType(reduction)
# get model if model_id
if feature_type == FeatureType.PLM:
if reduction is None:
raise InvalidParameterError(
"Expected reduction if using embedding model"
)
if isinstance(model, str):
embeddings_api = getattr(self.session, "embedding", None)
assert isinstance(embeddings_api, EmbeddingsAPI)
model = embeddings_api.get_model(model)
assert isinstance(model, EmbeddingModel), "Expected EmbeddingModel"
model_id = model.id
elif feature_type == FeatureType.SVD:
if reduction is not None:
raise InvalidParameterError("Unexpected reduction when using SVD model")
if isinstance(model, str):
svd_api = getattr(self.session, "svd", None)
assert isinstance(svd_api, SVDAPI)
model = svd_api.get_svd(model)
assert isinstance(model, SVDModel), "Expected SVDModel"
model_id = model.id
return UMAPModel(
session=self.session,
job=api.umap_fit_post(
session=self.session,
model_id=model_id,
feature_type=feature_type,
sequences=sequences,
assay_id=assay_id,
n_components=n_components,
n_neighbors=n_neighbors,
min_dist=min_dist,
reduction=reduction,
**kwargs,
),
)
[docs]
def get_umap(self, umap_id: str) -> UMAPModel:
"""
Get UMAP job results. Including UMAP dimension and sequence lengths.
Requires a successful UMAP job from fit_umap.
Parameters
----------
umap_id : str
The ID of the UMAP job.
Returns
-------
UMAPModel
The model with the UMAP fit.
"""
metadata = api.umap_get(self.session, umap_id)
return UMAPModel(session=self.session, metadata=metadata)
def __delete_umap(self, umap_id: str) -> bool:
"""
Delete UMAP model.
Parameters
----------
umap_id : str
The ID of the UMAP job.
Returns
-------
bool
True: successful deletion
"""
return api.umap_delete(self.session, umap_id)
[docs]
def list_umap(self) -> list[UMAPModel]:
"""
List UMAP models made by user.
Takes no args.
Returns
-------
list[UMAPModel]
UMAPModels
"""
jobs_api = getattr(self.session, "jobs", None)
assert isinstance(jobs_api, JobsAPI)
return [
UMAPModel(session=self.session, metadata=metadata)
for metadata in api.umap_list_get(self.session)
]