Spaces:
Sleeping
Sleeping
Àlex Solé
commited on
Commit
·
744c6a1
1
Parent(s):
3fa8a08
merged from streamlit
Browse files- .gitattributes +1 -0
- .gitignore +1 -0
- LICENSE +21 -0
- README.md +57 -13
- cpkt/cartnet_adp.ckpt +3 -0
- fig/pipeline.png +3 -0
- main.py +127 -0
- main_local.py +119 -0
- models/cartnet.py +289 -0
- models/master.py +14 -0
- models/utils.py +129 -0
- predict.py +73 -0
- process.py +100 -0
- requirements.txt +8 -0
- utils.py +323 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
.DS_Store
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 Àlex Solé Gómez
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,13 +1,57 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# CartNet Streamlit Web App
|
2 |
+
|
3 |
+

|
4 |
+
|
5 |
+
### CartNet online demo available at: [CartNet Web App](https://cartnet-adp-estimation.streamlit.app)
|
6 |
+
|
7 |
+
|
8 |
+
CartNet is a graph neural network specifically designed to predict Anisotropic Displacement Parameters (ADPs) in crystal structures. The model has been trained on over 220,000 molecular crystal structures from the Cambridge Structural Database (CSD), making it highly accurate and robust for ADP prediction tasks. CartNet addresses the computational challenges of traditional methods by encoding the full 3D geometry of atomic structures into a Cartesian reference frame, bypassing the need for unit cell encoding. The model incorporates innovative features, including a neighbour equalization technique to enhance interaction detection and a Cholesky-based output layer to ensure valid ADP predictions. Additionally, it introduces a rotational SO(3) data augmentation technique to improve generalization across different crystal structure orientations, making the model highly efficient and accurate in predicting ADPs while significantly reducing computational costs.
|
9 |
+
|
10 |
+
This repository contains a web application based on the official implementation of CartNet, which can be found at [imatge-upc/CartNet](https://github.com/imatge-upc/CartNet).
|
11 |
+
|
12 |
+
⚠️ **Warning**: The online web application can only process systems with less than 300 atoms in the unit cell. For large systems, please use the local application.
|
13 |
+
|
14 |
+
## Local Application
|
15 |
+
### Installation of the local application
|
16 |
+
|
17 |
+
To set up the local application, you need to install the dependencies listed in `requirements.txt`. You can do this by running the following command:
|
18 |
+
|
19 |
+
```bash
|
20 |
+
pip install -r requirements.txt
|
21 |
+
```
|
22 |
+
|
23 |
+
### Usage
|
24 |
+
|
25 |
+
You can make predictions directly from Python using the `predict.py` script.
|
26 |
+
|
27 |
+
The script takes two arguments:
|
28 |
+
1. `input_file`: Path to the input CIF file
|
29 |
+
2. `output_file`: Path where you want to save the processed CIF file
|
30 |
+
|
31 |
+
Example usage:
|
32 |
+
|
33 |
+
```bash
|
34 |
+
python predict.py input.cif output.cif
|
35 |
+
```
|
36 |
+
|
37 |
+
Or, if you prefer, you can use the browser app on your local machine without the atom number limitation by running:
|
38 |
+
|
39 |
+
```bash
|
40 |
+
streamlit run main_local.py
|
41 |
+
```
|
42 |
+
|
43 |
+
## How to cite
|
44 |
+
|
45 |
+
If you use CartNet in your research, please cite our paper:
|
46 |
+
|
47 |
+
```bibtex
|
48 |
+
@article{your_paper_citation,
|
49 |
+
title={Title of the Paper},
|
50 |
+
author={Author1 and Author2 and Author3},
|
51 |
+
journal={Journal Name},
|
52 |
+
year={2023},
|
53 |
+
volume={XX},
|
54 |
+
number={YY},
|
55 |
+
pages={ZZZ}
|
56 |
+
}
|
57 |
+
```
|
cpkt/cartnet_adp.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0829f2e8a631b380c04e2beb5c75b36767e5635321deea17fc5e7bafee643332
|
3 |
+
size 30066357
|
fig/pipeline.png
ADDED
![]() |
Git LFS Details
|
main.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import os
|
3 |
+
from ase.io import read
|
4 |
+
from CifFile import ReadCif
|
5 |
+
from torch_geometric.data import Data, Batch
|
6 |
+
import torch
|
7 |
+
from models.master import create_model
|
8 |
+
from process import process_data
|
9 |
+
from utils import radius_graph_pbc
|
10 |
+
import gc
|
11 |
+
|
12 |
+
MEAN_TEMP = torch.tensor(192.1785) #training temp mean
|
13 |
+
STD_TEMP = torch.tensor(81.2135) #training temp std
|
14 |
+
|
15 |
+
|
16 |
+
@torch.no_grad()
|
17 |
+
def main():
|
18 |
+
model = create_model()
|
19 |
+
st.title("CartNet ADP Prediction")
|
20 |
+
st.image('fig/pipeline.png')
|
21 |
+
|
22 |
+
st.markdown("""
|
23 |
+
CartNet is a graph neural network specifically designed for predicting Anisotropic Displacement Parameters (ADPs) in crystal structures. The model has been trained on over 220,000 molecular crystal structures from the Cambridge Structural Database (CSD), making it highly accurate and robust for ADP prediction tasks. CartNet addresses the computational challenges of traditional methods by encoding the full 3D geometry of atomic structures into a Cartesian reference frame, bypassing the need for unit cell encoding. The model incorporates innovative features, including a neighbour equalization technique to enhance interaction detection and a Cholesky-based output layer to ensure valid ADP predictions. Additionally, it introduces a rotational SO(3) data augmentation technique to improve generalization across different crystal structure orientations, making the model highly efficient and accurate in predicting ADPs while significantly reducing computational costs.
|
24 |
+
""")
|
25 |
+
|
26 |
+
uploaded_file = st.file_uploader("Upload a CIF file", type=["cif"], accept_multiple_files=False)
|
27 |
+
# uploaded_file = "ABABEM.cif"
|
28 |
+
if uploaded_file is not None:
|
29 |
+
try:
|
30 |
+
with open(uploaded_file.name, "wb") as f:
|
31 |
+
f.write(uploaded_file.getbuffer())
|
32 |
+
filename = str(uploaded_file.name)
|
33 |
+
# Read the CIF file using ASE
|
34 |
+
atoms = read(filename, format="cif")
|
35 |
+
cif = ReadCif(filename)
|
36 |
+
cif_data = cif.first_block()
|
37 |
+
if "_diffrn_ambient_temperature" in cif_data.keys():
|
38 |
+
temperature = float(cif_data["_diffrn_ambient_temperature"])
|
39 |
+
else:
|
40 |
+
raise ValueError("Temperature not found in the CIF file. \
|
41 |
+
Please provide a temperature in the field _diffrn_ambient_temperature from the CIF file.")
|
42 |
+
st.success("CIF file successfully read.")
|
43 |
+
|
44 |
+
data = Data()
|
45 |
+
data.x = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.int32)
|
46 |
+
|
47 |
+
if len(atoms.positions) > 300:
|
48 |
+
st.markdown("""
|
49 |
+
⚠️ **Warning**: The structure is too large. Please upload a smaller one or use the [local implementation of CartNet Web App](https://github.com/alexsoleg/cartnet-streamlit/).
|
50 |
+
""")
|
51 |
+
raise ValueError("Please provide a structure with less than 300 atoms in the unit cell.")
|
52 |
+
|
53 |
+
data.pos = torch.tensor(atoms.positions, dtype=torch.float32)
|
54 |
+
data.temperature_og = torch.tensor([temperature], dtype=torch.float32)
|
55 |
+
data.temperature = (data.temperature_og - MEAN_TEMP) / STD_TEMP
|
56 |
+
data.cell = torch.tensor(atoms.cell.array, dtype=torch.float32).unsqueeze(0)
|
57 |
+
|
58 |
+
data.pbc = torch.tensor([True, True, True])
|
59 |
+
data.natoms = len(atoms)
|
60 |
+
|
61 |
+
del atoms
|
62 |
+
gc.collect()
|
63 |
+
batch = Batch.from_data_list([data])
|
64 |
+
|
65 |
+
|
66 |
+
edge_index, _, _, edge_attr = radius_graph_pbc(batch, 5.0, 64)
|
67 |
+
del batch
|
68 |
+
gc.collect()
|
69 |
+
data.cart_dist = torch.norm(edge_attr, dim=-1)
|
70 |
+
data.cart_dir = torch.nn.functional.normalize(edge_attr, dim=-1)
|
71 |
+
data.edge_index = edge_index
|
72 |
+
data.non_H_mask = data.x != 1
|
73 |
+
delattr(data, "pbc")
|
74 |
+
delattr(data, "natoms")
|
75 |
+
batch = Batch.from_data_list([data])
|
76 |
+
del data, edge_index, edge_attr
|
77 |
+
gc.collect()
|
78 |
+
|
79 |
+
st.success("Graph successfully created.")
|
80 |
+
|
81 |
+
process_data(batch, model)
|
82 |
+
st.success("ADPs successfully predicted.")
|
83 |
+
|
84 |
+
# Create a download button for the processed CIF file
|
85 |
+
with open("output.cif", "r") as f:
|
86 |
+
cif_contents = f.read()
|
87 |
+
|
88 |
+
st.download_button(
|
89 |
+
label="Download processed CIF file",
|
90 |
+
data=cif_contents,
|
91 |
+
file_name="output.cif",
|
92 |
+
mime="text/plain"
|
93 |
+
)
|
94 |
+
|
95 |
+
os.remove("output.cif")
|
96 |
+
os.remove(filename)
|
97 |
+
gc.collect()
|
98 |
+
except Exception as e:
|
99 |
+
st.error(f"An error occurred while reading the CIF file: {e}")
|
100 |
+
st.markdown("""
|
101 |
+
⚠️ **Warning**: This online web application is designed for structures with up to 300 atoms in the unit cell. For larger structures, please use the [local implementation of CartNet Web App](https://github.com/alexsoleg/cartnet-streamlit/).
|
102 |
+
""")
|
103 |
+
|
104 |
+
st.markdown("""
|
105 |
+
📌 The official implementation of the paper with all experiments can be found at [CartNet GitHub Repository](https://github.com/imatge-upc/CartNet).
|
106 |
+
""")
|
107 |
+
|
108 |
+
st.markdown("""
|
109 |
+
### How to cite
|
110 |
+
|
111 |
+
If you use CartNet in your research, please cite our paper:
|
112 |
+
|
113 |
+
```bibtex
|
114 |
+
@article{your_paper_citation,
|
115 |
+
title={Title of the Paper},
|
116 |
+
author={Author1 and Author2 and Author3},
|
117 |
+
journal={Journal Name},
|
118 |
+
year={2023},
|
119 |
+
volume={XX},
|
120 |
+
number={YY},
|
121 |
+
pages={ZZZ}
|
122 |
+
}
|
123 |
+
```
|
124 |
+
""")
|
125 |
+
|
126 |
+
if __name__ == "__main__":
|
127 |
+
main()
|
main_local.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import os
|
3 |
+
from ase.io import read
|
4 |
+
from CifFile import ReadCif
|
5 |
+
from torch_geometric.data import Data, Batch
|
6 |
+
import torch
|
7 |
+
from models.master import create_model
|
8 |
+
from process import process_data
|
9 |
+
from utils import radius_graph_pbc
|
10 |
+
import gc
|
11 |
+
|
12 |
+
MEAN_TEMP = torch.tensor(192.1785) #training temp mean
|
13 |
+
STD_TEMP = torch.tensor(81.2135) #training temp std
|
14 |
+
|
15 |
+
|
16 |
+
@torch.no_grad()
|
17 |
+
def main():
|
18 |
+
model = create_model()
|
19 |
+
st.title("CartNet ADP Prediction")
|
20 |
+
st.image('fig/pipeline.png')
|
21 |
+
|
22 |
+
st.markdown("""
|
23 |
+
CartNet is a graph neural network specifically designed for predicting Anisotropic Displacement Parameters (ADPs) in crystal structures. The model has been trained on over 220,000 molecular crystal structures from the Cambridge Structural Database (CSD), making it highly accurate and robust for ADP prediction tasks. CartNet addresses the computational challenges of traditional methods by encoding the full 3D geometry of atomic structures into a Cartesian reference frame, bypassing the need for unit cell encoding. The model incorporates innovative features, including a neighbour equalization technique to enhance interaction detection and a Cholesky-based output layer to ensure valid ADP predictions. Additionally, it introduces a rotational SO(3) data augmentation technique to improve generalization across different crystal structure orientations, making the model highly efficient and accurate in predicting ADPs while significantly reducing computational costs.
|
24 |
+
""")
|
25 |
+
|
26 |
+
uploaded_file = st.file_uploader("Upload a CIF file", type=["cif"], accept_multiple_files=False)
|
27 |
+
# uploaded_file = "ABABEM.cif"
|
28 |
+
if uploaded_file is not None:
|
29 |
+
try:
|
30 |
+
with open(uploaded_file.name, "wb") as f:
|
31 |
+
f.write(uploaded_file.getbuffer())
|
32 |
+
filename = str(uploaded_file.name)
|
33 |
+
# Read the CIF file using ASE
|
34 |
+
atoms = read(filename, format="cif")
|
35 |
+
cif = ReadCif(filename)
|
36 |
+
cif_data = cif.first_block()
|
37 |
+
if "_diffrn_ambient_temperature" in cif_data.keys():
|
38 |
+
temperature = float(cif_data["_diffrn_ambient_temperature"])
|
39 |
+
else:
|
40 |
+
raise ValueError("Temperature not found in the CIF file. \
|
41 |
+
Please provide a temperature in the field _diffrn_ambient_temperature from the CIF file.")
|
42 |
+
st.success("CIF file successfully read.")
|
43 |
+
|
44 |
+
data = Data()
|
45 |
+
data.x = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.int32)
|
46 |
+
|
47 |
+
data.pos = torch.tensor(atoms.positions, dtype=torch.float32)
|
48 |
+
data.temperature_og = torch.tensor([temperature], dtype=torch.float32)
|
49 |
+
data.temperature = (data.temperature_og - MEAN_TEMP) / STD_TEMP
|
50 |
+
data.cell = torch.tensor(atoms.cell.array, dtype=torch.float32).unsqueeze(0)
|
51 |
+
|
52 |
+
data.pbc = torch.tensor([True, True, True])
|
53 |
+
data.natoms = len(atoms)
|
54 |
+
|
55 |
+
del atoms
|
56 |
+
gc.collect()
|
57 |
+
batch = Batch.from_data_list([data])
|
58 |
+
|
59 |
+
|
60 |
+
edge_index, _, _, edge_attr = radius_graph_pbc(batch, 5.0, 64)
|
61 |
+
del batch
|
62 |
+
gc.collect()
|
63 |
+
data.cart_dist = torch.norm(edge_attr, dim=-1)
|
64 |
+
data.cart_dir = torch.nn.functional.normalize(edge_attr, dim=-1)
|
65 |
+
data.edge_index = edge_index
|
66 |
+
data.non_H_mask = data.x != 1
|
67 |
+
delattr(data, "pbc")
|
68 |
+
delattr(data, "natoms")
|
69 |
+
batch = Batch.from_data_list([data])
|
70 |
+
del data, edge_index, edge_attr
|
71 |
+
gc.collect()
|
72 |
+
|
73 |
+
st.success("Graph successfully created.")
|
74 |
+
|
75 |
+
process_data(batch, model)
|
76 |
+
st.success("ADPs successfully predicted.")
|
77 |
+
|
78 |
+
# Create a download button for the processed CIF file
|
79 |
+
with open("output.cif", "r") as f:
|
80 |
+
cif_contents = f.read()
|
81 |
+
|
82 |
+
st.download_button(
|
83 |
+
label="Download processed CIF file",
|
84 |
+
data=cif_contents,
|
85 |
+
file_name="output.cif",
|
86 |
+
mime="text/plain"
|
87 |
+
)
|
88 |
+
|
89 |
+
os.remove("output.cif")
|
90 |
+
os.remove(filename)
|
91 |
+
gc.collect()
|
92 |
+
except Exception as e:
|
93 |
+
st.error(f"An error occurred while reading the CIF file: {e}")
|
94 |
+
|
95 |
+
|
96 |
+
st.markdown("""
|
97 |
+
📌 The official implementation of the paper with all experiments can be found at [CartNet GitHub Repository](https://github.com/imatge-upc/CartNet).
|
98 |
+
""")
|
99 |
+
|
100 |
+
st.markdown("""
|
101 |
+
### How to cite
|
102 |
+
|
103 |
+
If you use CartNet in your research, please cite our paper:
|
104 |
+
|
105 |
+
```bibtex
|
106 |
+
@article{your_paper_citation,
|
107 |
+
title={Title of the Paper},
|
108 |
+
author={Author1 and Author2 and Author3},
|
109 |
+
journal={Journal Name},
|
110 |
+
year={2023},
|
111 |
+
volume={XX},
|
112 |
+
number={YY},
|
113 |
+
pages={ZZZ}
|
114 |
+
}
|
115 |
+
```
|
116 |
+
""")
|
117 |
+
|
118 |
+
if __name__ == "__main__":
|
119 |
+
main()
|
models/cartnet.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Universitat Politècnica de Catalunya 2024 https://imatge.upc.edu
|
2 |
+
# Distributed under the MIT License.
|
3 |
+
# (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT)
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch_geometric.nn as pyg_nn
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch_scatter import scatter
|
10 |
+
from models.utils import ExpNormalSmearing, CosineCutoff
|
11 |
+
|
12 |
+
|
13 |
+
class CartNet(torch.nn.Module):
|
14 |
+
"""
|
15 |
+
CartNet model from Cartesian Encoding Graph Neural Network for Crystal Structures Property Prediction: Application to Thermal Ellipsoid Estimation.
|
16 |
+
Args:
|
17 |
+
dim_in (int): Dimensionality of the input features.
|
18 |
+
dim_rbf (int): Dimensionality of the radial basis function embeddings.
|
19 |
+
num_layers (int): Number of CartNet layers in the model.
|
20 |
+
radius (float, optional): Radius cutoff for neighbor interactions. Default is 5.0.
|
21 |
+
invariant (bool, optional): If `True`, enforces rotational invariance in the encoder. Default is `False`.
|
22 |
+
temperature (bool, optional): If `True`, includes temperature information in the encoder. Default is `True`.
|
23 |
+
use_envelope (bool, optional): If `True`, applies an envelope function to the interactions. Default is `True`.
|
24 |
+
cholesky (bool, optional): If `True`, uses a Cholesky head for the output. If `False`, uses a scalar head. Default is `True`.
|
25 |
+
Methods:
|
26 |
+
forward(batch):
|
27 |
+
Performs a forward pass of the model.
|
28 |
+
Args:
|
29 |
+
batch: A batch of input data.
|
30 |
+
Returns:
|
31 |
+
pred: The model's predictions.
|
32 |
+
true: The ground truth values corresponding to the input batch.
|
33 |
+
"""
|
34 |
+
|
35 |
+
|
36 |
+
def __init__(self,
|
37 |
+
dim_in: int,
|
38 |
+
dim_rbf: int,
|
39 |
+
num_layers: int,
|
40 |
+
radius: float = 5.0,
|
41 |
+
invariant: bool = False,
|
42 |
+
temperature: bool = True,
|
43 |
+
use_envelope: bool = True,
|
44 |
+
cholesky: bool = True):
|
45 |
+
super().__init__()
|
46 |
+
|
47 |
+
self.encoder = Encoder(dim_in, dim_rbf=dim_rbf, radius=radius, invariant=invariant, temperature=temperature)
|
48 |
+
self.dim_in = dim_in
|
49 |
+
|
50 |
+
layers = []
|
51 |
+
for _ in range(num_layers):
|
52 |
+
layers.append(CartNet_layer(
|
53 |
+
dim_in=dim_in,
|
54 |
+
use_envelope=use_envelope,
|
55 |
+
))
|
56 |
+
self.layers = torch.nn.Sequential(*layers)
|
57 |
+
|
58 |
+
if cholesky:
|
59 |
+
self.head = Cholesky_head(dim_in)
|
60 |
+
else:
|
61 |
+
self.head = Scalar_head(dim_in)
|
62 |
+
|
63 |
+
def forward(self, batch):
|
64 |
+
batch = self.encoder(batch)
|
65 |
+
|
66 |
+
for layer in self.layers:
|
67 |
+
batch = layer(batch)
|
68 |
+
|
69 |
+
pred = self.head(batch)
|
70 |
+
|
71 |
+
return pred
|
72 |
+
|
73 |
+
class Encoder(torch.nn.Module):
|
74 |
+
"""
|
75 |
+
Encoder module for the CartNet model.
|
76 |
+
This module encodes node and edge features for input into the CartNet model, incorporating optional temperature information and rotational invariance.
|
77 |
+
Args:
|
78 |
+
dim_in (int): Dimension of the input features after embedding.
|
79 |
+
dim_rbf (int): Dimension of the radial basis function used for edge attributes.
|
80 |
+
radius (float, optional): Cutoff radius for neighbor interactions. Defaults to 5.0.
|
81 |
+
invariant (bool, optional): If True, the encoder enforces rotational invariance by excluding directional information from edge attributes. Defaults to False.
|
82 |
+
temperature (bool, optional): If True, includes temperature data in the node embeddings. Defaults to True.
|
83 |
+
Attributes:
|
84 |
+
dim_in (int): Dimension of the input features.
|
85 |
+
invariant (bool): Indicates if rotational invariance is enforced.
|
86 |
+
temperature (bool): Indicates if temperature information is included.
|
87 |
+
embedding (nn.Embedding): Embedding layer mapping atomic numbers to feature vectors.
|
88 |
+
temperature_proj_atom (pyg_nn.Linear): Linear layer projecting temperature to embedding dimensions (used if temperature is True).
|
89 |
+
bias (nn.Parameter): Bias term added to embeddings (used if temperature is False).
|
90 |
+
activation (nn.Module): Activation function (SiLU).
|
91 |
+
encoder_atom (nn.Sequential): Sequential network encoding node features.
|
92 |
+
encoder_edge (nn.Sequential): Sequential network encoding edge features.
|
93 |
+
rbf (ExpNormalSmearing): Radial basis function for encoding distances.
|
94 |
+
"""
|
95 |
+
|
96 |
+
def __init__(
|
97 |
+
self,
|
98 |
+
dim_in: int,
|
99 |
+
dim_rbf: int,
|
100 |
+
radius: float = 5.0,
|
101 |
+
invariant: bool = False,
|
102 |
+
temperature: bool = True,
|
103 |
+
):
|
104 |
+
super(Encoder, self).__init__()
|
105 |
+
self.dim_in = dim_in
|
106 |
+
self.invariant = invariant
|
107 |
+
self.temperature = temperature
|
108 |
+
self.embedding = nn.Embedding(119, self.dim_in*2)
|
109 |
+
if self.temperature:
|
110 |
+
self.temperature_proj_atom = pyg_nn.Linear(1, self.dim_in*2, bias=True)
|
111 |
+
else:
|
112 |
+
self.bias = nn.Parameter(torch.zeros(self.dim_in*2))
|
113 |
+
self.activation = nn.SiLU(inplace=True)
|
114 |
+
|
115 |
+
|
116 |
+
self.encoder_atom = nn.Sequential(self.activation,
|
117 |
+
pyg_nn.Linear(self.dim_in*2, self.dim_in),
|
118 |
+
self.activation)
|
119 |
+
if self.invariant:
|
120 |
+
dim_edge = dim_rbf
|
121 |
+
else:
|
122 |
+
dim_edge = dim_rbf + 3
|
123 |
+
|
124 |
+
self.encoder_edge = nn.Sequential(pyg_nn.Linear(dim_edge, self.dim_in*2),
|
125 |
+
self.activation,
|
126 |
+
pyg_nn.Linear(self.dim_in*2, self.dim_in),
|
127 |
+
self.activation)
|
128 |
+
|
129 |
+
self.rbf = ExpNormalSmearing(0.0,radius,dim_rbf,False)
|
130 |
+
|
131 |
+
torch.nn.init.xavier_uniform_(self.embedding.weight.data)
|
132 |
+
|
133 |
+
def forward(self, batch):
|
134 |
+
|
135 |
+
x = self.embedding(batch.x) + self.temperature_proj_atom(batch.temperature.unsqueeze(-1))[batch.batch]
|
136 |
+
|
137 |
+
|
138 |
+
batch.x = self.encoder_atom(x)
|
139 |
+
|
140 |
+
batch.edge_attr = self.encoder_edge(torch.cat([self.rbf(batch.cart_dist), batch.cart_dir], dim=-1))
|
141 |
+
|
142 |
+
return batch
|
143 |
+
|
144 |
+
class CartNet_layer(pyg_nn.conv.MessagePassing):
|
145 |
+
"""
|
146 |
+
The message-passing layer used in the CartNet architecture.
|
147 |
+
Parameters:
|
148 |
+
dim_in (int): Dimension of the input node features.
|
149 |
+
use_envelope (bool, optional): If True, applies an envelope function to the distances. Defaults to True.
|
150 |
+
Attributes:
|
151 |
+
dim_in (int): Dimension of the input node features.
|
152 |
+
activation (nn.Module): Activation function (SiLU) used in the layer.
|
153 |
+
MLP_aggr (nn.Sequential): MLP used for aggregating messages.
|
154 |
+
MLP_gate (nn.Sequential): MLP used for computing gating coefficients.
|
155 |
+
norm (nn.BatchNorm1d): Batch normalization applied to the gating coefficients.
|
156 |
+
norm2 (nn.BatchNorm1d): Batch normalization applied to the aggregated messages.
|
157 |
+
use_envelope (bool): Indicates if the envelope function is used.
|
158 |
+
envelope (CosineCutoff): Envelope function applied to the distances.
|
159 |
+
"""
|
160 |
+
|
161 |
+
def __init__(self,
|
162 |
+
dim_in: int,
|
163 |
+
use_envelope: bool = True
|
164 |
+
):
|
165 |
+
super().__init__()
|
166 |
+
self.dim_in = dim_in
|
167 |
+
self.activation = nn.SiLU(inplace=True)
|
168 |
+
self.MLP_aggr = nn.Sequential(
|
169 |
+
pyg_nn.Linear(dim_in*3, dim_in, bias=True),
|
170 |
+
self.activation,
|
171 |
+
pyg_nn.Linear(dim_in, dim_in, bias=True),
|
172 |
+
)
|
173 |
+
self.MLP_gate = nn.Sequential(
|
174 |
+
pyg_nn.Linear(dim_in*3, dim_in, bias=True),
|
175 |
+
self.activation,
|
176 |
+
pyg_nn.Linear(dim_in, dim_in, bias=True),
|
177 |
+
)
|
178 |
+
|
179 |
+
self.norm = nn.BatchNorm1d(dim_in)
|
180 |
+
self.norm2 = nn.BatchNorm1d(dim_in)
|
181 |
+
self.use_envelope = use_envelope
|
182 |
+
self.envelope = CosineCutoff(0, 5.0)
|
183 |
+
|
184 |
+
|
185 |
+
def forward(self, batch):
|
186 |
+
|
187 |
+
x, e, edge_index, dist = batch.x, batch.edge_attr, batch.edge_index, batch.cart_dist
|
188 |
+
"""
|
189 |
+
x : [n_nodes, dim_in]
|
190 |
+
e : [n_edges, dim_in]
|
191 |
+
edge_index : [2, n_edges]
|
192 |
+
dist : [n_edges]
|
193 |
+
batch : [n_nodes]
|
194 |
+
"""
|
195 |
+
|
196 |
+
x_in = x
|
197 |
+
e_in = e
|
198 |
+
|
199 |
+
x, e = self.propagate(edge_index,
|
200 |
+
Xx=x, Ee=e,
|
201 |
+
He=dist,
|
202 |
+
)
|
203 |
+
|
204 |
+
batch.x = self.activation(x) + x_in
|
205 |
+
|
206 |
+
batch.edge_attr = e_in + e
|
207 |
+
|
208 |
+
return batch
|
209 |
+
|
210 |
+
|
211 |
+
def message(self, Xx_i, Ee, Xx_j, He):
|
212 |
+
"""
|
213 |
+
x_i : [n_edges, dim_in]
|
214 |
+
x_j : [n_edges, dim_in]
|
215 |
+
e : [n_edges, dim_in]
|
216 |
+
"""
|
217 |
+
|
218 |
+
e_ij = self.MLP_gate(torch.cat([Xx_i, Xx_j, Ee], dim=-1))
|
219 |
+
e_ij = F.sigmoid(self.norm(e_ij))
|
220 |
+
|
221 |
+
if self.use_envelope:
|
222 |
+
sigma_ij = self.envelope(He).unsqueeze(-1)*e_ij
|
223 |
+
else:
|
224 |
+
sigma_ij = e_ij
|
225 |
+
|
226 |
+
self.e = sigma_ij
|
227 |
+
return sigma_ij
|
228 |
+
|
229 |
+
def aggregate(self, sigma_ij, index, Xx_i, Xx_j, Ee, Xx):
|
230 |
+
"""
|
231 |
+
sigma_ij : [n_edges, dim_in] ; is the output from message() function
|
232 |
+
index : [n_edges]
|
233 |
+
x_j : [n_edges, dim_in]
|
234 |
+
"""
|
235 |
+
dim_size = Xx.shape[0]
|
236 |
+
|
237 |
+
sender = self.MLP_aggr(torch.cat([Xx_i, Xx_j, Ee], dim=-1))
|
238 |
+
|
239 |
+
|
240 |
+
out = scatter(sigma_ij*sender, index, 0, None, dim_size,
|
241 |
+
reduce='sum')
|
242 |
+
|
243 |
+
return out
|
244 |
+
|
245 |
+
def update(self, aggr_out):
|
246 |
+
"""
|
247 |
+
aggr_out : [n_nodes, dim_in] ; is the output from aggregate() function after the aggregation
|
248 |
+
x : [n_nodes, dim_in]
|
249 |
+
"""
|
250 |
+
x = self.norm2(aggr_out)
|
251 |
+
|
252 |
+
e_out = self.e
|
253 |
+
del self.e
|
254 |
+
|
255 |
+
return x, e_out
|
256 |
+
|
257 |
+
class Cholesky_head(torch.nn.Module):
|
258 |
+
"""
|
259 |
+
The Cholesky head used in the CartNet model.
|
260 |
+
It enforce the positive definiteness of the output covariance matrix.
|
261 |
+
|
262 |
+
Args:
|
263 |
+
dim_in (int): The input dimension of the features.
|
264 |
+
"""
|
265 |
+
|
266 |
+
def __init__(self,
|
267 |
+
dim_in: int
|
268 |
+
):
|
269 |
+
super(Cholesky_head, self).__init__()
|
270 |
+
self.MLP = nn.Sequential(pyg_nn.Linear(dim_in, dim_in//2),
|
271 |
+
nn.SiLU(inplace=True),
|
272 |
+
pyg_nn.Linear(dim_in//2, 6))
|
273 |
+
|
274 |
+
def forward(self, batch):
|
275 |
+
pred = self.MLP(batch.x[batch.non_H_mask])
|
276 |
+
|
277 |
+
diag_elements = F.softplus(pred[:, :3])
|
278 |
+
|
279 |
+
i,j = torch.tensor([0,1,2,0,0,1]), torch.tensor([0,1,2,1,2,2])
|
280 |
+
L_matrix = torch.zeros(pred.size(0),3,3, device=pred.device, dtype=pred.dtype)
|
281 |
+
L_matrix[:,i[:3], i[:3]] = diag_elements
|
282 |
+
L_matrix[:,i[3:], j[3:]] = pred[:,3:]
|
283 |
+
|
284 |
+
U = torch.bmm(L_matrix.transpose(1, 2), L_matrix)
|
285 |
+
|
286 |
+
return U
|
287 |
+
|
288 |
+
|
289 |
+
|
models/master.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import streamlit as st
|
3 |
+
from models.cartnet import CartNet
|
4 |
+
|
5 |
+
# We cache the loading function to make is very fast on reload.
|
6 |
+
@st.cache_resource
|
7 |
+
def create_model():
|
8 |
+
model = CartNet(dim_in=256, dim_rbf=64, num_layers=4, radius=5.0, invariant=False, temperature=True, use_envelope=True, cholesky=True)
|
9 |
+
ckpt_path = "cpkt/cartnet_adp.ckpt"
|
10 |
+
load = torch.load(ckpt_path, map_location=torch.device('cpu'))["model_state"]
|
11 |
+
|
12 |
+
model.load_state_dict(load)
|
13 |
+
model.eval()
|
14 |
+
return model
|
models/utils.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
from torch import nn, Tensor
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from typing import Optional
|
7 |
+
|
8 |
+
# Implementation from TensorNet
|
9 |
+
# https://github.com/torchmd/torchmd-net
|
10 |
+
class ExpNormalSmearing(nn.Module):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
cutoff_lower=0.0,
|
14 |
+
cutoff_upper=5.0,
|
15 |
+
num_rbf=50,
|
16 |
+
trainable=True,
|
17 |
+
dtype=torch.float32,
|
18 |
+
):
|
19 |
+
super(ExpNormalSmearing, self).__init__()
|
20 |
+
self.cutoff_lower = cutoff_lower
|
21 |
+
self.cutoff_upper = cutoff_upper
|
22 |
+
self.num_rbf = num_rbf
|
23 |
+
self.trainable = trainable
|
24 |
+
self.dtype = dtype
|
25 |
+
self.cutoff_fn = CosineCutoff(0, cutoff_upper)
|
26 |
+
self.alpha = 5.0 / (cutoff_upper - cutoff_lower)
|
27 |
+
|
28 |
+
means, betas = self._initial_params()
|
29 |
+
if trainable:
|
30 |
+
self.register_parameter("means", nn.Parameter(means))
|
31 |
+
self.register_parameter("betas", nn.Parameter(betas))
|
32 |
+
else:
|
33 |
+
self.register_buffer("means", means)
|
34 |
+
self.register_buffer("betas", betas)
|
35 |
+
|
36 |
+
def _initial_params(self):
|
37 |
+
# initialize means and betas according to the default values in PhysNet
|
38 |
+
# https://pubs.acs.org/doi/10.1021/acs.jctc.9b00181
|
39 |
+
start_value = torch.exp(
|
40 |
+
torch.scalar_tensor(
|
41 |
+
-self.cutoff_upper + self.cutoff_lower, dtype=self.dtype
|
42 |
+
)
|
43 |
+
)
|
44 |
+
means = torch.linspace(start_value, 1, self.num_rbf, dtype=self.dtype)
|
45 |
+
betas = torch.tensor(
|
46 |
+
[(2 / self.num_rbf * (1 - start_value)) ** -2] * self.num_rbf,
|
47 |
+
dtype=self.dtype,
|
48 |
+
)
|
49 |
+
return means, betas
|
50 |
+
|
51 |
+
def reset_parameters(self):
|
52 |
+
means, betas = self._initial_params()
|
53 |
+
self.means.data.copy_(means)
|
54 |
+
self.betas.data.copy_(betas)
|
55 |
+
|
56 |
+
def forward(self, dist):
|
57 |
+
dist = dist.unsqueeze(-1)
|
58 |
+
return self.cutoff_fn(dist) * torch.exp(
|
59 |
+
-self.betas
|
60 |
+
* (torch.exp(self.alpha * (-dist + self.cutoff_lower)) - self.means) ** 2
|
61 |
+
)
|
62 |
+
|
63 |
+
class CosineCutoff(nn.Module):
|
64 |
+
def __init__(self, cutoff_lower=0.0, cutoff_upper=5.0):
|
65 |
+
super(CosineCutoff, self).__init__()
|
66 |
+
self.cutoff_lower = cutoff_lower
|
67 |
+
self.cutoff_upper = cutoff_upper
|
68 |
+
|
69 |
+
def forward(self, distances: Tensor) -> Tensor:
|
70 |
+
if self.cutoff_lower > 0:
|
71 |
+
cutoffs = 0.5 * (
|
72 |
+
torch.cos(
|
73 |
+
math.pi
|
74 |
+
* (
|
75 |
+
2
|
76 |
+
* (distances - self.cutoff_lower)
|
77 |
+
/ (self.cutoff_upper - self.cutoff_lower)
|
78 |
+
+ 1.0
|
79 |
+
)
|
80 |
+
)
|
81 |
+
+ 1.0
|
82 |
+
)
|
83 |
+
# remove contributions below the cutoff radius
|
84 |
+
cutoffs = cutoffs * (distances < self.cutoff_upper)
|
85 |
+
cutoffs = cutoffs * (distances > self.cutoff_lower)
|
86 |
+
return cutoffs
|
87 |
+
else:
|
88 |
+
cutoffs = 0.5 * (torch.cos(distances * math.pi / self.cutoff_upper) + 1.0)
|
89 |
+
# remove contributions beyond the cutoff radius
|
90 |
+
cutoffs = cutoffs * (distances < self.cutoff_upper)
|
91 |
+
return cutoffs
|
92 |
+
|
93 |
+
|
94 |
+
# Implementation from Comformer
|
95 |
+
# https://github.com/divelab/AIRS/tree/main/OpenMat/ComFormer
|
96 |
+
class RBFExpansion(nn.Module):
|
97 |
+
"""Expand interatomic distances with radial basis functions."""
|
98 |
+
|
99 |
+
def __init__(
|
100 |
+
self,
|
101 |
+
vmin: float = 0,
|
102 |
+
vmax: float = 8,
|
103 |
+
bins: int = 40,
|
104 |
+
lengthscale: Optional[float] = None,
|
105 |
+
):
|
106 |
+
"""Register torch parameters for RBF expansion."""
|
107 |
+
super().__init__()
|
108 |
+
self.vmin = vmin
|
109 |
+
self.vmax = vmax
|
110 |
+
self.bins = bins
|
111 |
+
self.register_buffer(
|
112 |
+
"centers", torch.linspace(self.vmin, self.vmax, self.bins)
|
113 |
+
)
|
114 |
+
|
115 |
+
if lengthscale is None:
|
116 |
+
# SchNet-style
|
117 |
+
# set lengthscales relative to granularity of RBF expansion
|
118 |
+
self.lengthscale = np.diff(self.centers).mean()
|
119 |
+
self.gamma = 1 / self.lengthscale
|
120 |
+
|
121 |
+
else:
|
122 |
+
self.lengthscale = lengthscale
|
123 |
+
self.gamma = 1 / (lengthscale ** 2)
|
124 |
+
|
125 |
+
def forward(self, distance: torch.Tensor) -> torch.Tensor:
|
126 |
+
"""Apply RBF expansion to interatomic distance tensor."""
|
127 |
+
return torch.exp(
|
128 |
+
-self.gamma * (distance.unsqueeze(1) - self.centers) ** 2
|
129 |
+
)
|
predict.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from ase.io import read
|
4 |
+
from CifFile import ReadCif
|
5 |
+
from torch_geometric.data import Data, Batch
|
6 |
+
import torch
|
7 |
+
from models.master import create_model
|
8 |
+
from process import process_data
|
9 |
+
from utils import radius_graph_pbc
|
10 |
+
import gc
|
11 |
+
|
12 |
+
MEAN_TEMP = torch.tensor(192.1785) #training temp mean
|
13 |
+
STD_TEMP = torch.tensor(81.2135) #training temp std
|
14 |
+
|
15 |
+
@torch.no_grad()
|
16 |
+
def process_cif(input_file, output_file):
|
17 |
+
model = create_model()
|
18 |
+
|
19 |
+
try:
|
20 |
+
# Read the CIF file using ASE
|
21 |
+
atoms = read(input_file, format="cif")
|
22 |
+
cif = ReadCif(input_file)
|
23 |
+
cif_data = cif.first_block()
|
24 |
+
if "_diffrn_ambient_temperature" in cif_data.keys():
|
25 |
+
temperature = float(cif_data["_diffrn_ambient_temperature"])
|
26 |
+
else:
|
27 |
+
raise ValueError("Temperature not found in the CIF file. \
|
28 |
+
Please provide a temperature in the field _diffrn_ambient_temperature from the CIF file.")
|
29 |
+
|
30 |
+
data = Data()
|
31 |
+
data.x = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.int32)
|
32 |
+
|
33 |
+
if len(atoms.positions) > 300:
|
34 |
+
raise ValueError("This implementation is not optimized for large systems. For large systems, please use the local version.")
|
35 |
+
|
36 |
+
data.pos = torch.tensor(atoms.positions, dtype=torch.float32)
|
37 |
+
data.temperature_og = torch.tensor([temperature], dtype=torch.float32)
|
38 |
+
data.temperature = (data.temperature_og - MEAN_TEMP) / STD_TEMP
|
39 |
+
data.cell = torch.tensor(atoms.cell.array, dtype=torch.float32).unsqueeze(0)
|
40 |
+
|
41 |
+
data.pbc = torch.tensor([True, True, True])
|
42 |
+
data.natoms = len(atoms)
|
43 |
+
|
44 |
+
del atoms
|
45 |
+
gc.collect()
|
46 |
+
batch = Batch.from_data_list([data])
|
47 |
+
|
48 |
+
edge_index, _, _, edge_attr = radius_graph_pbc(batch, 5.0, 64)
|
49 |
+
del batch
|
50 |
+
gc.collect()
|
51 |
+
data.cart_dist = torch.norm(edge_attr, dim=-1)
|
52 |
+
data.cart_dir = torch.nn.functional.normalize(edge_attr, dim=-1)
|
53 |
+
data.edge_index = edge_index
|
54 |
+
data.non_H_mask = data.x != 1
|
55 |
+
delattr(data, "pbc")
|
56 |
+
delattr(data, "natoms")
|
57 |
+
batch = Batch.from_data_list([data])
|
58 |
+
del data, edge_index, edge_attr
|
59 |
+
gc.collect()
|
60 |
+
|
61 |
+
process_data(batch, model, output_file)
|
62 |
+
|
63 |
+
gc.collect()
|
64 |
+
except Exception as e:
|
65 |
+
print(f"An error occurred while processing the CIF file: {e}")
|
66 |
+
|
67 |
+
if __name__ == "__main__":
|
68 |
+
parser = argparse.ArgumentParser(description="Process a CIF file and output the result.")
|
69 |
+
parser.add_argument("input_file", type=str, help="Path to the input CIF file.")
|
70 |
+
parser.add_argument("output_file", type=str, help="Path to the output CIF file.")
|
71 |
+
args = parser.parse_args()
|
72 |
+
|
73 |
+
process_cif(args.input_file, args.output_file)
|
process.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from ase.io import write
|
3 |
+
from ase import Atoms
|
4 |
+
import gc
|
5 |
+
|
6 |
+
@torch.no_grad()
|
7 |
+
def process_data(batch, model, output_file="output.cif"):
|
8 |
+
atoms = batch.x.numpy().astype(int) # Atomic numbers
|
9 |
+
positions = batch.pos.numpy() # Atomic positions
|
10 |
+
cell = batch.cell.squeeze(0).numpy() # Cell parameters
|
11 |
+
temperature = batch.temperature_og.numpy()[0]
|
12 |
+
|
13 |
+
|
14 |
+
adps = model(batch)
|
15 |
+
|
16 |
+
# Convert Ucart to Ucif
|
17 |
+
M = batch.cell.squeeze(0)
|
18 |
+
N = torch.diag(torch.linalg.norm(torch.linalg.inv(M.transpose(-1,-2)).squeeze(0), dim=-1))
|
19 |
+
|
20 |
+
M = torch.linalg.inv(M)
|
21 |
+
N = torch.linalg.inv(N)
|
22 |
+
|
23 |
+
adps = M.transpose(-1,-2)@adps@M
|
24 |
+
adps = N.transpose(-1,-2)@adps@N
|
25 |
+
del M, N
|
26 |
+
gc.collect()
|
27 |
+
|
28 |
+
|
29 |
+
non_H_mask = batch.non_H_mask.numpy()
|
30 |
+
indices = torch.arange(len(atoms))[non_H_mask].numpy()
|
31 |
+
indices = {indices[i]: i for i in range(len(indices))}
|
32 |
+
# Create ASE Atoms object
|
33 |
+
ase_atoms = Atoms(numbers=atoms, positions=positions, cell=cell, pbc=True)
|
34 |
+
|
35 |
+
# Convert positions to fractional coordinates
|
36 |
+
fractional_positions = ase_atoms.get_scaled_positions()
|
37 |
+
|
38 |
+
# Write to CIF file
|
39 |
+
write(output_file, ase_atoms)
|
40 |
+
|
41 |
+
with open(output_file, 'r') as file:
|
42 |
+
lines = file.readlines()
|
43 |
+
|
44 |
+
# Find the line where "loop_" appears and remove lines from there to the end
|
45 |
+
for i, line in enumerate(lines):
|
46 |
+
if line.strip().startswith('loop_'):
|
47 |
+
lines = lines[:i]
|
48 |
+
break
|
49 |
+
|
50 |
+
# Write the modified lines to a new output file
|
51 |
+
with open(output_file, 'w') as file:
|
52 |
+
file.writelines(lines)
|
53 |
+
|
54 |
+
# Manually append positions and ADPs to the CIF file
|
55 |
+
with open(output_file, 'a') as cif_file:
|
56 |
+
|
57 |
+
# Write temperature
|
58 |
+
cif_file.write(f"\n_diffrn_ambient_temperature {temperature}\n")
|
59 |
+
# Write atomic positions
|
60 |
+
cif_file.write("\nloop_\n")
|
61 |
+
cif_file.write("_atom_site_label\n")
|
62 |
+
cif_file.write("_atom_site_type_symbol\n")
|
63 |
+
cif_file.write("_atom_site_fract_x\n")
|
64 |
+
cif_file.write("_atom_site_fract_y\n")
|
65 |
+
cif_file.write("_atom_site_fract_z\n")
|
66 |
+
cif_file.write("_atom_site_U_iso_or_equiv\n")
|
67 |
+
cif_file.write("_atom_site_thermal_displace_type\n")
|
68 |
+
|
69 |
+
element_count = {}
|
70 |
+
for i, (atom_number, frac_pos) in enumerate(zip(atoms, fractional_positions)):
|
71 |
+
element = ase_atoms[i].symbol
|
72 |
+
assert atom_number == ase_atoms[i].number
|
73 |
+
if element not in element_count:
|
74 |
+
element_count[element] = 0
|
75 |
+
element_count[element] += 1
|
76 |
+
label = f"{element}{element_count[element]}"
|
77 |
+
u_iso = torch.trace(adps[indices[i]]).mean() if element != 'H' else 0.01
|
78 |
+
type = "Uani" if element != 'H' else "Uiso"
|
79 |
+
cif_file.write(f"{label} {element} {frac_pos[0]} {frac_pos[1]} {frac_pos[2]} {u_iso} {type}\n")
|
80 |
+
|
81 |
+
# Write ADPs
|
82 |
+
cif_file.write("\nloop_\n")
|
83 |
+
cif_file.write("_atom_site_aniso_label\n")
|
84 |
+
cif_file.write("_atom_site_aniso_U_11\n")
|
85 |
+
cif_file.write("_atom_site_aniso_U_22\n")
|
86 |
+
cif_file.write("_atom_site_aniso_U_33\n")
|
87 |
+
cif_file.write("_atom_site_aniso_U_23\n")
|
88 |
+
cif_file.write("_atom_site_aniso_U_13\n")
|
89 |
+
cif_file.write("_atom_site_aniso_U_12\n")
|
90 |
+
|
91 |
+
element_count = {}
|
92 |
+
for i, atom_number in enumerate(atoms):
|
93 |
+
if atom_number == 1:
|
94 |
+
continue
|
95 |
+
element = ase_atoms[i].symbol
|
96 |
+
if element not in element_count:
|
97 |
+
element_count[element] = 0
|
98 |
+
element_count[element] += 1
|
99 |
+
label = f"{element}{element_count[element]}"
|
100 |
+
cif_file.write(f"{label} {adps[indices[i],0,0]} {adps[indices[i],1,1]} {adps[indices[i],2,2]} {adps[indices[i],1,2]} {adps[indices[i],0,2]} {adps[indices[i],0,1]}\n")
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy<2
|
2 |
+
torch==1.13.1
|
3 |
+
torch_geometric==2.5.2
|
4 |
+
torch-scatter==2.1.1
|
5 |
+
-f https://data.pyg.org/whl/torch-1.13.1+cpu.html
|
6 |
+
streamlit==1.40.1
|
7 |
+
ase==3.23.0
|
8 |
+
PyCifRW==4.4.6
|
utils.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
|
4 |
+
This source code is licensed under the MIT license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from torch_scatter import segment_coo, segment_csr
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
def radius_graph_pbc(
|
16 |
+
data,
|
17 |
+
radius,
|
18 |
+
max_num_neighbors_threshold,
|
19 |
+
enforce_max_neighbors_strictly: bool = False,
|
20 |
+
pbc=[True, True, True],
|
21 |
+
):
|
22 |
+
device = data.pos.device
|
23 |
+
batch_size = len(data.natoms)
|
24 |
+
|
25 |
+
if hasattr(data, "pbc"):
|
26 |
+
data.pbc = torch.atleast_2d(data.pbc)
|
27 |
+
for i in range(3):
|
28 |
+
if not torch.any(data.pbc[:, i]).item():
|
29 |
+
pbc[i] = False
|
30 |
+
elif torch.all(data.pbc[:, i]).item():
|
31 |
+
pbc[i] = True
|
32 |
+
else:
|
33 |
+
raise RuntimeError(
|
34 |
+
"Different structures in the batch have different PBC configurations. This is not currently supported."
|
35 |
+
)
|
36 |
+
|
37 |
+
# position of the atoms
|
38 |
+
atom_pos = data.pos
|
39 |
+
|
40 |
+
# Before computing the pairwise distances between atoms, first create a list of atom indices to compare for the entire batch
|
41 |
+
num_atoms_per_image = data.natoms
|
42 |
+
num_atoms_per_image_sqr = (num_atoms_per_image**2).long()
|
43 |
+
|
44 |
+
# index offset between images
|
45 |
+
index_offset = (
|
46 |
+
torch.cumsum(num_atoms_per_image, dim=0) - num_atoms_per_image
|
47 |
+
)
|
48 |
+
|
49 |
+
index_offset_expand = torch.repeat_interleave(
|
50 |
+
index_offset, num_atoms_per_image_sqr
|
51 |
+
)
|
52 |
+
num_atoms_per_image_expand = torch.repeat_interleave(
|
53 |
+
num_atoms_per_image, num_atoms_per_image_sqr
|
54 |
+
)
|
55 |
+
|
56 |
+
# Compute a tensor containing sequences of numbers that range from 0 to num_atoms_per_image_sqr for each image
|
57 |
+
# that is used to compute indices for the pairs of atoms. This is a very convoluted way to implement
|
58 |
+
# the following (but 10x faster since it removes the for loop)
|
59 |
+
# for batch_idx in range(batch_size):
|
60 |
+
# batch_count = torch.cat([batch_count, torch.arange(num_atoms_per_image_sqr[batch_idx], device=device)], dim=0)
|
61 |
+
num_atom_pairs = torch.sum(num_atoms_per_image_sqr)
|
62 |
+
index_sqr_offset = (
|
63 |
+
torch.cumsum(num_atoms_per_image_sqr, dim=0) - num_atoms_per_image_sqr
|
64 |
+
)
|
65 |
+
index_sqr_offset = torch.repeat_interleave(
|
66 |
+
index_sqr_offset, num_atoms_per_image_sqr
|
67 |
+
)
|
68 |
+
atom_count_sqr = (
|
69 |
+
torch.arange(num_atom_pairs, device=device) - index_sqr_offset
|
70 |
+
)
|
71 |
+
|
72 |
+
# Compute the indices for the pairs of atoms (using division and mod)
|
73 |
+
# If the systems get too large this apporach could run into numerical precision issues
|
74 |
+
index1 = (
|
75 |
+
torch.div(
|
76 |
+
atom_count_sqr, num_atoms_per_image_expand, rounding_mode="floor"
|
77 |
+
)
|
78 |
+
) + index_offset_expand
|
79 |
+
index2 = (
|
80 |
+
atom_count_sqr % num_atoms_per_image_expand
|
81 |
+
) + index_offset_expand
|
82 |
+
# Get the positions for each atom
|
83 |
+
pos1 = torch.index_select(atom_pos, 0, index1)
|
84 |
+
pos2 = torch.index_select(atom_pos, 0, index2)
|
85 |
+
|
86 |
+
# Calculate required number of unit cells in each direction.
|
87 |
+
# Smallest distance between planes separated by a1 is
|
88 |
+
# 1 / ||(a2 x a3) / V||_2, since a2 x a3 is the area of the plane.
|
89 |
+
# Note that the unit cell volume V = a1 * (a2 x a3) and that
|
90 |
+
# (a2 x a3) / V is also the reciprocal primitive vector
|
91 |
+
# (crystallographer's definition).
|
92 |
+
|
93 |
+
cross_a2a3 = torch.cross(data.cell[:, 1], data.cell[:, 2], dim=-1)
|
94 |
+
cell_vol = torch.sum(data.cell[:, 0] * cross_a2a3, dim=-1, keepdim=True)
|
95 |
+
|
96 |
+
if pbc[0]:
|
97 |
+
inv_min_dist_a1 = torch.norm(cross_a2a3 / cell_vol, p=2, dim=-1)
|
98 |
+
rep_a1 = torch.ceil(radius * inv_min_dist_a1)
|
99 |
+
else:
|
100 |
+
rep_a1 = data.cell.new_zeros(1)
|
101 |
+
|
102 |
+
if pbc[1]:
|
103 |
+
cross_a3a1 = torch.cross(data.cell[:, 2], data.cell[:, 0], dim=-1)
|
104 |
+
inv_min_dist_a2 = torch.norm(cross_a3a1 / cell_vol, p=2, dim=-1)
|
105 |
+
rep_a2 = torch.ceil(radius * inv_min_dist_a2)
|
106 |
+
else:
|
107 |
+
rep_a2 = data.cell.new_zeros(1)
|
108 |
+
|
109 |
+
if pbc[2]:
|
110 |
+
cross_a1a2 = torch.cross(data.cell[:, 0], data.cell[:, 1], dim=-1)
|
111 |
+
inv_min_dist_a3 = torch.norm(cross_a1a2 / cell_vol, p=2, dim=-1)
|
112 |
+
rep_a3 = torch.ceil(radius * inv_min_dist_a3)
|
113 |
+
else:
|
114 |
+
rep_a3 = data.cell.new_zeros(1)
|
115 |
+
|
116 |
+
# Take the max over all images for uniformity. This is essentially padding.
|
117 |
+
# Note that this can significantly increase the number of computed distances
|
118 |
+
# if the required repetitions are very different between images
|
119 |
+
# (which they usually are). Changing this to sparse (scatter) operations
|
120 |
+
# might be worth the effort if this function becomes a bottleneck.
|
121 |
+
max_rep = [rep_a1.max(), rep_a2.max(), rep_a3.max()]
|
122 |
+
|
123 |
+
# Tensor of unit cells
|
124 |
+
cells_per_dim = [
|
125 |
+
torch.arange(-rep, rep + 1, device=device, dtype=torch.float)
|
126 |
+
for rep in max_rep
|
127 |
+
]
|
128 |
+
unit_cell = torch.cartesian_prod(*cells_per_dim)
|
129 |
+
num_cells = len(unit_cell)
|
130 |
+
unit_cell_per_atom = unit_cell.view(1, num_cells, 3).repeat(
|
131 |
+
len(index2), 1, 1
|
132 |
+
)
|
133 |
+
unit_cell = torch.transpose(unit_cell, 0, 1)
|
134 |
+
unit_cell_batch = unit_cell.view(1, 3, num_cells).expand(
|
135 |
+
batch_size, -1, -1
|
136 |
+
)
|
137 |
+
|
138 |
+
# Compute the x, y, z positional offsets for each cell in each image
|
139 |
+
data_cell = torch.transpose(data.cell, 1, 2)
|
140 |
+
pbc_offsets = torch.bmm(data_cell, unit_cell_batch)
|
141 |
+
pbc_offsets_per_atom = torch.repeat_interleave(
|
142 |
+
pbc_offsets, num_atoms_per_image_sqr, dim=0
|
143 |
+
)
|
144 |
+
|
145 |
+
# Expand the positions and indices for the 9 cells
|
146 |
+
pos1 = pos1.view(-1, 3, 1).expand(-1, -1, num_cells)
|
147 |
+
pos2 = pos2.view(-1, 3, 1).expand(-1, -1, num_cells)
|
148 |
+
index1 = index1.view(-1, 1).repeat(1, num_cells).view(-1)
|
149 |
+
index2 = index2.view(-1, 1).repeat(1, num_cells).view(-1)
|
150 |
+
# Add the PBC offsets for the second atom
|
151 |
+
pos2 = pos2 + pbc_offsets_per_atom
|
152 |
+
|
153 |
+
# Compute the squared distance between atoms
|
154 |
+
direction = pos1 - pos2
|
155 |
+
atom_distance_sqr = torch.sum((direction) ** 2, dim=1)
|
156 |
+
direction = direction.permute(0, 2, 1).reshape(-1, 3)
|
157 |
+
atom_distance_sqr = atom_distance_sqr.view(-1)
|
158 |
+
|
159 |
+
# Remove pairs that are too far apart
|
160 |
+
mask_within_radius = torch.le(atom_distance_sqr, radius * radius)
|
161 |
+
# Remove pairs with the same atoms (distance = 0.0)
|
162 |
+
mask_not_same = torch.gt(atom_distance_sqr, 0.0001)
|
163 |
+
mask = torch.logical_and(mask_within_radius, mask_not_same)
|
164 |
+
index1 = torch.masked_select(index1, mask)
|
165 |
+
index2 = torch.masked_select(index2, mask)
|
166 |
+
unit_cell = torch.masked_select(
|
167 |
+
unit_cell_per_atom.view(-1, 3), mask.view(-1, 1).expand(-1, 3)
|
168 |
+
)
|
169 |
+
unit_cell = unit_cell.view(-1, 3)
|
170 |
+
atom_distance_sqr = torch.masked_select(atom_distance_sqr, mask)
|
171 |
+
direction = torch.masked_select(direction, mask.view(-1, 1).expand(-1, 3)).view(-1, 3)
|
172 |
+
|
173 |
+
if max_num_neighbors_threshold is not None:
|
174 |
+
mask_num_neighbors, num_neighbors_image = get_max_neighbors_mask(
|
175 |
+
natoms=data.natoms,
|
176 |
+
index=index1,
|
177 |
+
atom_distance=atom_distance_sqr,
|
178 |
+
max_num_neighbors_threshold=max_num_neighbors_threshold,
|
179 |
+
enforce_max_strictly=enforce_max_neighbors_strictly,
|
180 |
+
)
|
181 |
+
|
182 |
+
if not torch.all(mask_num_neighbors):
|
183 |
+
# Mask out the atoms to ensure each atom has at most max_num_neighbors_threshold neighbors
|
184 |
+
index1 = torch.masked_select(index1, mask_num_neighbors)
|
185 |
+
index2 = torch.masked_select(index2, mask_num_neighbors)
|
186 |
+
atom_distance_sqr = torch.masked_select(atom_distance_sqr, mask_num_neighbors)
|
187 |
+
direction = torch.masked_select(direction, mask_num_neighbors.view(-1, 1).expand(-1, 3)).view(-1, 3)
|
188 |
+
unit_cell = torch.masked_select(
|
189 |
+
unit_cell.view(-1, 3), mask_num_neighbors.view(-1, 1).expand(-1, 3)
|
190 |
+
)
|
191 |
+
unit_cell = unit_cell.view(-1, 3)
|
192 |
+
|
193 |
+
edge_index = torch.stack((index2, index1))
|
194 |
+
|
195 |
+
return edge_index, unit_cell, torch.sqrt(atom_distance_sqr), direction
|
196 |
+
|
197 |
+
|
198 |
+
def get_max_neighbors_mask(
|
199 |
+
natoms,
|
200 |
+
index,
|
201 |
+
atom_distance,
|
202 |
+
max_num_neighbors_threshold,
|
203 |
+
degeneracy_tolerance: float = 0.01,
|
204 |
+
enforce_max_strictly: bool = False,
|
205 |
+
):
|
206 |
+
"""
|
207 |
+
Give a mask that filters out edges so that each atom has at most
|
208 |
+
`max_num_neighbors_threshold` neighbors.
|
209 |
+
Assumes that `index` is sorted.
|
210 |
+
|
211 |
+
Enforcing the max strictly can force the arbitrary choice between
|
212 |
+
degenerate edges. This can lead to undesired behaviors; for
|
213 |
+
example, bulk formation energies which are not invariant to
|
214 |
+
unit cell choice.
|
215 |
+
|
216 |
+
A degeneracy tolerance can help prevent sudden changes in edge
|
217 |
+
existence from small changes in atom position, for example,
|
218 |
+
rounding errors, slab relaxation, temperature, etc.
|
219 |
+
"""
|
220 |
+
|
221 |
+
device = natoms.device
|
222 |
+
num_atoms = natoms.sum()
|
223 |
+
|
224 |
+
# Get number of neighbors
|
225 |
+
# segment_coo assumes sorted index
|
226 |
+
ones = index.new_ones(1).expand_as(index)
|
227 |
+
num_neighbors = segment_coo(ones, index, dim_size=num_atoms)
|
228 |
+
max_num_neighbors = num_neighbors.max()
|
229 |
+
num_neighbors_thresholded = num_neighbors.clamp(
|
230 |
+
max=max_num_neighbors_threshold
|
231 |
+
)
|
232 |
+
|
233 |
+
# Get number of (thresholded) neighbors per image
|
234 |
+
image_indptr = torch.zeros(
|
235 |
+
natoms.shape[0] + 1, device=device, dtype=torch.long
|
236 |
+
)
|
237 |
+
image_indptr[1:] = torch.cumsum(natoms, dim=0)
|
238 |
+
num_neighbors_image = segment_csr(num_neighbors_thresholded, image_indptr)
|
239 |
+
|
240 |
+
# If max_num_neighbors is below the threshold, return early
|
241 |
+
if (
|
242 |
+
max_num_neighbors <= max_num_neighbors_threshold
|
243 |
+
or max_num_neighbors_threshold <= 0
|
244 |
+
):
|
245 |
+
mask_num_neighbors = torch.tensor(
|
246 |
+
[True], dtype=bool, device=device
|
247 |
+
).expand_as(index)
|
248 |
+
return mask_num_neighbors, num_neighbors_image
|
249 |
+
|
250 |
+
# Create a tensor of size [num_atoms, max_num_neighbors] to sort the distances of the neighbors.
|
251 |
+
# Fill with infinity so we can easily remove unused distances later.
|
252 |
+
distance_sort = torch.full(
|
253 |
+
[num_atoms * max_num_neighbors], np.inf, device=device
|
254 |
+
)
|
255 |
+
|
256 |
+
# Create an index map to map distances from atom_distance to distance_sort
|
257 |
+
# index_sort_map assumes index to be sorted
|
258 |
+
index_neighbor_offset = torch.cumsum(num_neighbors, dim=0) - num_neighbors
|
259 |
+
index_neighbor_offset_expand = torch.repeat_interleave(
|
260 |
+
index_neighbor_offset, num_neighbors
|
261 |
+
)
|
262 |
+
index_sort_map = (
|
263 |
+
index * max_num_neighbors
|
264 |
+
+ torch.arange(len(index), device=device)
|
265 |
+
- index_neighbor_offset_expand
|
266 |
+
)
|
267 |
+
print(index_sort_map.dtype, atom_distance.dtype)
|
268 |
+
distance_sort.index_copy_(0, index_sort_map, atom_distance)
|
269 |
+
distance_sort = distance_sort.view(num_atoms, max_num_neighbors)
|
270 |
+
|
271 |
+
# Sort neighboring atoms based on distance
|
272 |
+
distance_sort, index_sort = torch.sort(distance_sort, dim=1)
|
273 |
+
|
274 |
+
# Select the max_num_neighbors_threshold neighbors that are closest
|
275 |
+
if enforce_max_strictly:
|
276 |
+
distance_sort = distance_sort[:, :max_num_neighbors_threshold]
|
277 |
+
index_sort = index_sort[:, :max_num_neighbors_threshold]
|
278 |
+
max_num_included = max_num_neighbors_threshold
|
279 |
+
|
280 |
+
else:
|
281 |
+
effective_cutoff = (
|
282 |
+
distance_sort[:, max_num_neighbors_threshold]
|
283 |
+
+ degeneracy_tolerance
|
284 |
+
)
|
285 |
+
is_included = torch.le(distance_sort.T, effective_cutoff)
|
286 |
+
|
287 |
+
# Set all undesired edges to infinite length to be removed later
|
288 |
+
distance_sort[~is_included.T] = np.inf
|
289 |
+
|
290 |
+
# Subselect tensors for efficiency
|
291 |
+
num_included_per_atom = torch.sum(is_included, dim=0)
|
292 |
+
max_num_included = torch.max(num_included_per_atom)
|
293 |
+
distance_sort = distance_sort[:, :max_num_included]
|
294 |
+
index_sort = index_sort[:, :max_num_included]
|
295 |
+
|
296 |
+
# Recompute the number of neighbors
|
297 |
+
num_neighbors_thresholded = num_neighbors.clamp(
|
298 |
+
max=num_included_per_atom
|
299 |
+
)
|
300 |
+
|
301 |
+
num_neighbors_image = segment_csr(
|
302 |
+
num_neighbors_thresholded, image_indptr
|
303 |
+
)
|
304 |
+
|
305 |
+
# Offset index_sort so that it indexes into index
|
306 |
+
index_sort = index_sort + index_neighbor_offset.view(-1, 1).expand(
|
307 |
+
-1, max_num_included
|
308 |
+
)
|
309 |
+
# Remove "unused pairs" with infinite distances
|
310 |
+
mask_finite = torch.isfinite(distance_sort)
|
311 |
+
index_sort = torch.masked_select(index_sort, mask_finite)
|
312 |
+
|
313 |
+
# At this point index_sort contains the index into index of the
|
314 |
+
# closest max_num_neighbors_threshold neighbors per atom
|
315 |
+
# Create a mask to remove all pairs not in index_sort
|
316 |
+
mask_num_neighbors = torch.zeros(len(index), device=device, dtype=bool)
|
317 |
+
mask_num_neighbors.index_fill_(0, index_sort, True)
|
318 |
+
|
319 |
+
return mask_num_neighbors, num_neighbors_image
|
320 |
+
|
321 |
+
|
322 |
+
|
323 |
+
|