Training models#
This tutorial teaches you how to train models using OpenProtein.AI’s Property Regression models. These models can be used to make predictions for new sequence variants and to design libraries of optimized sequences.
What you need before getting started#
You need an uploaded dataset and assay object in order to create training jobs. For more information, see Uploading your data.
Understanding Model Options#
Embeddings#
Embeddings capture essential patterns and relationships between protein sequences, shaping how the model interprets their properties. The recommended option is PoET, which generates embeddings tailored to specific fitness landscapes. The following embedding models are available for use:
PoET (Recommended)
Conditional protein language model that enables embedding, scoring and generating sequences conditioned on an input protein family of interest.
You will be required to input a prompt if this option is selected. The prompt is used to guide the model predictions based on your family of interest that you have indicated in your prompt
Reduction type will be set to ‘None’ as the default
Prot Seq
Masked protein language model (~300M parameters) trained on UniRef50 with contact and secondary structure prediction as secondary objectives. Use random Fourier position embeddings and FlashAttention to enable fast inference
Rotaprot Large UniRef50w
Masked protein language model (~900M parameters) trained on UniRef100 with sequences weighted inversely proportional to the number of UniRef50 homologs. Uses rotary relative position embeddings and FlashAttention to enable fast inference.
Rotaprot Large UniRef90 Finetuned
Rotaprot Large UniRef50W finetuned on UniRef100 with sequences weighted inversely proportional to the number of UniRef90 cluster members.
ESM1
Community based ESM1 models, with different versions having different model parameters and training data
Models include:
esm1b_t33_650M_UR50S
esm1v_t33_650M_UR90S_1
esm1v_t33_650M_UR90S_2
esm1v_t33_650M_UR90S_3
esm1v_t33_650M_UR90S_4
esm1v_t33_650M_UR90S_5
ESM2
Community-based ESM2 models, with different versions having different model parameters and training data
Models include:
esm2_t8_50M_UR50D
esm2_t12_35M_UR50D
esm2_t30_150M_UR50D
esm2_t33_650M_UR50D
esm2_t36_3B_UR50D
ProtTrans
Community-based ProtTrans models
Model include:
prott5-xl
Reduction Type#
The reduction type determines how the output embeddings are summarized into a more manageable form. This option helps control how the model compresses information, enabling the capture of relevant features from your protein sequences while reducing dimensionality. The available options are:
Mean
Creates a sequence embedding by averaging the per-residue embeddings over the length dimension. Useful for working with variable length sequences, but loses specific position information in the embedding compared with the full LxN embedding or SVD reduction thereof.
When to use: - Variable length sequence datasets when alignment is not a good option. - Sometimes outperforms full embeddings on small or highly diverse datasets.
Sum
The same as the mean embedding except that per-residues embeddings are summed over the length dimension rather than averaged. This can better preserve length and domain multiplicity information but tends to underperform mean embeddings for most use cases.
When to use: - Variable length sequence datasets as an alternative to mean embeddings. Mean embeddings should generally be preferred, but they can be compared on a case-by-case basis.
None
Retains the full embeddings without any dimensionality reduction. This option captures the complete relationships between features, offering the most detailed representation at the cost of larger file sizes. Note that this differs from the web app where None
refers to using SVDs instead.
Note: Using full-sized embeddings may be unstable on the platform at this stage.
When to use: - Most substitution-only design projects should prefer this option. Tends to outperform mean embeddings when sequences are fixed length or can be aligned. - Use when you require the full feature set for detailed analysis or downstream processing.
Training a model#
Use the assay object to create a training job:
[ ]:
# Option 1 (Recommended): Use PoET with prepared prompt
predictor = session.embeddings.poet.fit_gp(assay=assay,
properties=["isobutyramide_normalized_fitness"],
name="mymodel", # name the resulting model
prompt=prompt,
reduction="MEAN")
# Option 2: Use prot-seq or any other model without prompt
# predictor = session.embeddings.prot_seq.fit_gp(assay=assay,
# properties=["isobutyramide_normalized_fitness"],
# name="mymodel", # name the resulting model
# reduction="MEAN")
# Option 3: Use a fitted SVD
# predictor = svd.fit_gp(assay=assay,
# properties=["isobutyramide_normalized_fitness"],
# name="mymodel") # name the resulting model
predictor_id = predictor.id
predictor
PredictorMetadata(id='ae2db820-59c8-4902-8e75-e2d5196c98ad', name='mymodel', description=None, status=<JobStatus.PENDING: 'PENDING'>, model_spec=ModelSpec(kernel=Kernel(type='RBF', multitask=False), constraints=None, features=Features(type=<FeatureType.PLM: 'PLM'>, model_id='poet', reduction='MEAN')), training_dataset=Dataset(assay_id='99dc46be-5fb1-4771-a59a-65ec76d70765', properties=['isobutyramide_normalized_fitness']))
[ ]:
predictor.refresh()
predictor.status
<JobStatus.RUNNING: 'RUNNING'>
Wait for the results before proceeding:
[ ]:
results = predictor.wait(verbose=False)
You can display your results as a scatterplot:
[ ]:
import seaborn as sns
import matplotlib.pyplot as plt
sns.scatterplot(x=range(len(results[0].losses)), y=results[0].losses)
plt.xlabel("Steps")
plt.ylabel("Loss");
Request a cross-validation job to see the training results in more detail:
[ ]:
cvjob = predictor.crossvalidate()
cvjob.status
<JobStatus.PENDING: 'PENDING'>
[ ]:
cvdata = cvjob.wait(verbose=True)
Waiting: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 252.60it/s, status=SUCCESS]
[ ]:
ys, y_mus, y_vars = cvdata
sns.regplot(x=ys.astype(float), y=y_mus.astype(float))
plt.xlabel("Y")
plt.title("Cross validation results")
plt.ylabel("Y-hat");
Retrieving old models#
In order to retrieve and predict using older models trained with session.train
(now session.deprecated.train
), you will need to specify the train job id with load_job
.
[ ]:
train = session.load_job("6f8dd9a8-0370-4f3b-9ff6-43281a6f3428")
train
TrainJob(job_id='6f8dd9a8-0370-4f3b-9ff6-43281a6f3428', job_type=<JobType.workflow_train: '/workflow/train'>, status=<JobStatus.SUCCESS: 'SUCCESS'>, created_date=datetime.datetime(2024, 12, 2, 8, 42, 54, 93414), start_date=datetime.datetime(2024, 12, 2, 8, 42, 54, 160768), end_date=datetime.datetime(2024, 12, 2, 8, 46, 40, 984249), prerequisite_job_id='0de826f5-7202-4ac8-a947-c232bc3a0556', progress_message=None, progress_counter=None, sequence_length=None, traingraph=None)
You can list previous train jobs using session.jobs.list
with JobType.workflow_train
.
[ ]:
from openprotein.schemas import JobType
train_jobs = session.jobs.list(job_type=JobType.workflow_train, limit=5)
train_jobs
[TrainJob(job_id='6f8dd9a8-0370-4f3b-9ff6-43281a6f3428', job_type=<JobType.workflow_train: '/workflow/train'>, status=<JobStatus.SUCCESS: 'SUCCESS'>, created_date=datetime.datetime(2024, 12, 2, 8, 42, 54, 93414), start_date=datetime.datetime(2024, 12, 2, 8, 42, 54, 160768), end_date=datetime.datetime(2024, 12, 2, 8, 46, 40, 984249), prerequisite_job_id='0de826f5-7202-4ac8-a947-c232bc3a0556', progress_message=None, progress_counter=None, sequence_length=None, traingraph=None),
TrainJob(job_id='86addd0d-5939-4116-bc4d-45f03582a540', job_type=<JobType.workflow_train: '/workflow/train'>, status=<JobStatus.SUCCESS: 'SUCCESS'>, created_date=datetime.datetime(2024, 12, 2, 8, 19, 29, 555199), start_date=datetime.datetime(2024, 12, 2, 8, 23, 6, 269712), end_date=datetime.datetime(2024, 12, 2, 8, 41, 36, 398557), prerequisite_job_id='0de826f5-7202-4ac8-a947-c232bc3a0556', progress_message=None, progress_counter=None, sequence_length=None, traingraph=None),
TrainJob(job_id='c5b8784f-b5b0-4083-b99a-268ddb5d8cf0', job_type=<JobType.workflow_train: '/workflow/train'>, status=<JobStatus.SUCCESS: 'SUCCESS'>, created_date=datetime.datetime(2024, 11, 20, 8, 47, 54, 740375), start_date=datetime.datetime(2024, 11, 20, 8, 54, 4, 258277), end_date=datetime.datetime(2024, 11, 20, 9, 12, 46, 715364), prerequisite_job_id='0de826f5-7202-4ac8-a947-c232bc3a0556', progress_message=None, progress_counter=None, sequence_length=None, traingraph=None),
TrainJob(job_id='5d4de46c-293e-446e-bad7-297db00b7744', job_type=<JobType.workflow_train: '/workflow/train'>, status=<JobStatus.SUCCESS: 'SUCCESS'>, created_date=datetime.datetime(2024, 10, 3, 4, 4, 24, 879680), start_date=datetime.datetime(2024, 10, 3, 4, 10, 53, 104694), end_date=datetime.datetime(2024, 10, 3, 4, 12, 21, 940962), prerequisite_job_id='18acb4e2-b474-458c-8110-df4002e8fcbf', progress_message=None, progress_counter=None, sequence_length=None, traingraph=None),
TrainJob(job_id='ca3e5915-14a3-4930-b999-66443158fac5', job_type=<JobType.workflow_train: '/workflow/train'>, status=<JobStatus.SUCCESS: 'SUCCESS'>, created_date=datetime.datetime(2024, 8, 26, 4, 49, 4, 112065), start_date=datetime.datetime(2024, 8, 26, 4, 49, 4, 164811), end_date=datetime.datetime(2024, 8, 26, 4, 49, 35, 757422), prerequisite_job_id='11f70942-2fc5-4528-9f51-847ec9b4c068', progress_message=None, progress_counter=None, sequence_length=None, traingraph=None)]
Note that with list
, you will still have to put that through load_job
in order to load the model itself.
The model object is a more complex object (has predict
method) and is thus not loaded when calling list
.
[ ]:
train = session.load_job(train_jobs[0].job_id)
train
TrainJob(job_id='6f8dd9a8-0370-4f3b-9ff6-43281a6f3428', job_type=<JobType.workflow_train: '/workflow/train'>, status=<JobStatus.SUCCESS: 'SUCCESS'>, created_date=datetime.datetime(2024, 12, 2, 8, 42, 54, 93414), start_date=datetime.datetime(2024, 12, 2, 8, 42, 54, 160768), end_date=datetime.datetime(2024, 12, 2, 8, 46, 40, 984249), prerequisite_job_id='0de826f5-7202-4ac8-a947-c232bc3a0556', progress_message=None, progress_counter=None, sequence_length=None, traingraph=None)
Finally you can also list models associated with an assay. These models are associated with a train job which can be loaded as before.
[ ]:
train = session.load_job(assay.list_models()[0]["job_id"])
train
TrainJob(job_id='5d4de46c-293e-446e-bad7-297db00b7744', job_type=<JobType.workflow_train: '/workflow/train'>, status=<JobStatus.SUCCESS: 'SUCCESS'>, created_date=datetime.datetime(2024, 10, 3, 4, 4, 24, 879680), start_date=datetime.datetime(2024, 10, 3, 4, 10, 53, 104694), end_date=datetime.datetime(2024, 10, 3, 4, 12, 21, 940962), prerequisite_job_id='18acb4e2-b474-458c-8110-df4002e8fcbf', progress_message=None, progress_counter=None, sequence_length=None, traingraph=None)
Next steps#
Our Predictor API page contains more information about training your models.
You can use your trained model to perform a single site analysis or design sequences. See Using single site analysis and Designing sequences for details.