"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import gradio as gr\n",
"from gradio_molecule3d import Molecule3D\n",
"\n",
"\n",
"example = Molecule3D().example_value()\n",
"\n",
"\n",
"reps = [\n",
" {\n",
" \"model\": 0,\n",
" \"style\": \"cartoon\",\n",
" \"color\": \"whiteCarbon\",\n",
" \"residue_range\": \"\",\n",
" \"around\": 0,\n",
" \"byres\": False,\n",
" },\n",
" {\n",
" \"model\": 0,\n",
" \"chain\": \"A\",\n",
" \"resname\": \"HIS\",\n",
" \"style\": \"stick\",\n",
" \"color\": \"red\"\n",
" }\n",
" ]\n",
"\n",
"\n",
"\n",
"def predict(x):\n",
" print(\"predict function\", x)\n",
" print(x.name)\n",
" return x\n",
"\n",
"with gr.Blocks() as demo:\n",
" gr.Markdown(\"# Molecule3D\")\n",
" inp = Molecule3D(label=\"Molecule3D\", reps=reps)\n",
" out = Molecule3D(label=\"Output\", reps=reps)\n",
"\n",
" btn = gr.Button(\"Predict\")\n",
" gr.Markdown(\"\"\" \n",
" You can configure the default rendering of the molecule by adding a list of representations\n",
" \n",
" reps = [\n",
" {\n",
" \"model\": 0,\n",
" \"style\": \"cartoon\",\n",
" \"color\": \"whiteCarbon\",\n",
" \"residue_range\": \"\",\n",
" \"around\": 0,\n",
" \"byres\": False,\n",
" },\n",
" {\n",
" \"model\": 0,\n",
" \"chain\": \"A\",\n",
" \"resname\": \"HIS\",\n",
" \"style\": \"stick\",\n",
" \"color\": \"red\"\n",
" }\n",
" ]\n",
"
\n",
" \"\"\")\n",
" btn.click(predict, inputs=inp, outputs=out)\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" demo.launch()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d27cc368-26a0-42c2-a68a-8833de7bb4a0",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 8,
"id": "cdf7fd26-0464-40d9-9107-71c29dbcaef8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"* Running on local URL: http://127.0.0.1:7867\n",
"\n",
"To create a public link, set `share=True` in `launch()`.\n"
]
},
{
"data": {
"text/html": [
""
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": []
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/var/folders/tm/ym2tckv54b96ws82y3b7cqhh0000gn/T/ipykernel_11794/4072855226.py:39: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead.\n",
" colors = [cm.get_cmap('coolwarm')(score)[:3] for score in normalized_scores]\n",
"Traceback (most recent call last):\n",
" File \"/Users/thorben_froehlking/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/queueing.py\", line 622, in process_events\n",
" response = await route_utils.call_process_api(\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/Users/thorben_froehlking/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/route_utils.py\", line 323, in call_process_api\n",
" output = await app.get_blocks().process_api(\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/Users/thorben_froehlking/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/blocks.py\", line 2024, in process_api\n",
" data = await self.postprocess_data(block_fn, result[\"prediction\"], state)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/Users/thorben_froehlking/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/blocks.py\", line 1830, in postprocess_data\n",
" prediction_value = block.postprocess(prediction_value)\n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" File \"/Users/thorben_froehlking/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio_molecule3d/molecule3d.py\", line 210, in postprocess\n",
" orig_name=Path(file).name,\n",
" ^^^^^^^^^^\n",
" File \"/Users/thorben_froehlking/anaconda3/envs/LLM/lib/python3.12/pathlib.py\", line 1162, in __init__\n",
" super().__init__(*args)\n",
" File \"/Users/thorben_froehlking/anaconda3/envs/LLM/lib/python3.12/pathlib.py\", line 373, in __init__\n",
" raise TypeError(\n",
"TypeError: argument should be a str or an os.PathLike object where __fspath__ returns a str, not 'dict'\n"
]
}
],
"source": [
"import gradio as gr\n",
"import requests\n",
"from Bio.PDB import PDBParser\n",
"from gradio_molecule3d import Molecule3D\n",
"import numpy as np\n",
"from matplotlib import cm\n",
"\n",
"# Function to fetch a PDB file from RCSB PDB\n",
"def fetch_pdb(pdb_id):\n",
" pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'\n",
" pdb_path = f'{pdb_id}.pdb'\n",
" response = requests.get(pdb_url)\n",
" if response.status_code == 200:\n",
" with open(pdb_path, 'wb') as f:\n",
" f.write(response.content)\n",
" return pdb_path\n",
" else:\n",
" return None\n",
"\n",
"# Function to process the PDB file and return random predictions\n",
"def process_pdb(pdb_id, segment):\n",
" pdb_path = fetch_pdb(pdb_id)\n",
" if not pdb_path:\n",
" return \"Failed to fetch PDB file\", None, None, None\n",
"\n",
" parser = PDBParser(QUIET=True)\n",
" structure = parser.get_structure('protein', pdb_path)\n",
"\n",
" try:\n",
" chain = structure[0][segment]\n",
" except KeyError:\n",
" return \"Invalid Chain ID\", None, None, None\n",
"\n",
" sequence = [residue.get_resname() for residue in chain if residue.id[0] == ' ']\n",
" random_scores = np.random.rand(len(sequence))\n",
"\n",
" # Normalize scores for coloring (0 = blue, 1 = red)\n",
" normalized_scores = (random_scores - np.min(random_scores)) / (np.max(random_scores) - np.min(random_scores))\n",
" colors = [cm.get_cmap('coolwarm')(score)[:3] for score in normalized_scores]\n",
" hex_colors = [f'#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}' for r, g, b in colors]\n",
"\n",
" # Result string and representation\n",
" result_str = \"\\n\".join(\n",
" f\"{seq} {res.id[1]} {score:.2f}\" \n",
" for seq, res, score in zip(sequence, chain, random_scores)\n",
" )\n",
"\n",
" # Representation for the protein structure\n",
" reps = [\n",
" {\n",
" \"model\": 0,\n",
" \"style\": \"cartoon\",\n",
" \"color\": \"whiteCarbon\"\n",
" }\n",
" ] + [\n",
" {\n",
" \"model\": 0,\n",
" \"style\": \"cartoon\",\n",
" \"residue_index\": i,\n",
" \"color\": color\n",
" }\n",
" for i, color in enumerate(hex_colors)\n",
" ]\n",
"\n",
" # Save the predictions to a file\n",
" prediction_file = f\"{pdb_id}_predictions.txt\"\n",
" with open(prediction_file, \"w\") as f:\n",
" f.write(result_str)\n",
" \n",
" return result_str, reps, prediction_file\n",
"\n",
"# Gradio UI\n",
"with gr.Blocks() as demo:\n",
" gr.Markdown(\"# Protein Binding Site Prediction (Random Scores)\")\n",
"\n",
" with gr.Row():\n",
" pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
" segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n",
" visualize_btn = gr.Button(\"Visualize Structure\")\n",
" prediction_btn = gr.Button(\"Predict Random Binding Site Scores\")\n",
"\n",
" molecule_output = Molecule3D(label=\"Protein Structure\", reps=reps)\n",
" predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n",
" download_output = gr.File(label=\"Download Predictions\")\n",
"\n",
" prediction_btn.click(\n",
" fn=process_pdb,\n",
" inputs=[pdb_input, segment_input],\n",
" outputs=[predictions_output, molecule_output, download_output]\n",
" )\n",
"\n",
" gr.Markdown(\"## Examples\")\n",
" gr.Examples(\n",
" examples=[\n",
" [\"2IWI\", \"A\"],\n",
" [\"7RPZ\", \"B\"],\n",
" [\"3TJN\", \"C\"]\n",
" ],\n",
" inputs=[pdb_input, segment_input],\n",
" outputs=[predictions_output, molecule_output, download_output]\n",
" )\n",
"\n",
"demo.launch()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ee215c16-a1fb-450f-bb93-37aaee6fb3f1",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python (LLM)",
"language": "python",
"name": "llm"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}