Spaces:
Sleeping
Sleeping
Àlex Solé
commited on
Commit
·
f2b6066
1
Parent(s):
cd23593
updated code, readme and image
Browse files- README_streamlit.md → README_offline.md +18 -4
- fig/{pipeline.png → frontpage.png} +2 -2
- main.py +18 -47
- main_local.py +0 -139
- process.py +43 -2
README_streamlit.md → README_offline.md
RENAMED
@@ -1,19 +1,33 @@
|
|
1 |
# CartNet Streamlit Web App
|
2 |
|
3 |
-
 in crystal structures. The model has been trained on over 220,000 molecular crystal structures from the Cambridge Structural Database (CSD), making it highly accurate and robust for ADP prediction tasks. CartNet addresses the computational challenges of traditional methods by encoding the full 3D geometry of atomic structures into a Cartesian reference frame, bypassing the need for unit cell encoding. The model incorporates innovative features, including a neighbour equalization technique to enhance interaction detection and a Cholesky-based output layer to ensure valid ADP predictions. Additionally, it introduces a rotational SO(3) data augmentation technique to improve generalization across different crystal structure orientations, making the model highly efficient and accurate in predicting ADPs while significantly reducing computational costs.
|
9 |
|
10 |
This repository contains a web application based on the official implementation of CartNet, which can be found at [imatge-upc/CartNet](https://github.com/imatge-upc/CartNet).
|
11 |
|
12 |
-
⚠️ **Warning**: The online web application can only process systems with less than
|
13 |
|
14 |
## Local Application
|
15 |
### Installation of the local application
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
To set up the local application, you need to install the dependencies listed in `requirements.txt`. You can do this by running the following command:
|
18 |
|
19 |
```bash
|
@@ -37,7 +51,7 @@ python predict.py input.cif output.cif
|
|
37 |
Or, if you prefer, you can use the browser app on your local machine without the atom number limitation by running:
|
38 |
|
39 |
```bash
|
40 |
-
streamlit run
|
41 |
```
|
42 |
|
43 |
## How to cite
|
|
|
1 |
# CartNet Streamlit Web App
|
2 |
|
3 |
+

|
4 |
|
5 |
+
<h3 align="center">
|
6 |
+
🌐 <a href="https://imatge-upc.github.io/CartNet/" target="_blank">Project</a> |
|
7 |
+
🐙 <a href="https://github.com/imatge-upc/CartNet" target="_blank">GitHub</a> |
|
8 |
+
<!-- 📃 <a href="https://imatge-upc.github.io/CartNet/static/pdfs/CartNet.pdf" target="_blank">Paper</a> | -->
|
9 |
+
🤗 <a href="https://huggingface.co/spaces/alexsoleg/cartnet-demo" target="_blank">Demo</a>
|
10 |
+
</h3>
|
11 |
|
12 |
|
13 |
CartNet is a graph neural network specifically designed to predict Anisotropic Displacement Parameters (ADPs) in crystal structures. The model has been trained on over 220,000 molecular crystal structures from the Cambridge Structural Database (CSD), making it highly accurate and robust for ADP prediction tasks. CartNet addresses the computational challenges of traditional methods by encoding the full 3D geometry of atomic structures into a Cartesian reference frame, bypassing the need for unit cell encoding. The model incorporates innovative features, including a neighbour equalization technique to enhance interaction detection and a Cholesky-based output layer to ensure valid ADP predictions. Additionally, it introduces a rotational SO(3) data augmentation technique to improve generalization across different crystal structure orientations, making the model highly efficient and accurate in predicting ADPs while significantly reducing computational costs.
|
14 |
|
15 |
This repository contains a web application based on the official implementation of CartNet, which can be found at [imatge-upc/CartNet](https://github.com/imatge-upc/CartNet).
|
16 |
|
17 |
+
⚠️ **Warning**: The online web application can only process systems with less than 1000 atoms in the unit cell. For larger systems, please use the local application.
|
18 |
|
19 |
## Local Application
|
20 |
### Installation of the local application
|
21 |
|
22 |
+
Clone the following repository:
|
23 |
+
|
24 |
+
```bash
|
25 |
+
# Make sure you have git-lfs installed (https://git-lfs.com)
|
26 |
+
git lfs install
|
27 |
+
|
28 |
+
git clone https://huggingface.co/spaces/alexsoleg/cartnet-demo
|
29 |
+
```
|
30 |
+
|
31 |
To set up the local application, you need to install the dependencies listed in `requirements.txt`. You can do this by running the following command:
|
32 |
|
33 |
```bash
|
|
|
51 |
Or, if you prefer, you can use the browser app on your local machine without the atom number limitation by running:
|
52 |
|
53 |
```bash
|
54 |
+
streamlit run main.py -- --local
|
55 |
```
|
56 |
|
57 |
## How to cite
|
fig/{pipeline.png → frontpage.png}
RENAMED
File without changes
|
main.py
CHANGED
@@ -2,23 +2,22 @@ import streamlit as st
|
|
2 |
import os
|
3 |
from ase.io import read
|
4 |
from CifFile import ReadCif
|
5 |
-
from torch_geometric.data import Data, Batch
|
6 |
import torch
|
7 |
from models.master import create_model
|
8 |
-
from process import
|
9 |
-
|
10 |
import gc
|
11 |
from io import BytesIO, StringIO
|
|
|
|
|
12 |
|
13 |
-
|
14 |
-
STD_TEMP = torch.tensor(81.2135) #training temp std
|
15 |
|
16 |
|
17 |
-
|
18 |
-
def main():
|
19 |
model = create_model()
|
20 |
st.title("CartNet Thermal Ellipsoid Prediction")
|
21 |
-
st.image('fig/
|
22 |
|
23 |
st.markdown("""
|
24 |
CartNet is a graph neural network specifically designed for predicting Anisotropic Displacement Parameters (ADPs) in crystal structures. The model has been trained on over 220,000 molecular crystal structures from the Cambridge Structural Database (CSD), making it highly accurate and robust for ADP prediction tasks. CartNet addresses the computational challenges of traditional methods by encoding the full 3D geometry of atomic structures into a Cartesian reference frame, bypassing the need for unit cell encoding. The model incorporates innovative features, including a neighbour equalization technique to enhance interaction detection and a Cholesky-based output layer to ensure valid ADP predictions. Additionally, it introduces a rotational SO(3) data augmentation technique to improve generalization across different crystal structure orientations, making the model highly efficient and accurate in predicting ADPs while significantly reducing computational costs.
|
@@ -46,7 +45,7 @@ def main():
|
|
46 |
block = "data_"+str(key)+"\n"+ cif[key].printsection()
|
47 |
atoms = read(StringIO(block), format="cif")
|
48 |
|
49 |
-
if len(atoms.positions) > 1000:
|
50 |
st.error("""
|
51 |
⚠️ **Warning**: The structure is too large. Please upload a smaller one or use the [local implementation of CartNet Web App](https://github.com/alexsoleg/cartnet-streamlit/).
|
52 |
""")
|
@@ -68,39 +67,7 @@ def main():
|
|
68 |
|
69 |
continue
|
70 |
|
71 |
-
|
72 |
-
data.x = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.int32)
|
73 |
-
|
74 |
-
data.pos = torch.tensor(atoms.positions, dtype=torch.float32)
|
75 |
-
data.temperature_og = torch.tensor([temperature], dtype=torch.float32)
|
76 |
-
data.temperature = (data.temperature_og - MEAN_TEMP) / STD_TEMP
|
77 |
-
data.cell = torch.tensor(atoms.cell.array, dtype=torch.float32).unsqueeze(0)
|
78 |
-
|
79 |
-
data.pbc = torch.tensor([True, True, True])
|
80 |
-
data.natoms = len(atoms)
|
81 |
-
|
82 |
-
del atoms
|
83 |
-
gc.collect()
|
84 |
-
batch = Batch.from_data_list([data])
|
85 |
-
|
86 |
-
|
87 |
-
edge_index, _, _, edge_attr = radius_graph_pbc(batch, 5.0, 64)
|
88 |
-
del batch
|
89 |
-
gc.collect()
|
90 |
-
data.cart_dist = torch.norm(edge_attr, dim=-1)
|
91 |
-
data.cart_dir = torch.nn.functional.normalize(edge_attr, dim=-1)
|
92 |
-
data.edge_index = edge_index
|
93 |
-
data.non_H_mask = data.x != 1
|
94 |
-
delattr(data, "pbc")
|
95 |
-
delattr(data, "natoms")
|
96 |
-
batch = Batch.from_data_list([data])
|
97 |
-
del data, edge_index, edge_attr
|
98 |
-
gc.collect()
|
99 |
-
|
100 |
-
st.success("Graph successfully created.")
|
101 |
-
|
102 |
-
cif_file = process_data(batch, model)
|
103 |
-
st.success("ADPs successfully predicted.")
|
104 |
|
105 |
cif_file = BytesIO(cif_file.getvalue().encode())
|
106 |
st.download_button(
|
@@ -118,10 +85,11 @@ def main():
|
|
118 |
gc.collect()
|
119 |
except Exception as e:
|
120 |
st.error(f"An error occurred while reading the CIF file: {e}")
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
|
|
125 |
|
126 |
st.warning("""
|
127 |
⚠️ **Warning**: We use [ASE library](https://wiki.fysik.dtu.dk/ase/) for reading the CIF files, please make sure your CIF file is compatible.
|
@@ -150,6 +118,9 @@ def main():
|
|
150 |
""")
|
151 |
|
152 |
if __name__ == "__main__":
|
153 |
-
|
|
|
|
|
|
|
154 |
|
155 |
|
|
|
2 |
import os
|
3 |
from ase.io import read
|
4 |
from CifFile import ReadCif
|
|
|
5 |
import torch
|
6 |
from models.master import create_model
|
7 |
+
from process import process_ase
|
8 |
+
|
9 |
import gc
|
10 |
from io import BytesIO, StringIO
|
11 |
+
import argparse
|
12 |
+
|
13 |
|
14 |
+
torch.set_grad_enabled(False)
|
|
|
15 |
|
16 |
|
17 |
+
def main(local):
|
|
|
18 |
model = create_model()
|
19 |
st.title("CartNet Thermal Ellipsoid Prediction")
|
20 |
+
st.image('fig/frontpage.png')
|
21 |
|
22 |
st.markdown("""
|
23 |
CartNet is a graph neural network specifically designed for predicting Anisotropic Displacement Parameters (ADPs) in crystal structures. The model has been trained on over 220,000 molecular crystal structures from the Cambridge Structural Database (CSD), making it highly accurate and robust for ADP prediction tasks. CartNet addresses the computational challenges of traditional methods by encoding the full 3D geometry of atomic structures into a Cartesian reference frame, bypassing the need for unit cell encoding. The model incorporates innovative features, including a neighbour equalization technique to enhance interaction detection and a Cholesky-based output layer to ensure valid ADP predictions. Additionally, it introduces a rotational SO(3) data augmentation technique to improve generalization across different crystal structure orientations, making the model highly efficient and accurate in predicting ADPs while significantly reducing computational costs.
|
|
|
45 |
block = "data_"+str(key)+"\n"+ cif[key].printsection()
|
46 |
atoms = read(StringIO(block), format="cif")
|
47 |
|
48 |
+
if len(atoms.positions) > 1000 and not local:
|
49 |
st.error("""
|
50 |
⚠️ **Warning**: The structure is too large. Please upload a smaller one or use the [local implementation of CartNet Web App](https://github.com/alexsoleg/cartnet-streamlit/).
|
51 |
""")
|
|
|
67 |
|
68 |
continue
|
69 |
|
70 |
+
cif_file = process_ase(atoms, temperature, model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
cif_file = BytesIO(cif_file.getvalue().encode())
|
73 |
st.download_button(
|
|
|
85 |
gc.collect()
|
86 |
except Exception as e:
|
87 |
st.error(f"An error occurred while reading the CIF file: {e}")
|
88 |
+
|
89 |
+
if not local:
|
90 |
+
st.warning("""
|
91 |
+
⚠️ **Warning**: This online web application is designed for structures with up to 1000 atoms in the unit cell. For larger structures, please use the [local implementation of CartNet Web App](https://github.com/alexsoleg/cartnet-streamlit/).
|
92 |
+
""")
|
93 |
|
94 |
st.warning("""
|
95 |
⚠️ **Warning**: We use [ASE library](https://wiki.fysik.dtu.dk/ase/) for reading the CIF files, please make sure your CIF file is compatible.
|
|
|
118 |
""")
|
119 |
|
120 |
if __name__ == "__main__":
|
121 |
+
parser = argparse.ArgumentParser()
|
122 |
+
parser.add_argument('--local', action='store_true')
|
123 |
+
args = parser.parse_args()
|
124 |
+
main(args.local)
|
125 |
|
126 |
|
main_local.py
DELETED
@@ -1,139 +0,0 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
import os
|
3 |
-
from ase.io import read
|
4 |
-
from CifFile import ReadCif
|
5 |
-
from torch_geometric.data import Data, Batch
|
6 |
-
import torch
|
7 |
-
from models.master import create_model
|
8 |
-
from process import process_data
|
9 |
-
from utils import radius_graph_pbc
|
10 |
-
from io import BytesIO, StringIO
|
11 |
-
import gc
|
12 |
-
|
13 |
-
MEAN_TEMP = torch.tensor(192.1785) #training temp mean
|
14 |
-
STD_TEMP = torch.tensor(81.2135) #training temp std
|
15 |
-
|
16 |
-
|
17 |
-
@torch.no_grad()
|
18 |
-
def main():
|
19 |
-
model = create_model()
|
20 |
-
st.title("CartNet Thermal Ellipsoid Prediction")
|
21 |
-
st.image('fig/pipeline.png')
|
22 |
-
|
23 |
-
st.markdown("""
|
24 |
-
CartNet is a graph neural network specifically designed for predicting Anisotropic Displacement Parameters (ADPs) in crystal structures. The model has been trained on over 220,000 molecular crystal structures from the Cambridge Structural Database (CSD), making it highly accurate and robust for ADP prediction tasks. CartNet addresses the computational challenges of traditional methods by encoding the full 3D geometry of atomic structures into a Cartesian reference frame, bypassing the need for unit cell encoding. The model incorporates innovative features, including a neighbour equalization technique to enhance interaction detection and a Cholesky-based output layer to ensure valid ADP predictions. Additionally, it introduces a rotational SO(3) data augmentation technique to improve generalization across different crystal structure orientations, making the model highly efficient and accurate in predicting ADPs while significantly reducing computational costs.
|
25 |
-
""")
|
26 |
-
|
27 |
-
uploaded_file = st.file_uploader("Upload a CIF file", type=["cif"], accept_multiple_files=False)
|
28 |
-
|
29 |
-
|
30 |
-
if uploaded_file is not None:
|
31 |
-
try:
|
32 |
-
filename = str(uploaded_file.name)
|
33 |
-
file = BytesIO(uploaded_file.getbuffer())
|
34 |
-
cif = ReadCif(file)
|
35 |
-
|
36 |
-
if len(cif.keys())>1:
|
37 |
-
st.warning("⚠️ **Warning**: Found " + str(len(cif.keys())) + " blocks in the CIF file. We will process all of them and export as separate CIF files.")
|
38 |
-
|
39 |
-
st.markdown(f"### CIF file: {filename}")
|
40 |
-
for key in cif.keys():
|
41 |
-
st.markdown(f"### Structure: {key}")
|
42 |
-
try:
|
43 |
-
block = "data_"+str(key)+"\n"+ cif[key].printsection()
|
44 |
-
atoms = read(StringIO(block), format="cif")
|
45 |
-
|
46 |
-
cif_data = cif[key]
|
47 |
-
if "_diffrn_ambient_temperature" in cif_data.keys():
|
48 |
-
temperature = float(cif_data["_diffrn_ambient_temperature"].split("(")[0])
|
49 |
-
elif "_cell_measurement_temperature" in cif_data.keys():
|
50 |
-
temperature = float(cif_data["_cell_measurement_temperature"].split("(")[0])
|
51 |
-
else:
|
52 |
-
st.error("Temperature not found in the CIF file. \
|
53 |
-
Please provide a temperature in the field _diffrn_ambient_temperature o in the field _cell_measurement_temperature from the CIF file.")
|
54 |
-
continue
|
55 |
-
st.success("CIF file successfully read.")
|
56 |
-
except Exception as e:
|
57 |
-
st.error(f"Error: {e}")
|
58 |
-
st.error(f"We couldn't find any structure for the block {key}. Please make sure the CIF is compatible with ASE. If the error message is a blank line, it means ASE didn't found any coordinates.")
|
59 |
-
|
60 |
-
continue
|
61 |
-
|
62 |
-
data = Data()
|
63 |
-
data.x = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.int32)
|
64 |
-
|
65 |
-
data.pos = torch.tensor(atoms.positions, dtype=torch.float32)
|
66 |
-
data.temperature_og = torch.tensor([temperature], dtype=torch.float32)
|
67 |
-
data.temperature = (data.temperature_og - MEAN_TEMP) / STD_TEMP
|
68 |
-
data.cell = torch.tensor(atoms.cell.array, dtype=torch.float32).unsqueeze(0)
|
69 |
-
|
70 |
-
data.pbc = torch.tensor([True, True, True])
|
71 |
-
data.natoms = len(atoms)
|
72 |
-
|
73 |
-
del atoms
|
74 |
-
gc.collect()
|
75 |
-
batch = Batch.from_data_list([data])
|
76 |
-
|
77 |
-
|
78 |
-
edge_index, _, _, edge_attr = radius_graph_pbc(batch, 5.0, 64)
|
79 |
-
del batch
|
80 |
-
gc.collect()
|
81 |
-
data.cart_dist = torch.norm(edge_attr, dim=-1)
|
82 |
-
data.cart_dir = torch.nn.functional.normalize(edge_attr, dim=-1)
|
83 |
-
data.edge_index = edge_index
|
84 |
-
data.non_H_mask = data.x != 1
|
85 |
-
delattr(data, "pbc")
|
86 |
-
delattr(data, "natoms")
|
87 |
-
batch = Batch.from_data_list([data])
|
88 |
-
del data, edge_index, edge_attr
|
89 |
-
gc.collect()
|
90 |
-
|
91 |
-
st.success("Graph successfully created.")
|
92 |
-
|
93 |
-
cif_file = process_data(batch, model)
|
94 |
-
st.success("ADPs successfully predicted.")
|
95 |
-
|
96 |
-
cif_file = BytesIO(cif_file.getvalue().encode())
|
97 |
-
st.download_button(
|
98 |
-
label="Download processed CIF file",
|
99 |
-
data=cif_file,
|
100 |
-
file_name=f"output_{key}.cif",
|
101 |
-
mime="text/plain",
|
102 |
-
key=f"download_button_{key}"
|
103 |
-
)
|
104 |
-
|
105 |
-
gc.collect()
|
106 |
-
|
107 |
-
gc.collect()
|
108 |
-
except Exception as e:
|
109 |
-
st.error(f"An error occurred while reading the CIF file: {e}")
|
110 |
-
|
111 |
-
|
112 |
-
st.markdown("""
|
113 |
-
📌 The official implementation of the paper with all experiments can be found at [CartNet GitHub Repository](https://github.com/imatge-upc/CartNet).
|
114 |
-
""")
|
115 |
-
|
116 |
-
st.warning("""
|
117 |
-
⚠️ **Warning**: We use [ASE library](https://wiki.fysik.dtu.dk/ase/) for reading the CIF files, please make sure your CIF file is compatible.
|
118 |
-
""")
|
119 |
-
|
120 |
-
st.markdown("""
|
121 |
-
### How to cite
|
122 |
-
|
123 |
-
If you use CartNet in your research, please cite our paper:
|
124 |
-
|
125 |
-
```bibtex
|
126 |
-
@article{your_paper_citation,
|
127 |
-
title={Title of the Paper},
|
128 |
-
author={Author1 and Author2 and Author3},
|
129 |
-
journal={Journal Name},
|
130 |
-
year={2023},
|
131 |
-
volume={XX},
|
132 |
-
number={YY},
|
133 |
-
pages={ZZZ}
|
134 |
-
}
|
135 |
-
```
|
136 |
-
""")
|
137 |
-
|
138 |
-
if __name__ == "__main__":
|
139 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
process.py
CHANGED
@@ -1,10 +1,51 @@
|
|
|
|
1 |
import torch
|
|
|
2 |
from ase.io import write
|
3 |
from ase import Atoms
|
4 |
import gc
|
5 |
from io import BytesIO, StringIO
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
-
@torch.no_grad()
|
8 |
def process_data(batch, model, output_file="output.cif"):
|
9 |
atoms = batch.x.numpy().astype(int) # Atomic numbers
|
10 |
positions = batch.pos.numpy() # Atomic positions
|
@@ -72,7 +113,7 @@ def process_data(batch, model, output_file="output.cif"):
|
|
72 |
element_count[element] = 0
|
73 |
element_count[element] += 1
|
74 |
label = f"{element}{element_count[element]}"
|
75 |
-
u_iso = torch.trace(adps[indices[i]]).mean() if element != 'H' else 0.
|
76 |
type = "Uani" if element != 'H' else "Uiso"
|
77 |
cif_file.write(f"{label} {element} {frac_pos[0]} {frac_pos[1]} {frac_pos[2]} {u_iso} {type}\n")
|
78 |
|
|
|
1 |
+
import streamlit as st
|
2 |
import torch
|
3 |
+
from torch_geometric.data import Data, Batch
|
4 |
from ase.io import write
|
5 |
from ase import Atoms
|
6 |
import gc
|
7 |
from io import BytesIO, StringIO
|
8 |
+
from utils import radius_graph_pbc
|
9 |
+
|
10 |
+
MEAN_TEMP = torch.tensor(192.1785) #training temp mean
|
11 |
+
STD_TEMP = torch.tensor(81.2135) #training temp std
|
12 |
+
|
13 |
+
def process_ase(atoms, temperature, model):
|
14 |
+
data = Data()
|
15 |
+
data.x = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.int32)
|
16 |
+
|
17 |
+
data.pos = torch.tensor(atoms.positions, dtype=torch.float32)
|
18 |
+
data.temperature_og = torch.tensor([temperature], dtype=torch.float32)
|
19 |
+
data.temperature = (data.temperature_og - MEAN_TEMP) / STD_TEMP
|
20 |
+
data.cell = torch.tensor(atoms.cell.array, dtype=torch.float32).unsqueeze(0)
|
21 |
+
|
22 |
+
data.pbc = torch.tensor([True, True, True])
|
23 |
+
data.natoms = len(atoms)
|
24 |
+
|
25 |
+
del atoms
|
26 |
+
gc.collect()
|
27 |
+
batch = Batch.from_data_list([data])
|
28 |
+
|
29 |
+
|
30 |
+
edge_index, _, _, edge_attr = radius_graph_pbc(batch, 5.0, 64)
|
31 |
+
del batch
|
32 |
+
gc.collect()
|
33 |
+
data.cart_dist = torch.norm(edge_attr, dim=-1)
|
34 |
+
data.cart_dir = torch.nn.functional.normalize(edge_attr, dim=-1)
|
35 |
+
data.edge_index = edge_index
|
36 |
+
data.non_H_mask = data.x != 1
|
37 |
+
delattr(data, "pbc")
|
38 |
+
delattr(data, "natoms")
|
39 |
+
batch = Batch.from_data_list([data])
|
40 |
+
del data, edge_index, edge_attr
|
41 |
+
gc.collect()
|
42 |
+
|
43 |
+
st.success("Graph successfully created.")
|
44 |
+
|
45 |
+
cif_file = process_data(batch, model)
|
46 |
+
st.success("ADPs successfully predicted.")
|
47 |
+
return cif_file
|
48 |
|
|
|
49 |
def process_data(batch, model, output_file="output.cif"):
|
50 |
atoms = batch.x.numpy().astype(int) # Atomic numbers
|
51 |
positions = batch.pos.numpy() # Atomic positions
|
|
|
113 |
element_count[element] = 0
|
114 |
element_count[element] += 1
|
115 |
label = f"{element}{element_count[element]}"
|
116 |
+
u_iso = torch.trace(adps[indices[i]]).mean() if element != 'H' else 0.0001
|
117 |
type = "Uani" if element != 'H' else "Uiso"
|
118 |
cif_file.write(f"{label} {element} {frac_pos[0]} {frac_pos[1]} {frac_pos[2]} {u_iso} {type}\n")
|
119 |
|