Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import py3Dmol | |
import io | |
import numpy as np | |
import os | |
import traceback | |
from esm.sdk import client | |
from esm.sdk.api import ESM3InferenceClient, ESMProtein, GenerationConfig, ESMProteinError | |
from esm.utils.structure.protein_chain import ProteinChain | |
from Bio.Data import PDBData | |
import biotite.structure as bs | |
from biotite.structure.io import pdb | |
from esm.utils import residue_constants as RC | |
import requests | |
from dotenv import load_dotenv | |
import torch | |
import json | |
import time | |
load_dotenv() | |
API_URL = "https://forge.evolutionaryscale.ai/api/v1" | |
MODEL = "esm3-open-2024-03" | |
API_TOKEN = os.environ.get("ESM_API_TOKEN") | |
if not API_TOKEN: | |
raise ValueError("ESM_API_TOKEN environment variable is not set") | |
model = client( | |
model=MODEL, | |
url=API_URL, | |
token="2x0lifRJCpo8klurAJtRom" | |
) | |
amino3to1 = { | |
'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F', | |
'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L', | |
'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R', | |
'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y' | |
} | |
def read_pdb_io(pdb_file): | |
if isinstance(pdb_file, io.StringIO): | |
pdb_content = pdb_file.getvalue() | |
elif hasattr(pdb_file, 'name'): | |
with open(pdb_file.name, 'r') as f: | |
pdb_content = f.read() | |
else: | |
raise ValueError("Unsupported file type") | |
if not pdb_content.strip(): | |
raise ValueError("The PDB file is empty.") | |
pdb_io = io.StringIO(pdb_content) | |
return pdb_io, pdb_content | |
def get_protein(pdb_file) -> ESMProtein: | |
try: | |
pdb_io, content = read_pdb_io(pdb_file) | |
if not content.strip(): | |
raise ValueError("The PDB file is empty") | |
# Parse the PDB file using biotite | |
pdb_file = pdb.PDBFile.read(pdb_io) | |
structure = pdb_file.get_structure() | |
# Check if the structure contains any atoms | |
if structure.array_length() == 0: | |
raise ValueError("The PDB file does not contain any valid atoms") | |
# Filter for amino acids and create a sequence | |
valid_residues = [] | |
for res in bs.residue_iter(structure): | |
res_name = res.res_name | |
if isinstance(res_name, np.ndarray): | |
res_name = res_name[0] # Take the first element if it's an array | |
if res_name in amino3to1: | |
valid_residues.append(res) | |
if not valid_residues: | |
raise ValueError("No valid amino acid residues found in the PDB file") | |
sequence = ''.join(amino3to1.get(res.res_name[0] if isinstance(res.res_name, np.ndarray) else res.res_name, 'X') for res in valid_residues) | |
# Handle res_id as a potential sequence | |
residue_indices = [] | |
for res in valid_residues: | |
if isinstance(res.res_id, (list, tuple, np.ndarray)): | |
residue_indices.append(res.res_id[0]) # Take the first element if it's a sequence | |
else: | |
residue_indices.append(res.res_id) | |
# Create a ProteinChain object | |
protein_chain = ProteinChain( | |
id="test", | |
sequence=sequence, | |
chain_id="A", | |
entity_id=None, | |
residue_index=np.array(residue_indices, dtype=int), | |
insertion_code=np.full(len(sequence), "", dtype="<U4"), | |
atom37_positions=np.full((len(sequence), 37, 3), np.nan), | |
atom37_mask=np.zeros((len(sequence), 37), dtype=bool), | |
confidence=np.ones(len(sequence), dtype=np.float32) | |
) | |
# Fill in atom positions and mask | |
for i, res in enumerate(valid_residues): | |
for atom in res: | |
atom_name = atom.atom_name | |
if isinstance(atom_name, np.ndarray): | |
atom_name = atom_name[0] # Take the first element if it's an array | |
if atom_name in RC.atom_order: | |
idx = RC.atom_order[atom_name] | |
coord = atom.coord | |
if coord.ndim > 1: | |
coord = coord[0] # Take the first coordinate set if multiple are present | |
protein_chain.atom37_positions[i, idx] = coord | |
protein_chain.atom37_mask[i, idx] = True | |
protein = ESMProtein.from_protein_chain(protein_chain) | |
return protein | |
except Exception as e: | |
print(f"Error processing PDB file: {str(e)}") | |
raise ValueError(f"Unable to process the PDB file: {str(e)}") | |
def add_noise_to_coordinates(protein: ESMProtein, noise_level: float) -> ESMProtein: | |
"""Add Gaussian noise to the atom positions of the protein.""" | |
coordinates = protein.coordinates | |
noise = np.random.randn(*coordinates.shape) * noise_level | |
noisy_coordinates = coordinates + noise | |
return ESMProtein(sequence=protein.sequence, coordinates=noisy_coordinates) | |
def run_structure_prediction(protein: ESMProtein) -> ESMProtein: | |
structure_prediction_config = GenerationConfig( | |
track="structure", | |
num_steps=10, | |
temperature=0.7, | |
) | |
try: | |
response = model.generate(protein, structure_prediction_config) | |
if isinstance(response, ESMProtein): | |
return response | |
elif isinstance(response, ESMProteinError): | |
print(f"ESMProteinError during structure prediction: {response.error_msg}") | |
return None | |
else: | |
raise ValueError(f"Unexpected response type: {type(response)}") | |
except Exception as e: | |
print(f"Error during structure prediction: {str(e)}") | |
return None | |
def align_after_prediction(protein: ESMProtein, structure_prediction: ESMProtein) -> tuple[ESMProtein, float]: | |
if structure_prediction is None: | |
return None, float('inf') | |
try: | |
structure_prediction_chain = structure_prediction.to_protein_chain() | |
protein_chain = protein.to_protein_chain() | |
# Ensure both chains have the same length | |
min_length = min(len(structure_prediction_chain.sequence), len(protein_chain.sequence)) | |
structure_indices = np.arange(0, min_length) | |
# Perform alignment | |
aligned_chain = structure_prediction_chain.align( | |
protein_chain, | |
mobile_inds=structure_indices, | |
target_inds=structure_indices | |
) | |
# Calculate RMSD | |
crmsd = structure_prediction_chain.rmsd( | |
protein_chain, | |
mobile_inds=structure_indices, | |
target_inds=structure_indices | |
) | |
return ESMProtein.from_protein_chain(aligned_chain), crmsd | |
except AttributeError as e: | |
print(f"Error during alignment: {str(e)}") | |
print(f"Structure prediction type: {type(structure_prediction)}") | |
print(f"Structure prediction attributes: {dir(structure_prediction)}") | |
return None, float('inf') | |
except Exception as e: | |
print(f"Unexpected error during alignment: {str(e)}") | |
return None, float('inf') | |
def visualize_after_pred(protein: ESMProtein, aligned: ESMProtein): | |
if aligned is None: | |
return None | |
viewer = py3Dmol.view(width=800, height=600) | |
viewer.addModel(protein_to_pdb(protein), "pdb") | |
viewer.setStyle({"cartoon": {"color": "lightgrey"}}) | |
viewer.addModel(protein_to_pdb(aligned), "pdb") | |
viewer.setStyle({"model": -1}, {"cartoon": {"color": "lightgreen"}}) | |
viewer.zoomTo() | |
return viewer.render() | |
def protein_to_pdb(protein: ESMProtein): | |
pdb_str = "" | |
for i, (aa, coords) in enumerate(zip(protein.sequence, protein.coordinates)): | |
for j, atom in enumerate(RC.atom_types): | |
if not torch.isnan(coords[j][0]): | |
x, y, z = coords[j].tolist() | |
pdb_str += f"ATOM {i*37+j+1:5d} {atom:3s} {aa:3s} A{i+1:4d} {x:8.3f}{y:8.3f}{z:8.3f}\n" | |
return pdb_str | |
def prediction_visualization(pdb_file, num_runs: int, noise_level: float, num_frames: int, progress=gr.Progress()): | |
protein = get_protein(pdb_file) | |
runs = [] | |
total_iterations = num_frames * num_runs | |
progress(0, desc="Starting predictions") | |
for frame in progress.tqdm(range(num_frames), desc="Processing frames"): | |
noisy_protein = add_noise_to_coordinates(protein, noise_level) | |
for i in range(num_runs): | |
progress((frame * num_runs + i + 1) / total_iterations, desc=f"Frame {frame+1}, Run {i+1}") | |
structure_prediction = run_structure_prediction(noisy_protein) | |
if structure_prediction is not None: | |
aligned, crmsd = align_after_prediction(protein, structure_prediction) | |
if aligned is not None: | |
runs.append((crmsd, aligned)) | |
time.sleep(0.1) # Small delay to allow for UI updates | |
if not runs: | |
return None, "No successful predictions" | |
best_aligned = sorted(runs, key=lambda x: x[0])[0] | |
view_data = visualize_after_pred(protein, best_aligned[1]) | |
return view_data, f"Best cRMSD: {best_aligned[0]:.4f}" | |
def run_prediction(pdb_file, num_runs, noise_level, num_frames, progress=gr.Progress()): | |
try: | |
if pdb_file is None: | |
return "Please upload a PDB file.", "No file uploaded" | |
progress(0, desc="Starting prediction") | |
view_data, crmsd_text = prediction_visualization(pdb_file, num_runs, noise_level, num_frames, progress) | |
if view_data is None: | |
return "No successful predictions were made. Try adjusting the parameters or check the PDB file.", crmsd_text | |
progress(0.9, desc="Rendering visualization") | |
html_content = f""" | |
<div style="height: 600px; width: 100%;"> | |
<script src="https://3dmol.csb.pitt.edu/build/3Dmol-min.js"></script> | |
<div id="container-{id(view_data)}" style="height: 100%; width: 100%; position: relative;"></div> | |
<script> | |
var viewer = $3Dmol.createViewer(document.getElementById("container-{id(view_data)}"), {{defaultcolors: $3Dmol.rasmolElementColors}}); | |
viewer.addModel({json.dumps(view_data['pdb'])}, "pdb"); | |
viewer.setStyle({{}}, {json.dumps(view_data['style'])}); | |
viewer.zoomTo(); | |
viewer.render(); | |
</script> | |
</div> | |
""" | |
progress(1.0, desc="Completed") | |
return html_content, crmsd_text | |
except Exception as e: | |
error_message = str(e) | |
stack_trace = traceback.format_exc() | |
return f""" | |
<div style='color: red;'> | |
<h3>Error:</h3> | |
<p>{error_message}</p> | |
<h4>Stack Trace:</h4> | |
<pre>{stack_trace}</pre> | |
</div> | |
""", "Error occurred" | |
def create_demo(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# Protein Structure Prediction and Visualization with Noise and MD Frames") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
pdb_file = gr.File(label="Upload PDB file") | |
num_runs = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of runs per frame") | |
noise_level = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.1, label="Noise level") | |
num_frames = gr.Slider(minimum=1, maximum=10, step=1, value=1, label="Number of MD frames") | |
run_button = gr.Button("Run Prediction") | |
with gr.Column(scale=2): | |
visualization = gr.HTML(label="3D Visualization") | |
alignment_result = gr.Textbox(label="Alignment Result") | |
run_button.click( | |
fn=run_prediction, | |
inputs=[pdb_file, num_runs, noise_level, num_frames], | |
outputs=[visualization, alignment_result] | |
) | |
gr.Markdown(""" | |
## How to use | |
1. Upload a PDB file using the file uploader. | |
2. Adjust the number of prediction runs per frame using the slider. | |
3. Set the noise level to add random perturbations to the structure. | |
4. Choose the number of MD frames to simulate. | |
5. Click the "Run Prediction" button to start the process. | |
6. The 3D visualization will show the original structure (grey) and the best predicted structure (green). | |
7. The alignment result will display the best cRMSD (lower is better). | |
## About | |
This demo uses the ESM3 model to predict protein structures from PDB files. | |
It runs multiple predictions with added noise and simulated MD frames, displaying the best result based on the lowest cRMSD. | |
""") | |
return demo | |
if __name__ == "__main__": | |
demo = create_demo() | |
demo.queue() | |
demo.launch() |