Source code for openprotein.fold.rosettafold3

"""Community-based RosettaFold3 models for complex structure prediction with ligands/dna/rna."""

import warnings
from typing import Sequence

from openprotein.align import MSAFuture
from openprotein.base import APISession
from openprotein.common import ModelMetadata
from openprotein.fold.common import (
    msa_future_to_complex,
    normalize_inputs,
    serialize_input,
)
from openprotein.fold.future import FoldResultFuture
from openprotein.molecules import DNA, RNA, Complex, Protein

from . import api
from .models import FoldModel


[docs] class RosettaFold3Model(FoldModel): """ Class providing inference endpoints for RosettaFold-3 structure prediction model. """ model_id: str = "rosettafold-3" def __init__( self, session: APISession, model_id: str, metadata: ModelMetadata | None = None, ): super().__init__(session, model_id, metadata)
[docs] def fold( self, sequences: Sequence[Complex | Protein | str | bytes] | MSAFuture, diffusion_samples: int = 1, num_recycles: int = 10, num_steps: int = 50, **kwargs, ) -> FoldResultFuture: """ Request structure prediction with RosettaFold-3 model. Parameters ---------- sequences: list[Complex | Protein | str | bytes] | MSAFuture, List of protein sequences to include in folded output. `Protein` objects must be tagged with an `msa`, which can be a `Protein.single_sequence_mode` for single sequence mode. Alternatively, supply an `MSAFuture` to use all query sequences as a multimer. diffusion_samples: int Number of diffusion samples to use num_recycles : int Number of recycling steps to use num_steps : int Number of sampling steps to use Returns ------- FoldResultFuture Future for the folding results. """ # build the normalized_models from msa if isinstance(sequences, MSAFuture): normalized_complexes = [msa_future_to_complex(self.session, sequences)] else: normalized_complexes = normalize_inputs(sequences) for complex in normalized_complexes: for id, chain in complex.get_chains().items(): if isinstance(chain, DNA) or isinstance(chain, RNA): with warnings.catch_warnings(): warnings.simplefilter("always") # Force warning to always show warnings.warn( "RosettaFold-3 does not support DNA/RNA input. These extra chains will be ignored in the output." ) del complex._chains[id] _complexes = serialize_input(self.session, normalized_complexes, needs_msa=True) if len(_complexes) == 0: raise ValueError("Expected proteins or ligands") return FoldResultFuture( session=self.session, job=api.fold_models_post( session=self.session, model_id=self.model_id, sequences=_complexes, diffusion_samples=diffusion_samples, num_recycles=num_recycles, num_steps=num_steps, **kwargs, ), complexes=normalized_complexes, )