{ "cells": [ { "cell_type": "code", "execution_count": 2, "id": "f3b7f6b0-6685-4a5c-9529-45e0ca905a3b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "* Running on local URL: http://127.0.0.1:7860\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import gradio as gr\n", "import requests\n", "from Bio.PDB import PDBParser\n", "import numpy as np\n", "import os\n", "from gradio_molecule3d import Molecule3D\n", "\n", "def read_mol(pdb_path):\n", " \"\"\"Read PDB file and return its content as a string\"\"\"\n", " with open(pdb_path, 'r') as f:\n", " return f.read()\n", "\n", "def fetch_pdb(pdb_id):\n", " pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'\n", " pdb_path = f'{pdb_id}.pdb'\n", " response = requests.get(pdb_url)\n", " if response.status_code == 200:\n", " with open(pdb_path, 'wb') as f:\n", " f.write(response.content)\n", " return pdb_path\n", " else:\n", " return None\n", "\n", "def process_pdb(pdb_id, segment):\n", " pdb_path = fetch_pdb(pdb_id)\n", " if not pdb_path:\n", " return \"Failed to fetch PDB file\", None, None\n", " \n", " parser = PDBParser(QUIET=1)\n", " structure = parser.get_structure('protein', pdb_path)\n", " \n", " try:\n", " chain = structure[0][segment]\n", " except KeyError:\n", " return \"Invalid Chain ID\", None, None\n", " \n", " # Comprehensive amino acid mapping\n", " aa_dict = {\n", " 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',\n", " 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',\n", " 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',\n", " 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y',\n", " 'MSE': 'M', 'SEP': 'S', 'TPO': 'T', 'CSO': 'C', 'PTR': 'Y', 'HYP': 'P'\n", " }\n", " \n", " # Exclude non-amino acid residues\n", " sequence = [\n", " residue for residue in chain \n", " if residue.get_resname().strip() in aa_dict\n", " ]\n", " \n", " random_scores = np.random.rand(len(sequence))\n", " result_str = \"\\n\".join(\n", " f\"{aa_dict[res.get_resname()]} {res.id[1]} {score:.2f}\" \n", " for res, score in zip(sequence, random_scores)\n", " )\n", " \n", " # Save the predictions to a file\n", " prediction_file = f\"{pdb_id}_predictions.txt\"\n", " with open(prediction_file, \"w\") as f:\n", " f.write(result_str)\n", " \n", " return result_str, molecule(pdb_path, random_scores, segment), prediction_file\n", "\n", "def molecule(input_pdb, scores=None, segment='A'):\n", " mol = read_mol(input_pdb) # Read PDB file content\n", " \n", " # Prepare high-scoring residues script if scores are provided\n", " high_score_script = \"\"\n", " if scores is not None:\n", " high_score_script = \"\"\"\n", " // Reset all styles first\n", " viewer.getModel(0).setStyle({}, {});\n", " \n", " // Show only the selected chain\n", " viewer.getModel(0).setStyle(\n", " {\"chain\": \"%s\"}, \n", " { cartoon: {colorscheme:\"whiteCarbon\"} }\n", " );\n", " \n", " // Highlight high-scoring residues only for the selected chain\n", " let highScoreResidues = [%s];\n", " viewer.getModel(0).setStyle(\n", " {\"chain\": \"%s\", \"resi\": highScoreResidues}, \n", " {\"stick\": {\"color\": \"red\"}}\n", " );\n", " \"\"\" % (segment, \n", " \", \".join(str(i+1) for i, score in enumerate(scores) if score > 0.8),\n", " segment)\n", " \n", " html_content = f\"\"\"\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " \n", " \"\"\"\n", " \n", " # Return the HTML content within an iframe safely encoded for special characters\n", " return f''\n", "\n", "reps = [\n", " {\n", " \"model\": 0,\n", " \"style\": \"cartoon\",\n", " \"color\": \"whiteCarbon\",\n", " \"residue_range\": \"\",\n", " \"around\": 0,\n", " \"byres\": False,\n", " }\n", " ]\n", "# Gradio UI\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\"# Protein Binding Site Prediction (Random Scores)\")\n", " with gr.Row():\n", " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n", " visualize_btn = gr.Button(\"Visualize Structure\")\n", "\n", " molecule_output2 = Molecule3D(label=\"Protein Structure\", reps=reps)\n", "\n", " with gr.Row():\n", " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n", " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n", " prediction_btn = gr.Button(\"Predict Random Binding Site Scores\")\n", "\n", " molecule_output = gr.HTML(label=\"Protein Structure\")\n", " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n", " download_output = gr.File(label=\"Download Predictions\")\n", " \n", " visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output2)\n", " \n", " prediction_btn.click(process_pdb, inputs=[pdb_input, segment_input], outputs=[predictions_output, molecule_output, download_output])\n", " \n", " gr.Markdown(\"## Examples\")\n", " gr.Examples(\n", " examples=[\n", " [\"2IWI\", \"A\"],\n", " [\"7RPZ\", \"B\"],\n", " [\"3TJN\", \"C\"]\n", " ],\n", " inputs=[pdb_input, segment_input],\n", " outputs=[predictions_output, molecule_output, download_output]\n", " )\n", "\n", "demo.launch()" ] }, { "cell_type": "code", "execution_count": 6, "id": "28f8f28c-48d3-4e35-9766-3de9882179b5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "* Running on local URL: http://127.0.0.1:7864\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import gradio as gr\n", "import requests\n", "from Bio.PDB import PDBParser\n", "import numpy as np\n", "import os\n", "from gradio_molecule3d import Molecule3D\n", "\n", "def read_mol(pdb_path):\n", " \"\"\"Read PDB file and return its content as a string\"\"\"\n", " with open(pdb_path, 'r') as f:\n", " return f.read()\n", "\n", "def fetch_pdb(pdb_id):\n", " pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'\n", " pdb_path = f'{pdb_id}.pdb'\n", " response = requests.get(pdb_url)\n", " if response.status_code == 200:\n", " with open(pdb_path, 'wb') as f:\n", " f.write(response.content)\n", " return pdb_path\n", " else:\n", " return None\n", "\n", "def process_pdb(pdb_id, segment):\n", " pdb_path = fetch_pdb(pdb_id)\n", " if not pdb_path:\n", " return \"Failed to fetch PDB file\", None, None\n", " \n", " parser = PDBParser(QUIET=1)\n", " structure = parser.get_structure('protein', pdb_path)\n", " \n", " try:\n", " chain = structure[0][segment]\n", " except KeyError:\n", " return \"Invalid Chain ID\", None, None\n", " \n", " # Comprehensive amino acid mapping\n", " aa_dict = {\n", " 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',\n", " 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',\n", " 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',\n", " 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y',\n", " 'MSE': 'M', 'SEP': 'S', 'TPO': 'T', 'CSO': 'C', 'PTR': 'Y', 'HYP': 'P'\n", " }\n", " \n", " # Exclude non-amino acid residues\n", " sequence = [\n", " residue for residue in chain \n", " if residue.get_resname().strip() in aa_dict\n", " ]\n", " \n", " random_scores = np.random.rand(len(sequence))\n", " result_str = \"\\n\".join(\n", " f\"{aa_dict[res.get_resname()]} {res.id[1]} {score:.2f}\" \n", " for res, score in zip(sequence, random_scores)\n", " )\n", " \n", " # Save the predictions to a file\n", " prediction_file = f\"{pdb_id}_predictions.txt\"\n", " with open(prediction_file, \"w\") as f:\n", " f.write(result_str)\n", " \n", " return result_str, molecule(pdb_path, random_scores, segment), prediction_file\n", "\n", "def molecule(input_pdb, scores=None, segment='A'):\n", " mol = read_mol(input_pdb) # Read PDB file content\n", " \n", " # Prepare high-scoring residues script if scores are provided\n", " high_score_script = \"\"\n", " if scores is not None:\n", " high_score_script = \"\"\"\n", " // Reset all styles first\n", " viewer.getModel(0).setStyle({}, {});\n", " \n", " // Show only the selected chain\n", " viewer.getModel(0).setStyle(\n", " {\"chain\": \"%s\"}, \n", " { cartoon: {colorscheme:\"whiteCarbon\"} }\n", " );\n", " \n", " // Highlight high-scoring residues only for the selected chain\n", " let highScoreResidues = [%s];\n", " viewer.getModel(0).setStyle(\n", " {\"chain\": \"%s\", \"resi\": highScoreResidues}, \n", " {\"stick\": {\"color\": \"red\"}}\n", " );\n", " \"\"\" % (segment, \n", " \", \".join(str(i+1) for i, score in enumerate(scores) if score > 0.8),\n", " segment)\n", " \n", " html_content = f\"\"\"\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " \n", " \"\"\"\n", " \n", " # Return the HTML content within an iframe safely encoded for special characters\n", " return f''\n", "\n", "reps = [\n", " {\n", " \"model\": 0,\n", " \"style\": \"cartoon\",\n", " \"color\": \"whiteCarbon\",\n", " \"residue_range\": \"\",\n", " \"around\": 0,\n", " \"byres\": False,\n", " }\n", " ]\n", "\n", "# Gradio UI\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\"# Protein Binding Site Prediction (Random Scores)\")\n", " with gr.Row():\n", " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n", " visualize_btn = gr.Button(\"Visualize Structure\")\n", "\n", " molecule_output2 = Molecule3D(label=\"Protein Structure\", reps=reps)\n", "\n", " with gr.Row():\n", " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n", " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n", " prediction_btn = gr.Button(\"Predict Random Binding Site Scores\")\n", "\n", " molecule_output = gr.HTML(label=\"Protein Structure\")\n", " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n", " download_output = gr.File(label=\"Download Predictions\")\n", " \n", " visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output2)\n", " \n", " prediction_btn.click(process_pdb, inputs=[pdb_input, segment_input], outputs=[predictions_output, molecule_output, download_output])\n", " \n", " gr.Markdown(\"## Examples\")\n", " gr.Examples(\n", " examples=[\n", " [\"2IWI\", \"A\"],\n", " [\"7RPZ\", \"B\"],\n", " [\"3TJN\", \"C\"]\n", " ],\n", " inputs=[pdb_input, segment_input],\n", " outputs=[predictions_output, molecule_output, download_output]\n", " )\n", "\n", "demo.launch()" ] }, { "cell_type": "code", "execution_count": null, "id": "517a2fe7-419f-4d0b-a9ed-62a22c1c1284", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 11, "id": "d62be1b5-762e-4b69-aed4-e4ba2a44482f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "* Running on local URL: http://127.0.0.1:7867\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import gradio as gr\n", "import requests\n", "from Bio.PDB import PDBParser\n", "import numpy as np\n", "import os\n", "from gradio_molecule3d import Molecule3D\n", "\n", "def read_mol(pdb_path):\n", " \"\"\"Read PDB file and return its content as a string\"\"\"\n", " with open(pdb_path, 'r') as f:\n", " return f.read()\n", "\n", "def fetch_pdb(pdb_id):\n", " pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'\n", " pdb_path = f'{pdb_id}.pdb'\n", " response = requests.get(pdb_url)\n", " if response.status_code == 200:\n", " with open(pdb_path, 'wb') as f:\n", " f.write(response.content)\n", " return pdb_path\n", " else:\n", " return None\n", "\n", "def process_pdb(pdb_id, segment):\n", " pdb_path = fetch_pdb(pdb_id)\n", " if not pdb_path:\n", " return \"Failed to fetch PDB file\", None, None\n", " \n", " parser = PDBParser(QUIET=1)\n", " structure = parser.get_structure('protein', pdb_path)\n", " \n", " try:\n", " chain = structure[0][segment]\n", " except KeyError:\n", " return \"Invalid Chain ID\", None, None\n", " \n", " # Comprehensive amino acid mapping\n", " aa_dict = {\n", " 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',\n", " 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',\n", " 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',\n", " 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y',\n", " 'MSE': 'M', 'SEP': 'S', 'TPO': 'T', 'CSO': 'C', 'PTR': 'Y', 'HYP': 'P'\n", " }\n", " \n", " # Exclude non-amino acid residues\n", " sequence = [\n", " residue for residue in chain \n", " if residue.get_resname().strip() in aa_dict\n", " ]\n", " \n", " random_scores = np.random.rand(len(sequence))\n", " result_str = \"\\n\".join(\n", " f\"{aa_dict[res.get_resname()]} {res.id[1]} {score:.2f}\" \n", " for res, score in zip(sequence, random_scores)\n", " )\n", " \n", " # Save the predictions to a file\n", " prediction_file = f\"{pdb_id}_predictions.txt\"\n", " with open(prediction_file, \"w\") as f:\n", " f.write(result_str)\n", " \n", " return result_str, molecule(pdb_path, random_scores, segment), prediction_file\n", "\n", "def molecule(input_pdb, scores=None, segment='A'):\n", " mol = read_mol(input_pdb) # Read PDB file content\n", " \n", " # Prepare high-scoring residues script if scores are provided\n", " high_score_script = \"\"\n", " if scores is not None:\n", " high_score_script = \"\"\"\n", " // Reset all styles first\n", " viewer.getModel(0).setStyle({}, {});\n", " \n", " // Show only the selected chain\n", " viewer.getModel(0).setStyle(\n", " {\"chain\": \"%s\"}, \n", " { cartoon: {colorscheme:\"whiteCarbon\"} }\n", " );\n", " \n", " // Highlight high-scoring residues only for the selected chain\n", " let highScoreResidues = [%s];\n", " viewer.getModel(0).setStyle(\n", " {\"chain\": \"%s\", \"resi\": highScoreResidues}, \n", " {\"stick\": {\"color\": \"red\"}}\n", " );\n", "\n", " // Highlight high-scoring residues only for the selected chain\n", " let highScoreResidues2 = [%s];\n", " viewer.getModel(0).setStyle(\n", " {\"chain\": \"%s\", \"resi\": highScoreResidues2}, \n", " {\"stick\": {\"color\": \"orange\"}}\n", " );\n", " \"\"\" % (segment, \n", " \", \".join(str(i+1) for i, score in enumerate(scores) if score > 0.8),\n", " segment,\n", " \", \".join(str(i+1) for i, score in enumerate(scores) if (score > 0.5) and (score < 0.8)),\n", " segment)\n", " \n", " html_content = f\"\"\"\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " \n", " \"\"\"\n", " \n", " # Return the HTML content within an iframe safely encoded for special characters\n", " return f''\n", "\n", "reps = [\n", " {\n", " \"model\": 0,\n", " \"style\": \"cartoon\",\n", " \"color\": \"whiteCarbon\",\n", " \"residue_range\": \"\",\n", " \"around\": 0,\n", " \"byres\": False,\n", " }\n", " ]\n", "\n", "# Gradio UI\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\"# Protein Binding Site Prediction (Random Scores)\")\n", " with gr.Row():\n", " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n", " visualize_btn = gr.Button(\"Visualize Structure\")\n", "\n", " molecule_output2 = Molecule3D(label=\"Protein Structure\", reps=reps)\n", "\n", " with gr.Row():\n", " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n", " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n", " prediction_btn = gr.Button(\"Predict Random Binding Site Scores\")\n", "\n", " molecule_output = gr.HTML(label=\"Protein Structure\")\n", " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n", " download_output = gr.File(label=\"Download Predictions\")\n", " \n", " visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output2)\n", " \n", " prediction_btn.click(process_pdb, inputs=[pdb_input, segment_input], outputs=[predictions_output, molecule_output, download_output])\n", " \n", " gr.Markdown(\"## Examples\")\n", " gr.Examples(\n", " examples=[\n", " [\"2IWI\", \"A\"],\n", " [\"7RPZ\", \"B\"],\n", " [\"3TJN\", \"C\"]\n", " ],\n", " inputs=[pdb_input, segment_input],\n", " outputs=[predictions_output, molecule_output, download_output]\n", " )\n", "\n", "demo.launch()" ] }, { "cell_type": "code", "execution_count": null, "id": "30f35243-852f-4771-9a4b-5cdd198552b5", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "5eca6754-4aa1-463f-881a-25d2a0d6bb5b", "metadata": {}, "outputs": [], "source": [ "import gradio as gr\n", "import requests\n", "from Bio.PDB import PDBParser\n", "import numpy as np\n", "import os\n", "from gradio_molecule3d import Molecule3D\n", "\n", "\n", "from model_loader import load_model\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader\n", "\n", "import re\n", "import pandas as pd\n", "import copy\n", "\n", "import transformers, datasets\n", "from transformers import AutoTokenizer\n", "from transformers import DataCollatorForTokenClassification\n", "\n", "from datasets import Dataset\n", "\n", "from scipy.special import expit\n", "\n", "# Load model and move to device\n", "checkpoint = 'ThorbenF/prot_t5_xl_uniref50'\n", "max_length = 1500\n", "model, tokenizer = load_model(checkpoint, max_length)\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "model.to(device)\n", "model.eval()\n", "\n", "def normalize_scores(scores):\n", " min_score = np.min(scores)\n", " max_score = np.max(scores)\n", " return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores\n", " \n", "def read_mol(pdb_path):\n", " \"\"\"Read PDB file and return its content as a string\"\"\"\n", " with open(pdb_path, 'r') as f:\n", " return f.read()\n", "\n", "def fetch_pdb(pdb_id):\n", " pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'\n", " pdb_path = f'{pdb_id}.pdb'\n", " response = requests.get(pdb_url)\n", " if response.status_code == 200:\n", " with open(pdb_path, 'wb') as f:\n", " f.write(response.content)\n", " return pdb_path\n", " else:\n", " return None\n", "\n", "def process_pdb(pdb_id, segment):\n", " pdb_path = fetch_pdb(pdb_id)\n", " if not pdb_path:\n", " return \"Failed to fetch PDB file\", None, None\n", " \n", " parser = PDBParser(QUIET=1)\n", " structure = parser.get_structure('protein', pdb_path)\n", " \n", " try:\n", " chain = structure[0][segment]\n", " except KeyError:\n", " return \"Invalid Chain ID\", None, None\n", " \n", " # Comprehensive amino acid mapping\n", " aa_dict = {\n", " 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',\n", " 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',\n", " 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',\n", " 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y',\n", " 'MSE': 'M', 'SEP': 'S', 'TPO': 'T', 'CSO': 'C', 'PTR': 'Y', 'HYP': 'P'\n", " }\n", " \n", " # Exclude non-amino acid residues\n", " sequence = [\n", " residue for residue in chain \n", " if residue.get_resname().strip() in aa_dict\n", " ]\n", " \n", " # Prepare input for model prediction\n", " input_ids = tokenizer(\" \".join(sequence), return_tensors=\"pt\").input_ids.to(device)\n", " with torch.no_grad():\n", " outputs = model(input_ids).logits.detach().cpu().numpy().squeeze()\n", "\n", " # Calculate scores and normalize them\n", " scores = expit(outputs[:, 1] - outputs[:, 0])\n", " normalized_scores = normalize_scores(scores)\n", "\n", " result_str = \"\\n\".join(\n", " f\"{aa_dict[res.get_resname()]} {res.id[1]} {score:.2f}\" \n", " for res, score in zip(sequence, normalized_scores)\n", " )\n", " \n", " # Save the predictions to a file\n", " prediction_file = f\"{pdb_id}_predictions.txt\"\n", " with open(prediction_file, \"w\") as f:\n", " f.write(result_str)\n", " \n", " return result_str, molecule(pdb_path, random_scores, segment), prediction_file\n", "\n", "def molecule(input_pdb, scores=None, segment='A'):\n", " mol = read_mol(input_pdb) # Read PDB file content\n", " \n", " # Prepare high-scoring residues script if scores are provided\n", " high_score_script = \"\"\n", " if scores is not None:\n", " high_score_script = \"\"\"\n", " // Reset all styles first\n", " viewer.getModel(0).setStyle({}, {});\n", " \n", " // Show only the selected chain\n", " viewer.getModel(0).setStyle(\n", " {\"chain\": \"%s\"}, \n", " { cartoon: {colorscheme:\"whiteCarbon\"} }\n", " );\n", " \n", " // Highlight high-scoring residues only for the selected chain\n", " let highScoreResidues = [%s];\n", " viewer.getModel(0).setStyle(\n", " {\"chain\": \"%s\", \"resi\": highScoreResidues}, \n", " {\"stick\": {\"color\": \"red\"}}\n", " );\n", "\n", " // Highlight high-scoring residues only for the selected chain\n", " let highScoreResidues2 = [%s];\n", " viewer.getModel(0).setStyle(\n", " {\"chain\": \"%s\", \"resi\": highScoreResidues2}, \n", " {\"stick\": {\"color\": \"orange\"}}\n", " );\n", " \"\"\" % (segment, \n", " \", \".join(str(i+1) for i, score in enumerate(scores) if score > 0.8),\n", " segment,\n", " \", \".join(str(i+1) for i, score in enumerate(scores) if (score > 0.5) and (score < 0.8)),\n", " segment)\n", " \n", " html_content = f\"\"\"\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " \n", " \"\"\"\n", " \n", " # Return the HTML content within an iframe safely encoded for special characters\n", " return f''\n", "\n", "reps = [\n", " {\n", " \"model\": 0,\n", " \"style\": \"cartoon\",\n", " \"color\": \"whiteCarbon\",\n", " \"residue_range\": \"\",\n", " \"around\": 0,\n", " \"byres\": False,\n", " }\n", " ]\n", "\n", "# Gradio UI\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\"# Protein Binding Site Prediction (Random Scores)\")\n", " with gr.Row():\n", " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n", " visualize_btn = gr.Button(\"Visualize Structure\")\n", "\n", " molecule_output2 = Molecule3D(label=\"Protein Structure\", reps=reps)\n", "\n", " with gr.Row():\n", " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n", " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n", " prediction_btn = gr.Button(\"Predict Random Binding Site Scores\")\n", "\n", " molecule_output = gr.HTML(label=\"Protein Structure\")\n", " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n", " download_output = gr.File(label=\"Download Predictions\")\n", " \n", " visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output2)\n", " \n", " prediction_btn.click(process_pdb, inputs=[pdb_input, segment_input], outputs=[predictions_output, molecule_output, download_output])\n", " \n", " gr.Markdown(\"## Examples\")\n", " gr.Examples(\n", " examples=[\n", " [\"2IWI\", \"A\"],\n", " [\"7RPZ\", \"B\"],\n", " [\"3TJN\", \"C\"]\n", " ],\n", " inputs=[pdb_input, segment_input],\n", " outputs=[predictions_output, molecule_output, download_output]\n", " )\n", "\n", "demo.launch()" ] }, { "cell_type": "code", "execution_count": null, "id": "95046d1c-ec7c-4e3e-8a98-1802cb09a25b", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "a37cbe6f-d57f-41e5-8ae1-38258da39d47", "metadata": {}, "outputs": [], "source": [ "import gradio as gr\n", "from model_loader import load_model\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader\n", "\n", "import re\n", "import numpy as np\n", "import os\n", "import pandas as pd\n", "import copy\n", "\n", "import transformers, datasets\n", "from transformers import AutoTokenizer\n", "from transformers import DataCollatorForTokenClassification\n", "\n", "from datasets import Dataset\n", "\n", "from scipy.special import expit\n", "\n", "import requests\n", "\n", "from gradio_molecule3d import Molecule3D\n", "\n", "# Biopython imports\n", "from Bio.PDB import PDBParser, Select, PDBIO\n", "from Bio.PDB.DSSP import DSSP\n", "from Bio.PDB import PDBList\n", "\n", "from matplotlib import cm # For color mapping\n", "from matplotlib.colors import Normalize\n", "\n", "# Load model and move to device\n", "checkpoint = 'ThorbenF/prot_t5_xl_uniref50'\n", "max_length = 1500\n", "model, tokenizer = load_model(checkpoint, max_length)\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "model.to(device)\n", "model.eval()\n", "\n", "# Function to fetch a PDB file\n", "def fetch_pdb(pdb_id):\n", " pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'\n", " pdb_path = f'pdb_files/{pdb_id}.pdb'\n", " os.makedirs('pdb_files', exist_ok=True)\n", " response = requests.get(pdb_url)\n", " if response.status_code == 200:\n", " with open(pdb_path, 'wb') as f:\n", " f.write(response.content)\n", " return pdb_path\n", " return None\n", "\n", "\n", "def normalize_scores(scores):\n", " min_score = np.min(scores)\n", " max_score = np.max(scores)\n", " return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores\n", "\n", "def process_pdb(pdb_id, segment):\n", " pdb_path = fetch_pdb(pdb_id)\n", " if not pdb_path:\n", " return \"Failed to fetch PDB file\", None, None\n", " \n", " parser = PDBParser(QUIET=1)\n", " structure = parser.get_structure('protein', pdb_path)\n", " chain = structure[0][segment]\n", " \n", " # Comprehensive amino acid mapping\n", " aa_dict = {\n", " 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',\n", " 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',\n", " 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',\n", " 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y',\n", " 'MSE': 'M', 'SEP': 'S', 'TPO': 'T', 'CSO': 'C', 'PTR': 'Y', 'HYP': 'P'\n", " }\n", " \n", " # Exclude non-amino acid residues\n", " sequence = \"\".join(\n", " aa_dict[residue.get_resname().strip()] \n", " for residue in chain \n", " if residue.get_resname().strip() in aa_dict\n", " )\n", " \n", " # Prepare input for model prediction\n", " input_ids = tokenizer(\" \".join(sequence), return_tensors=\"pt\").input_ids.to(device)\n", " with torch.no_grad():\n", " outputs = model(input_ids).logits.detach().cpu().numpy().squeeze()\n", "\n", " # Calculate scores and normalize them\n", " scores = expit(outputs[:, 1] - outputs[:, 0])\n", " normalized_scores = normalize_scores(scores)\n", " \n", " # Prepare the result string, including only amino acid residues\n", " result_str = \"\\n\".join([\n", " f\"{res.get_resname()} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}\" \n", " for i, res in enumerate(chain) if res.get_resname().strip() in aa_dict\n", " ])\n", " \n", " # Save predictions to file\n", " with open(f\"{pdb_id}_predictions.txt\", \"w\") as f:\n", " f.write(result_str)\n", " \n", " return result_str, pdb_path, f\"{pdb_id}_predictions.txt\"\n", "\n", "reps = [{\"model\": 0, \"style\": \"cartoon\", \"color\": \"spectrum\"}]\n", "\n", "# Gradio UI\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\"# Protein Binding Site Prediction\")\n", "\n", " with gr.Row():\n", " pdb_input = gr.Textbox(value=\"2IWI\",\n", " label=\"PDB ID\",\n", " placeholder=\"Enter PDB ID here...\")\n", " segment_input = gr.Textbox(value=\"A\",\n", " label=\"Chain ID (Segment)\",\n", " placeholder=\"Enter Chain ID here...\")\n", " visualize_btn = gr.Button(\"Visualize Sructure\")\n", " prediction_btn = gr.Button(\"Predict Ligand Binding Site\")\n", "\n", " molecule_output = Molecule3D(label=\"Protein Structure\", reps=reps)\n", " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n", " download_output = gr.File(label=\"Download Predictions\")\n", "\n", " visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output)\n", " prediction_btn.click(\n", " process_pdb, \n", " inputs=[pdb_input, segment_input], \n", " outputs=[predictions_output, molecule_output, download_output]\n", " )\n", "\n", " gr.Markdown(\"## Examples\")\n", " gr.Examples(\n", " examples=[\n", " [\"2IWI\"],\n", " [\"7RPZ\"],\n", " [\"3TJN\"]\n", " ],\n", " inputs=[pdb_input, segment_input], \n", " outputs=[predictions_output, molecule_output, download_output]\n", " )\n", "\n", "demo.launch(share=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "4c61bac4-4f2e-4f4a-aa1f-30dca209747c", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python (LLM)", "language": "python", "name": "llm" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.7" } }, "nbformat": 4, "nbformat_minor": 5 }