Àlex Solé commited on
Commit
f2b6066
·
1 Parent(s): cd23593

updated code, readme and image

Browse files
README_streamlit.md → README_offline.md RENAMED
@@ -1,19 +1,33 @@
1
  # CartNet Streamlit Web App
2
 
3
- ![Pipeline](fig/pipeline.png)
4
 
5
- ### CartNet online demo available at: [CartNet Web App](https://cartnet-adp-estimation.streamlit.app)
 
 
 
 
 
6
 
7
 
8
  CartNet is a graph neural network specifically designed to predict Anisotropic Displacement Parameters (ADPs) in crystal structures. The model has been trained on over 220,000 molecular crystal structures from the Cambridge Structural Database (CSD), making it highly accurate and robust for ADP prediction tasks. CartNet addresses the computational challenges of traditional methods by encoding the full 3D geometry of atomic structures into a Cartesian reference frame, bypassing the need for unit cell encoding. The model incorporates innovative features, including a neighbour equalization technique to enhance interaction detection and a Cholesky-based output layer to ensure valid ADP predictions. Additionally, it introduces a rotational SO(3) data augmentation technique to improve generalization across different crystal structure orientations, making the model highly efficient and accurate in predicting ADPs while significantly reducing computational costs.
9
 
10
  This repository contains a web application based on the official implementation of CartNet, which can be found at [imatge-upc/CartNet](https://github.com/imatge-upc/CartNet).
11
 
12
- ⚠️ **Warning**: The online web application can only process systems with less than 300 atoms in the unit cell. For large systems, please use the local application.
13
 
14
  ## Local Application
15
  ### Installation of the local application
16
 
 
 
 
 
 
 
 
 
 
17
  To set up the local application, you need to install the dependencies listed in `requirements.txt`. You can do this by running the following command:
18
 
19
  ```bash
@@ -37,7 +51,7 @@ python predict.py input.cif output.cif
37
  Or, if you prefer, you can use the browser app on your local machine without the atom number limitation by running:
38
 
39
  ```bash
40
- streamlit run main_local.py
41
  ```
42
 
43
  ## How to cite
 
1
  # CartNet Streamlit Web App
2
 
3
+ ![Pipeline](fig/frontpage.png)
4
 
5
+ <h3 align="center">
6
+ 🌐 <a href="https://imatge-upc.github.io/CartNet/" target="_blank">Project</a> |
7
+ 🐙 <a href="https://github.com/imatge-upc/CartNet" target="_blank">GitHub</a> |
8
+ <!-- 📃 <a href="https://imatge-upc.github.io/CartNet/static/pdfs/CartNet.pdf" target="_blank">Paper</a> | -->
9
+ 🤗 <a href="https://huggingface.co/spaces/alexsoleg/cartnet-demo" target="_blank">Demo</a>
10
+ </h3>
11
 
12
 
13
  CartNet is a graph neural network specifically designed to predict Anisotropic Displacement Parameters (ADPs) in crystal structures. The model has been trained on over 220,000 molecular crystal structures from the Cambridge Structural Database (CSD), making it highly accurate and robust for ADP prediction tasks. CartNet addresses the computational challenges of traditional methods by encoding the full 3D geometry of atomic structures into a Cartesian reference frame, bypassing the need for unit cell encoding. The model incorporates innovative features, including a neighbour equalization technique to enhance interaction detection and a Cholesky-based output layer to ensure valid ADP predictions. Additionally, it introduces a rotational SO(3) data augmentation technique to improve generalization across different crystal structure orientations, making the model highly efficient and accurate in predicting ADPs while significantly reducing computational costs.
14
 
15
  This repository contains a web application based on the official implementation of CartNet, which can be found at [imatge-upc/CartNet](https://github.com/imatge-upc/CartNet).
16
 
17
+ ⚠️ **Warning**: The online web application can only process systems with less than 1000 atoms in the unit cell. For larger systems, please use the local application.
18
 
19
  ## Local Application
20
  ### Installation of the local application
21
 
22
+ Clone the following repository:
23
+
24
+ ```bash
25
+ # Make sure you have git-lfs installed (https://git-lfs.com)
26
+ git lfs install
27
+
28
+ git clone https://huggingface.co/spaces/alexsoleg/cartnet-demo
29
+ ```
30
+
31
  To set up the local application, you need to install the dependencies listed in `requirements.txt`. You can do this by running the following command:
32
 
33
  ```bash
 
51
  Or, if you prefer, you can use the browser app on your local machine without the atom number limitation by running:
52
 
53
  ```bash
54
+ streamlit run main.py -- --local
55
  ```
56
 
57
  ## How to cite
fig/{pipeline.png → frontpage.png} RENAMED
File without changes
main.py CHANGED
@@ -2,23 +2,22 @@ import streamlit as st
2
  import os
3
  from ase.io import read
4
  from CifFile import ReadCif
5
- from torch_geometric.data import Data, Batch
6
  import torch
7
  from models.master import create_model
8
- from process import process_data
9
- from utils import radius_graph_pbc
10
  import gc
11
  from io import BytesIO, StringIO
 
 
12
 
13
- MEAN_TEMP = torch.tensor(192.1785) #training temp mean
14
- STD_TEMP = torch.tensor(81.2135) #training temp std
15
 
16
 
17
- @torch.no_grad()
18
- def main():
19
  model = create_model()
20
  st.title("CartNet Thermal Ellipsoid Prediction")
21
- st.image('fig/pipeline.png')
22
 
23
  st.markdown("""
24
  CartNet is a graph neural network specifically designed for predicting Anisotropic Displacement Parameters (ADPs) in crystal structures. The model has been trained on over 220,000 molecular crystal structures from the Cambridge Structural Database (CSD), making it highly accurate and robust for ADP prediction tasks. CartNet addresses the computational challenges of traditional methods by encoding the full 3D geometry of atomic structures into a Cartesian reference frame, bypassing the need for unit cell encoding. The model incorporates innovative features, including a neighbour equalization technique to enhance interaction detection and a Cholesky-based output layer to ensure valid ADP predictions. Additionally, it introduces a rotational SO(3) data augmentation technique to improve generalization across different crystal structure orientations, making the model highly efficient and accurate in predicting ADPs while significantly reducing computational costs.
@@ -46,7 +45,7 @@ def main():
46
  block = "data_"+str(key)+"\n"+ cif[key].printsection()
47
  atoms = read(StringIO(block), format="cif")
48
 
49
- if len(atoms.positions) > 1000:
50
  st.error("""
51
  ⚠️ **Warning**: The structure is too large. Please upload a smaller one or use the [local implementation of CartNet Web App](https://github.com/alexsoleg/cartnet-streamlit/).
52
  """)
@@ -68,39 +67,7 @@ def main():
68
 
69
  continue
70
 
71
- data = Data()
72
- data.x = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.int32)
73
-
74
- data.pos = torch.tensor(atoms.positions, dtype=torch.float32)
75
- data.temperature_og = torch.tensor([temperature], dtype=torch.float32)
76
- data.temperature = (data.temperature_og - MEAN_TEMP) / STD_TEMP
77
- data.cell = torch.tensor(atoms.cell.array, dtype=torch.float32).unsqueeze(0)
78
-
79
- data.pbc = torch.tensor([True, True, True])
80
- data.natoms = len(atoms)
81
-
82
- del atoms
83
- gc.collect()
84
- batch = Batch.from_data_list([data])
85
-
86
-
87
- edge_index, _, _, edge_attr = radius_graph_pbc(batch, 5.0, 64)
88
- del batch
89
- gc.collect()
90
- data.cart_dist = torch.norm(edge_attr, dim=-1)
91
- data.cart_dir = torch.nn.functional.normalize(edge_attr, dim=-1)
92
- data.edge_index = edge_index
93
- data.non_H_mask = data.x != 1
94
- delattr(data, "pbc")
95
- delattr(data, "natoms")
96
- batch = Batch.from_data_list([data])
97
- del data, edge_index, edge_attr
98
- gc.collect()
99
-
100
- st.success("Graph successfully created.")
101
-
102
- cif_file = process_data(batch, model)
103
- st.success("ADPs successfully predicted.")
104
 
105
  cif_file = BytesIO(cif_file.getvalue().encode())
106
  st.download_button(
@@ -118,10 +85,11 @@ def main():
118
  gc.collect()
119
  except Exception as e:
120
  st.error(f"An error occurred while reading the CIF file: {e}")
121
-
122
- st.warning("""
123
- ⚠️ **Warning**: This online web application is designed for structures with up to 1000 atoms in the unit cell. For larger structures, please use the [local implementation of CartNet Web App](https://github.com/alexsoleg/cartnet-streamlit/).
124
- """)
 
125
 
126
  st.warning("""
127
  ⚠️ **Warning**: We use [ASE library](https://wiki.fysik.dtu.dk/ase/) for reading the CIF files, please make sure your CIF file is compatible.
@@ -150,6 +118,9 @@ def main():
150
  """)
151
 
152
  if __name__ == "__main__":
153
- main()
 
 
 
154
 
155
 
 
2
  import os
3
  from ase.io import read
4
  from CifFile import ReadCif
 
5
  import torch
6
  from models.master import create_model
7
+ from process import process_ase
8
+
9
  import gc
10
  from io import BytesIO, StringIO
11
+ import argparse
12
+
13
 
14
+ torch.set_grad_enabled(False)
 
15
 
16
 
17
+ def main(local):
 
18
  model = create_model()
19
  st.title("CartNet Thermal Ellipsoid Prediction")
20
+ st.image('fig/frontpage.png')
21
 
22
  st.markdown("""
23
  CartNet is a graph neural network specifically designed for predicting Anisotropic Displacement Parameters (ADPs) in crystal structures. The model has been trained on over 220,000 molecular crystal structures from the Cambridge Structural Database (CSD), making it highly accurate and robust for ADP prediction tasks. CartNet addresses the computational challenges of traditional methods by encoding the full 3D geometry of atomic structures into a Cartesian reference frame, bypassing the need for unit cell encoding. The model incorporates innovative features, including a neighbour equalization technique to enhance interaction detection and a Cholesky-based output layer to ensure valid ADP predictions. Additionally, it introduces a rotational SO(3) data augmentation technique to improve generalization across different crystal structure orientations, making the model highly efficient and accurate in predicting ADPs while significantly reducing computational costs.
 
45
  block = "data_"+str(key)+"\n"+ cif[key].printsection()
46
  atoms = read(StringIO(block), format="cif")
47
 
48
+ if len(atoms.positions) > 1000 and not local:
49
  st.error("""
50
  ⚠️ **Warning**: The structure is too large. Please upload a smaller one or use the [local implementation of CartNet Web App](https://github.com/alexsoleg/cartnet-streamlit/).
51
  """)
 
67
 
68
  continue
69
 
70
+ cif_file = process_ase(atoms, temperature, model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  cif_file = BytesIO(cif_file.getvalue().encode())
73
  st.download_button(
 
85
  gc.collect()
86
  except Exception as e:
87
  st.error(f"An error occurred while reading the CIF file: {e}")
88
+
89
+ if not local:
90
+ st.warning("""
91
+ ⚠️ **Warning**: This online web application is designed for structures with up to 1000 atoms in the unit cell. For larger structures, please use the [local implementation of CartNet Web App](https://github.com/alexsoleg/cartnet-streamlit/).
92
+ """)
93
 
94
  st.warning("""
95
  ⚠️ **Warning**: We use [ASE library](https://wiki.fysik.dtu.dk/ase/) for reading the CIF files, please make sure your CIF file is compatible.
 
118
  """)
119
 
120
  if __name__ == "__main__":
121
+ parser = argparse.ArgumentParser()
122
+ parser.add_argument('--local', action='store_true')
123
+ args = parser.parse_args()
124
+ main(args.local)
125
 
126
 
main_local.py DELETED
@@ -1,139 +0,0 @@
1
- import streamlit as st
2
- import os
3
- from ase.io import read
4
- from CifFile import ReadCif
5
- from torch_geometric.data import Data, Batch
6
- import torch
7
- from models.master import create_model
8
- from process import process_data
9
- from utils import radius_graph_pbc
10
- from io import BytesIO, StringIO
11
- import gc
12
-
13
- MEAN_TEMP = torch.tensor(192.1785) #training temp mean
14
- STD_TEMP = torch.tensor(81.2135) #training temp std
15
-
16
-
17
- @torch.no_grad()
18
- def main():
19
- model = create_model()
20
- st.title("CartNet Thermal Ellipsoid Prediction")
21
- st.image('fig/pipeline.png')
22
-
23
- st.markdown("""
24
- CartNet is a graph neural network specifically designed for predicting Anisotropic Displacement Parameters (ADPs) in crystal structures. The model has been trained on over 220,000 molecular crystal structures from the Cambridge Structural Database (CSD), making it highly accurate and robust for ADP prediction tasks. CartNet addresses the computational challenges of traditional methods by encoding the full 3D geometry of atomic structures into a Cartesian reference frame, bypassing the need for unit cell encoding. The model incorporates innovative features, including a neighbour equalization technique to enhance interaction detection and a Cholesky-based output layer to ensure valid ADP predictions. Additionally, it introduces a rotational SO(3) data augmentation technique to improve generalization across different crystal structure orientations, making the model highly efficient and accurate in predicting ADPs while significantly reducing computational costs.
25
- """)
26
-
27
- uploaded_file = st.file_uploader("Upload a CIF file", type=["cif"], accept_multiple_files=False)
28
-
29
-
30
- if uploaded_file is not None:
31
- try:
32
- filename = str(uploaded_file.name)
33
- file = BytesIO(uploaded_file.getbuffer())
34
- cif = ReadCif(file)
35
-
36
- if len(cif.keys())>1:
37
- st.warning("⚠️ **Warning**: Found " + str(len(cif.keys())) + " blocks in the CIF file. We will process all of them and export as separate CIF files.")
38
-
39
- st.markdown(f"### CIF file: {filename}")
40
- for key in cif.keys():
41
- st.markdown(f"### Structure: {key}")
42
- try:
43
- block = "data_"+str(key)+"\n"+ cif[key].printsection()
44
- atoms = read(StringIO(block), format="cif")
45
-
46
- cif_data = cif[key]
47
- if "_diffrn_ambient_temperature" in cif_data.keys():
48
- temperature = float(cif_data["_diffrn_ambient_temperature"].split("(")[0])
49
- elif "_cell_measurement_temperature" in cif_data.keys():
50
- temperature = float(cif_data["_cell_measurement_temperature"].split("(")[0])
51
- else:
52
- st.error("Temperature not found in the CIF file. \
53
- Please provide a temperature in the field _diffrn_ambient_temperature o in the field _cell_measurement_temperature from the CIF file.")
54
- continue
55
- st.success("CIF file successfully read.")
56
- except Exception as e:
57
- st.error(f"Error: {e}")
58
- st.error(f"We couldn't find any structure for the block {key}. Please make sure the CIF is compatible with ASE. If the error message is a blank line, it means ASE didn't found any coordinates.")
59
-
60
- continue
61
-
62
- data = Data()
63
- data.x = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.int32)
64
-
65
- data.pos = torch.tensor(atoms.positions, dtype=torch.float32)
66
- data.temperature_og = torch.tensor([temperature], dtype=torch.float32)
67
- data.temperature = (data.temperature_og - MEAN_TEMP) / STD_TEMP
68
- data.cell = torch.tensor(atoms.cell.array, dtype=torch.float32).unsqueeze(0)
69
-
70
- data.pbc = torch.tensor([True, True, True])
71
- data.natoms = len(atoms)
72
-
73
- del atoms
74
- gc.collect()
75
- batch = Batch.from_data_list([data])
76
-
77
-
78
- edge_index, _, _, edge_attr = radius_graph_pbc(batch, 5.0, 64)
79
- del batch
80
- gc.collect()
81
- data.cart_dist = torch.norm(edge_attr, dim=-1)
82
- data.cart_dir = torch.nn.functional.normalize(edge_attr, dim=-1)
83
- data.edge_index = edge_index
84
- data.non_H_mask = data.x != 1
85
- delattr(data, "pbc")
86
- delattr(data, "natoms")
87
- batch = Batch.from_data_list([data])
88
- del data, edge_index, edge_attr
89
- gc.collect()
90
-
91
- st.success("Graph successfully created.")
92
-
93
- cif_file = process_data(batch, model)
94
- st.success("ADPs successfully predicted.")
95
-
96
- cif_file = BytesIO(cif_file.getvalue().encode())
97
- st.download_button(
98
- label="Download processed CIF file",
99
- data=cif_file,
100
- file_name=f"output_{key}.cif",
101
- mime="text/plain",
102
- key=f"download_button_{key}"
103
- )
104
-
105
- gc.collect()
106
-
107
- gc.collect()
108
- except Exception as e:
109
- st.error(f"An error occurred while reading the CIF file: {e}")
110
-
111
-
112
- st.markdown("""
113
- 📌 The official implementation of the paper with all experiments can be found at [CartNet GitHub Repository](https://github.com/imatge-upc/CartNet).
114
- """)
115
-
116
- st.warning("""
117
- ⚠️ **Warning**: We use [ASE library](https://wiki.fysik.dtu.dk/ase/) for reading the CIF files, please make sure your CIF file is compatible.
118
- """)
119
-
120
- st.markdown("""
121
- ### How to cite
122
-
123
- If you use CartNet in your research, please cite our paper:
124
-
125
- ```bibtex
126
- @article{your_paper_citation,
127
- title={Title of the Paper},
128
- author={Author1 and Author2 and Author3},
129
- journal={Journal Name},
130
- year={2023},
131
- volume={XX},
132
- number={YY},
133
- pages={ZZZ}
134
- }
135
- ```
136
- """)
137
-
138
- if __name__ == "__main__":
139
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
process.py CHANGED
@@ -1,10 +1,51 @@
 
1
  import torch
 
2
  from ase.io import write
3
  from ase import Atoms
4
  import gc
5
  from io import BytesIO, StringIO
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- @torch.no_grad()
8
  def process_data(batch, model, output_file="output.cif"):
9
  atoms = batch.x.numpy().astype(int) # Atomic numbers
10
  positions = batch.pos.numpy() # Atomic positions
@@ -72,7 +113,7 @@ def process_data(batch, model, output_file="output.cif"):
72
  element_count[element] = 0
73
  element_count[element] += 1
74
  label = f"{element}{element_count[element]}"
75
- u_iso = torch.trace(adps[indices[i]]).mean() if element != 'H' else 0.01
76
  type = "Uani" if element != 'H' else "Uiso"
77
  cif_file.write(f"{label} {element} {frac_pos[0]} {frac_pos[1]} {frac_pos[2]} {u_iso} {type}\n")
78
 
 
1
+ import streamlit as st
2
  import torch
3
+ from torch_geometric.data import Data, Batch
4
  from ase.io import write
5
  from ase import Atoms
6
  import gc
7
  from io import BytesIO, StringIO
8
+ from utils import radius_graph_pbc
9
+
10
+ MEAN_TEMP = torch.tensor(192.1785) #training temp mean
11
+ STD_TEMP = torch.tensor(81.2135) #training temp std
12
+
13
+ def process_ase(atoms, temperature, model):
14
+ data = Data()
15
+ data.x = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.int32)
16
+
17
+ data.pos = torch.tensor(atoms.positions, dtype=torch.float32)
18
+ data.temperature_og = torch.tensor([temperature], dtype=torch.float32)
19
+ data.temperature = (data.temperature_og - MEAN_TEMP) / STD_TEMP
20
+ data.cell = torch.tensor(atoms.cell.array, dtype=torch.float32).unsqueeze(0)
21
+
22
+ data.pbc = torch.tensor([True, True, True])
23
+ data.natoms = len(atoms)
24
+
25
+ del atoms
26
+ gc.collect()
27
+ batch = Batch.from_data_list([data])
28
+
29
+
30
+ edge_index, _, _, edge_attr = radius_graph_pbc(batch, 5.0, 64)
31
+ del batch
32
+ gc.collect()
33
+ data.cart_dist = torch.norm(edge_attr, dim=-1)
34
+ data.cart_dir = torch.nn.functional.normalize(edge_attr, dim=-1)
35
+ data.edge_index = edge_index
36
+ data.non_H_mask = data.x != 1
37
+ delattr(data, "pbc")
38
+ delattr(data, "natoms")
39
+ batch = Batch.from_data_list([data])
40
+ del data, edge_index, edge_attr
41
+ gc.collect()
42
+
43
+ st.success("Graph successfully created.")
44
+
45
+ cif_file = process_data(batch, model)
46
+ st.success("ADPs successfully predicted.")
47
+ return cif_file
48
 
 
49
  def process_data(batch, model, output_file="output.cif"):
50
  atoms = batch.x.numpy().astype(int) # Atomic numbers
51
  positions = batch.pos.numpy() # Atomic positions
 
113
  element_count[element] = 0
114
  element_count[element] += 1
115
  label = f"{element}{element_count[element]}"
116
+ u_iso = torch.trace(adps[indices[i]]).mean() if element != 'H' else 0.0001
117
  type = "Uani" if element != 'H' else "Uiso"
118
  cif_file.write(f"{label} {element} {frac_pos[0]} {frac_pos[1]} {frac_pos[2]} {u_iso} {type}\n")
119