|
import gradio as gr |
|
import py3Dmol |
|
import io |
|
import numpy as np |
|
import os |
|
import traceback |
|
from esm.sdk import client |
|
from esm.sdk.api import ESM3InferenceClient, ESMProtein, GenerationConfig, ESMProteinError |
|
from esm.utils.structure.protein_chain import ProteinChain |
|
from Bio.Data import PDBData |
|
import biotite.structure as bs |
|
from biotite.structure.io import pdb |
|
from esm.utils import residue_constants as RC |
|
import requests |
|
from dotenv import load_dotenv |
|
import torch |
|
|
|
load_dotenv() |
|
|
|
API_URL = "https://forge.evolutionaryscale.ai/api/v1" |
|
MODEL = "esm3-open-2024-03" |
|
API_TOKEN = os.environ.get("ESM_API_TOKEN") |
|
if not API_TOKEN: |
|
raise ValueError("ESM_API_TOKEN environment variable is not set") |
|
|
|
model = client( |
|
model=MODEL, |
|
url=API_URL, |
|
token="2x0lifRJCpo8klurAJtRom" |
|
) |
|
|
|
amino3to1 = { |
|
'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F', |
|
'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L', |
|
'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R', |
|
'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y' |
|
} |
|
|
|
|
|
def read_pdb_io(pdb_file): |
|
if isinstance(pdb_file, io.StringIO): |
|
pdb_content = pdb_file.getvalue() |
|
elif hasattr(pdb_file, 'name'): |
|
with open(pdb_file.name, 'r') as f: |
|
pdb_content = f.read() |
|
else: |
|
raise ValueError("Unsupported file type") |
|
|
|
if not pdb_content.strip(): |
|
raise ValueError("The PDB file is empty.") |
|
|
|
pdb_io = io.StringIO(pdb_content) |
|
return pdb_io, pdb_content |
|
|
|
def get_protein(pdb_file) -> ESMProtein: |
|
try: |
|
pdb_io, content = read_pdb_io(pdb_file) |
|
|
|
if not content.strip(): |
|
raise ValueError("The PDB file is empty") |
|
|
|
|
|
pdb_file = pdb.PDBFile.read(pdb_io) |
|
structure = pdb_file.get_structure() |
|
|
|
|
|
if structure.array_length() == 0: |
|
raise ValueError("The PDB file does not contain any valid atoms") |
|
|
|
|
|
valid_residues = [] |
|
for res in bs.residue_iter(structure): |
|
res_name = res.res_name |
|
if isinstance(res_name, np.ndarray): |
|
res_name = res_name[0] |
|
if res_name in amino3to1: |
|
valid_residues.append(res) |
|
|
|
if not valid_residues: |
|
raise ValueError("No valid amino acid residues found in the PDB file") |
|
|
|
sequence = ''.join(amino3to1.get(res.res_name[0] if isinstance(res.res_name, np.ndarray) else res.res_name, 'X') for res in valid_residues) |
|
|
|
|
|
residue_indices = [] |
|
for res in valid_residues: |
|
if isinstance(res.res_id, (list, tuple, np.ndarray)): |
|
residue_indices.append(res.res_id[0]) |
|
else: |
|
residue_indices.append(res.res_id) |
|
|
|
|
|
protein_chain = ProteinChain( |
|
id="test", |
|
sequence=sequence, |
|
chain_id="A", |
|
entity_id=None, |
|
residue_index=np.array(residue_indices, dtype=int), |
|
insertion_code=np.full(len(sequence), "", dtype="<U4"), |
|
atom37_positions=np.full((len(sequence), 37, 3), np.nan), |
|
atom37_mask=np.zeros((len(sequence), 37), dtype=bool), |
|
confidence=np.ones(len(sequence), dtype=np.float32) |
|
) |
|
|
|
|
|
for i, res in enumerate(valid_residues): |
|
for atom in res: |
|
atom_name = atom.atom_name |
|
if isinstance(atom_name, np.ndarray): |
|
atom_name = atom_name[0] |
|
if atom_name in RC.atom_order: |
|
idx = RC.atom_order[atom_name] |
|
coord = atom.coord |
|
if coord.ndim > 1: |
|
coord = coord[0] |
|
protein_chain.atom37_positions[i, idx] = coord |
|
protein_chain.atom37_mask[i, idx] = True |
|
|
|
protein = ESMProtein.from_protein_chain(protein_chain) |
|
return protein |
|
except Exception as e: |
|
print(f"Error processing PDB file: {str(e)}") |
|
raise ValueError(f"Unable to process the PDB file: {str(e)}") |
|
|
|
def add_noise_to_coordinates(protein: ESMProtein, noise_level: float) -> ESMProtein: |
|
"""Add Gaussian noise to the atom positions of the protein.""" |
|
coordinates = protein.coordinates |
|
noise = np.random.randn(*coordinates.shape) * noise_level |
|
noisy_coordinates = coordinates + noise |
|
return ESMProtein(sequence=protein.sequence, coordinates=noisy_coordinates) |
|
|
|
def run_structure_prediction(protein: ESMProtein) -> ESMProtein: |
|
structure_prediction_config = GenerationConfig( |
|
track="structure", |
|
num_steps=10, |
|
temperature=0.7, |
|
) |
|
try: |
|
response = model.generate(protein, structure_prediction_config) |
|
|
|
if isinstance(response, ESMProtein): |
|
return response |
|
elif isinstance(response, ESMProteinError): |
|
print(f"ESMProteinError during structure prediction: {response.error_msg}") |
|
return None |
|
else: |
|
raise ValueError(f"Unexpected response type: {type(response)}") |
|
except Exception as e: |
|
print(f"Error during structure prediction: {str(e)}") |
|
return None |
|
|
|
def align_after_prediction(protein: ESMProtein, structure_prediction: ESMProtein) -> tuple[ESMProtein, float]: |
|
if structure_prediction is None: |
|
return None, float('inf') |
|
|
|
try: |
|
structure_prediction_chain = structure_prediction.to_protein_chain() |
|
protein_chain = protein.to_protein_chain() |
|
|
|
|
|
min_length = min(len(structure_prediction_chain.sequence), len(protein_chain.sequence)) |
|
structure_indices = np.arange(0, min_length) |
|
|
|
|
|
aligned_chain = structure_prediction_chain.align( |
|
protein_chain, |
|
mobile_inds=structure_indices, |
|
target_inds=structure_indices |
|
) |
|
|
|
|
|
crmsd = structure_prediction_chain.rmsd( |
|
protein_chain, |
|
mobile_inds=structure_indices, |
|
target_inds=structure_indices |
|
) |
|
|
|
return ESMProtein.from_protein_chain(aligned_chain), crmsd |
|
except AttributeError as e: |
|
print(f"Error during alignment: {str(e)}") |
|
print(f"Structure prediction type: {type(structure_prediction)}") |
|
print(f"Structure prediction attributes: {dir(structure_prediction)}") |
|
return None, float('inf') |
|
except Exception as e: |
|
print(f"Unexpected error during alignment: {str(e)}") |
|
return None, float('inf') |
|
|
|
def visualize_after_pred(protein: ESMProtein, aligned: ESMProtein): |
|
if aligned is None: |
|
return py3Dmol.view(width=800, height=600) |
|
|
|
view = py3Dmol.view(width=800, height=600) |
|
view.addModel(protein_to_pdb(protein), "pdb") |
|
view.setStyle({"cartoon": {"color": "lightgrey"}}) |
|
view.addModel(protein_to_pdb(aligned), "pdb") |
|
view.setStyle({"model": 1}, {"cartoon": {"color": "lightgreen"}}) |
|
view.zoomTo() |
|
return view |
|
|
|
def protein_to_pdb(protein: ESMProtein): |
|
pdb_str = "" |
|
for i, (aa, coords) in enumerate(zip(protein.sequence, protein.coordinates)): |
|
for j, atom in enumerate(RC.atom_types): |
|
if not torch.isnan(coords[j][0]): |
|
x, y, z = coords[j].tolist() |
|
pdb_str += f"ATOM {i*37+j+1:5d} {atom:3s} {aa:3s} A{i+1:4d} {x:8.3f}{y:8.3f}{z:8.3f}\n" |
|
return pdb_str |
|
|
|
def prediction_visualization(pdb_file, num_runs: int, noise_level: float, num_frames: int): |
|
protein = get_protein(pdb_file) |
|
runs = [] |
|
|
|
for frame in range(num_frames): |
|
noisy_protein = add_noise_to_coordinates(protein, noise_level) |
|
|
|
for i in range(num_runs): |
|
structure_prediction = run_structure_prediction(noisy_protein) |
|
if structure_prediction is not None: |
|
aligned, crmsd = align_after_prediction(protein, structure_prediction) |
|
if aligned is not None: |
|
runs.append((crmsd, aligned)) |
|
|
|
if not runs: |
|
return None, "No successful predictions" |
|
|
|
best_aligned = sorted(runs, key=lambda x: x[0])[0] |
|
view = visualize_after_pred(protein, best_aligned[1]) |
|
return view, f"Best cRMSD: {best_aligned[0]:.4f}" |
|
|
|
def run_prediction(pdb_file, num_runs, noise_level, num_frames): |
|
try: |
|
if pdb_file is None: |
|
return "Please upload a PDB file.", "No file uploaded" |
|
|
|
view, crmsd_text = prediction_visualization(pdb_file, num_runs, noise_level, num_frames) |
|
if view is None: |
|
return "No successful predictions were made. Try adjusting the parameters or check the PDB file.", crmsd_text |
|
|
|
html = view._make_html() |
|
return f""" |
|
<div style="height: 600px;"> |
|
{html} |
|
</div> |
|
""", crmsd_text |
|
except Exception as e: |
|
error_message = str(e) |
|
stack_trace = traceback.format_exc() |
|
return f""" |
|
<div style='color: red;'> |
|
<h3>Error:</h3> |
|
<p>{error_message}</p> |
|
<h4>Stack Trace:</h4> |
|
<pre>{stack_trace}</pre> |
|
</div> |
|
""", "Error occurred" |
|
|
|
def create_demo(): |
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Protein Structure Prediction and Visualization with Noise and MD Frames") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
pdb_file = gr.File(label="Upload PDB file") |
|
num_runs = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of runs per frame") |
|
noise_level = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.1, label="Noise level") |
|
num_frames = gr.Slider(minimum=1, maximum=10, step=1, value=1, label="Number of MD frames") |
|
run_button = gr.Button("Run Prediction") |
|
|
|
with gr.Column(scale=2): |
|
visualization = gr.HTML(label="3D Visualization") |
|
alignment_result = gr.Textbox(label="Alignment Result") |
|
|
|
run_button.click( |
|
fn=run_prediction, |
|
inputs=[pdb_file, num_runs, noise_level, num_frames], |
|
outputs=[visualization, alignment_result] |
|
) |
|
|
|
gr.Markdown(""" |
|
## How to use |
|
1. Upload a PDB file using the file uploader. |
|
2. Adjust the number of prediction runs per frame using the slider. |
|
3. Set the noise level to add random perturbations to the structure. |
|
4. Choose the number of MD frames to simulate. |
|
5. Click the "Run Prediction" button to start the process. |
|
6. The 3D visualization will show the original structure (grey) and the best predicted structure (green). |
|
7. The alignment result will display the best cRMSD (lower is better). |
|
|
|
## About |
|
This demo uses the ESM3 model to predict protein structures from PDB files. |
|
It runs multiple predictions with added noise and simulated MD frames, displaying the best result based on the lowest cRMSD. |
|
""") |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
demo = create_demo() |
|
demo.queue() |
|
demo.launch() |