cartnet-demo / main_local.py
Àlex Solé
fixed bug multiple structures csd
fa790e2
raw
history blame
6.24 kB
import streamlit as st
import os
from ase.io import read
from CifFile import ReadCif
from torch_geometric.data import Data, Batch
import torch
from models.master import create_model
from process import process_data
from utils import radius_graph_pbc
from io import BytesIO, StringIO
import gc
MEAN_TEMP = torch.tensor(192.1785) #training temp mean
STD_TEMP = torch.tensor(81.2135) #training temp std
@torch.no_grad()
def main():
model = create_model()
st.title("CartNet Thermal Ellipsoid Prediction")
st.image('fig/pipeline.png')
st.markdown("""
CartNet is a graph neural network specifically designed for predicting Anisotropic Displacement Parameters (ADPs) in crystal structures. The model has been trained on over 220,000 molecular crystal structures from the Cambridge Structural Database (CSD), making it highly accurate and robust for ADP prediction tasks. CartNet addresses the computational challenges of traditional methods by encoding the full 3D geometry of atomic structures into a Cartesian reference frame, bypassing the need for unit cell encoding. The model incorporates innovative features, including a neighbour equalization technique to enhance interaction detection and a Cholesky-based output layer to ensure valid ADP predictions. Additionally, it introduces a rotational SO(3) data augmentation technique to improve generalization across different crystal structure orientations, making the model highly efficient and accurate in predicting ADPs while significantly reducing computational costs.
""")
uploaded_file = st.file_uploader("Upload a CIF file", type=["cif"], accept_multiple_files=False)
if uploaded_file is not None:
try:
filename = str(uploaded_file.name)
file = BytesIO(uploaded_file.getbuffer())
cif = ReadCif(file)
if len(cif.keys())>1:
st.warning("⚠️ **Warning**: Found " + str(len(cif.keys())) + " blocks in the CIF file. We will process all of them and export as separate CIF files.")
st.markdown(f"### CIF file: {filename}")
for key in cif.keys():
st.markdown(f"### Block: {key}")
try:
block = "data_"+str(key)+"\n"+ cif[key].printsection()
atoms = read(StringIO(block), format="cif")
cif_data = cif[key]
if "_diffrn_ambient_temperature" in cif_data.keys():
temperature = float(cif_data["_diffrn_ambient_temperature"].split("(")[0])
elif "_cell_measurement_temperature" in cif_data.keys():
temperature = float(cif_data["_cell_measurement_temperature"].split("(")[0])
else:
st.error("Temperature not found in the CIF file. \
Please provide a temperature in the field _diffrn_ambient_temperature o in the field _cell_measurement_temperature from the CIF file.")
continue
st.success("CIF file successfully read.")
except Exception as e:
st.error(f"Error: {e}")
st.error(f"We couldn't find any structure for the block {key}. Please make sure the cif is compatible with ASE. If the error message is a blank line, it means ASE didn't found any coordinates.")
continue
data = Data()
data.x = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.int32)
data.pos = torch.tensor(atoms.positions, dtype=torch.float32)
data.temperature_og = torch.tensor([temperature], dtype=torch.float32)
data.temperature = (data.temperature_og - MEAN_TEMP) / STD_TEMP
data.cell = torch.tensor(atoms.cell.array, dtype=torch.float32).unsqueeze(0)
data.pbc = torch.tensor([True, True, True])
data.natoms = len(atoms)
del atoms
gc.collect()
batch = Batch.from_data_list([data])
edge_index, _, _, edge_attr = radius_graph_pbc(batch, 5.0, 64)
del batch
gc.collect()
data.cart_dist = torch.norm(edge_attr, dim=-1)
data.cart_dir = torch.nn.functional.normalize(edge_attr, dim=-1)
data.edge_index = edge_index
data.non_H_mask = data.x != 1
delattr(data, "pbc")
delattr(data, "natoms")
batch = Batch.from_data_list([data])
del data, edge_index, edge_attr
gc.collect()
st.success("Graph successfully created.")
cif_file = process_data(batch, model)
st.success("ADPs successfully predicted.")
cif_file = BytesIO(cif_file.getvalue().encode())
st.download_button(
label="Download processed CIF file",
data=cif_file,
file_name=f"output_{key}.cif",
mime="text/plain",
key=f"download_button_{key}"
)
gc.collect()
gc.collect()
except Exception as e:
st.error(f"An error occurred while reading the CIF file: {e}")
st.markdown("""
📌 The official implementation of the paper with all experiments can be found at [CartNet GitHub Repository](https://github.com/imatge-upc/CartNet).
""")
st.warning("""
⚠️ **Warning**: We use [ASE library](https://wiki.fysik.dtu.dk/ase/) for reading the cif files, please make sure it is compatible.
""")
st.markdown("""
### How to cite
If you use CartNet in your research, please cite our paper:
```bibtex
@article{your_paper_citation,
title={Title of the Paper},
author={Author1 and Author2 and Author3},
journal={Journal Name},
year={2023},
volume={XX},
number={YY},
pages={ZZZ}
}
```
""")
if __name__ == "__main__":
main()