Àlex Solé commited on
Commit
744c6a1
·
1 Parent(s): 3fa8a08

merged from streamlit

Browse files
Files changed (15) hide show
  1. .gitattributes +1 -0
  2. .gitignore +1 -0
  3. LICENSE +21 -0
  4. README.md +57 -13
  5. cpkt/cartnet_adp.ckpt +3 -0
  6. fig/pipeline.png +3 -0
  7. main.py +127 -0
  8. main_local.py +119 -0
  9. models/cartnet.py +289 -0
  10. models/master.py +14 -0
  11. models/utils.py +129 -0
  12. predict.py +73 -0
  13. process.py +100 -0
  14. requirements.txt +8 -0
  15. utils.py +323 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .DS_Store
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Àlex Solé Gómez
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,13 +1,57 @@
1
- ---
2
- title: Cartnet Demo
3
- emoji: 📈
4
- colorFrom: purple
5
- colorTo: red
6
- sdk: streamlit
7
- sdk_version: 1.40.1
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CartNet Streamlit Web App
2
+
3
+ ![Pipeline](fig/pipeline.png)
4
+
5
+ ### CartNet online demo available at: [CartNet Web App](https://cartnet-adp-estimation.streamlit.app)
6
+
7
+
8
+ CartNet is a graph neural network specifically designed to predict Anisotropic Displacement Parameters (ADPs) in crystal structures. The model has been trained on over 220,000 molecular crystal structures from the Cambridge Structural Database (CSD), making it highly accurate and robust for ADP prediction tasks. CartNet addresses the computational challenges of traditional methods by encoding the full 3D geometry of atomic structures into a Cartesian reference frame, bypassing the need for unit cell encoding. The model incorporates innovative features, including a neighbour equalization technique to enhance interaction detection and a Cholesky-based output layer to ensure valid ADP predictions. Additionally, it introduces a rotational SO(3) data augmentation technique to improve generalization across different crystal structure orientations, making the model highly efficient and accurate in predicting ADPs while significantly reducing computational costs.
9
+
10
+ This repository contains a web application based on the official implementation of CartNet, which can be found at [imatge-upc/CartNet](https://github.com/imatge-upc/CartNet).
11
+
12
+ ⚠️ **Warning**: The online web application can only process systems with less than 300 atoms in the unit cell. For large systems, please use the local application.
13
+
14
+ ## Local Application
15
+ ### Installation of the local application
16
+
17
+ To set up the local application, you need to install the dependencies listed in `requirements.txt`. You can do this by running the following command:
18
+
19
+ ```bash
20
+ pip install -r requirements.txt
21
+ ```
22
+
23
+ ### Usage
24
+
25
+ You can make predictions directly from Python using the `predict.py` script.
26
+
27
+ The script takes two arguments:
28
+ 1. `input_file`: Path to the input CIF file
29
+ 2. `output_file`: Path where you want to save the processed CIF file
30
+
31
+ Example usage:
32
+
33
+ ```bash
34
+ python predict.py input.cif output.cif
35
+ ```
36
+
37
+ Or, if you prefer, you can use the browser app on your local machine without the atom number limitation by running:
38
+
39
+ ```bash
40
+ streamlit run main_local.py
41
+ ```
42
+
43
+ ## How to cite
44
+
45
+ If you use CartNet in your research, please cite our paper:
46
+
47
+ ```bibtex
48
+ @article{your_paper_citation,
49
+ title={Title of the Paper},
50
+ author={Author1 and Author2 and Author3},
51
+ journal={Journal Name},
52
+ year={2023},
53
+ volume={XX},
54
+ number={YY},
55
+ pages={ZZZ}
56
+ }
57
+ ```
cpkt/cartnet_adp.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0829f2e8a631b380c04e2beb5c75b36767e5635321deea17fc5e7bafee643332
3
+ size 30066357
fig/pipeline.png ADDED

Git LFS Details

  • SHA256: 2c02f5df0375788cb98c5f80872cb0859b48a86c23927757c531e2cc21de1d96
  • Pointer size: 132 Bytes
  • Size of remote file: 1.83 MB
main.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ from ase.io import read
4
+ from CifFile import ReadCif
5
+ from torch_geometric.data import Data, Batch
6
+ import torch
7
+ from models.master import create_model
8
+ from process import process_data
9
+ from utils import radius_graph_pbc
10
+ import gc
11
+
12
+ MEAN_TEMP = torch.tensor(192.1785) #training temp mean
13
+ STD_TEMP = torch.tensor(81.2135) #training temp std
14
+
15
+
16
+ @torch.no_grad()
17
+ def main():
18
+ model = create_model()
19
+ st.title("CartNet ADP Prediction")
20
+ st.image('fig/pipeline.png')
21
+
22
+ st.markdown("""
23
+ CartNet is a graph neural network specifically designed for predicting Anisotropic Displacement Parameters (ADPs) in crystal structures. The model has been trained on over 220,000 molecular crystal structures from the Cambridge Structural Database (CSD), making it highly accurate and robust for ADP prediction tasks. CartNet addresses the computational challenges of traditional methods by encoding the full 3D geometry of atomic structures into a Cartesian reference frame, bypassing the need for unit cell encoding. The model incorporates innovative features, including a neighbour equalization technique to enhance interaction detection and a Cholesky-based output layer to ensure valid ADP predictions. Additionally, it introduces a rotational SO(3) data augmentation technique to improve generalization across different crystal structure orientations, making the model highly efficient and accurate in predicting ADPs while significantly reducing computational costs.
24
+ """)
25
+
26
+ uploaded_file = st.file_uploader("Upload a CIF file", type=["cif"], accept_multiple_files=False)
27
+ # uploaded_file = "ABABEM.cif"
28
+ if uploaded_file is not None:
29
+ try:
30
+ with open(uploaded_file.name, "wb") as f:
31
+ f.write(uploaded_file.getbuffer())
32
+ filename = str(uploaded_file.name)
33
+ # Read the CIF file using ASE
34
+ atoms = read(filename, format="cif")
35
+ cif = ReadCif(filename)
36
+ cif_data = cif.first_block()
37
+ if "_diffrn_ambient_temperature" in cif_data.keys():
38
+ temperature = float(cif_data["_diffrn_ambient_temperature"])
39
+ else:
40
+ raise ValueError("Temperature not found in the CIF file. \
41
+ Please provide a temperature in the field _diffrn_ambient_temperature from the CIF file.")
42
+ st.success("CIF file successfully read.")
43
+
44
+ data = Data()
45
+ data.x = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.int32)
46
+
47
+ if len(atoms.positions) > 300:
48
+ st.markdown("""
49
+ ⚠️ **Warning**: The structure is too large. Please upload a smaller one or use the [local implementation of CartNet Web App](https://github.com/alexsoleg/cartnet-streamlit/).
50
+ """)
51
+ raise ValueError("Please provide a structure with less than 300 atoms in the unit cell.")
52
+
53
+ data.pos = torch.tensor(atoms.positions, dtype=torch.float32)
54
+ data.temperature_og = torch.tensor([temperature], dtype=torch.float32)
55
+ data.temperature = (data.temperature_og - MEAN_TEMP) / STD_TEMP
56
+ data.cell = torch.tensor(atoms.cell.array, dtype=torch.float32).unsqueeze(0)
57
+
58
+ data.pbc = torch.tensor([True, True, True])
59
+ data.natoms = len(atoms)
60
+
61
+ del atoms
62
+ gc.collect()
63
+ batch = Batch.from_data_list([data])
64
+
65
+
66
+ edge_index, _, _, edge_attr = radius_graph_pbc(batch, 5.0, 64)
67
+ del batch
68
+ gc.collect()
69
+ data.cart_dist = torch.norm(edge_attr, dim=-1)
70
+ data.cart_dir = torch.nn.functional.normalize(edge_attr, dim=-1)
71
+ data.edge_index = edge_index
72
+ data.non_H_mask = data.x != 1
73
+ delattr(data, "pbc")
74
+ delattr(data, "natoms")
75
+ batch = Batch.from_data_list([data])
76
+ del data, edge_index, edge_attr
77
+ gc.collect()
78
+
79
+ st.success("Graph successfully created.")
80
+
81
+ process_data(batch, model)
82
+ st.success("ADPs successfully predicted.")
83
+
84
+ # Create a download button for the processed CIF file
85
+ with open("output.cif", "r") as f:
86
+ cif_contents = f.read()
87
+
88
+ st.download_button(
89
+ label="Download processed CIF file",
90
+ data=cif_contents,
91
+ file_name="output.cif",
92
+ mime="text/plain"
93
+ )
94
+
95
+ os.remove("output.cif")
96
+ os.remove(filename)
97
+ gc.collect()
98
+ except Exception as e:
99
+ st.error(f"An error occurred while reading the CIF file: {e}")
100
+ st.markdown("""
101
+ ⚠️ **Warning**: This online web application is designed for structures with up to 300 atoms in the unit cell. For larger structures, please use the [local implementation of CartNet Web App](https://github.com/alexsoleg/cartnet-streamlit/).
102
+ """)
103
+
104
+ st.markdown("""
105
+ 📌 The official implementation of the paper with all experiments can be found at [CartNet GitHub Repository](https://github.com/imatge-upc/CartNet).
106
+ """)
107
+
108
+ st.markdown("""
109
+ ### How to cite
110
+
111
+ If you use CartNet in your research, please cite our paper:
112
+
113
+ ```bibtex
114
+ @article{your_paper_citation,
115
+ title={Title of the Paper},
116
+ author={Author1 and Author2 and Author3},
117
+ journal={Journal Name},
118
+ year={2023},
119
+ volume={XX},
120
+ number={YY},
121
+ pages={ZZZ}
122
+ }
123
+ ```
124
+ """)
125
+
126
+ if __name__ == "__main__":
127
+ main()
main_local.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ from ase.io import read
4
+ from CifFile import ReadCif
5
+ from torch_geometric.data import Data, Batch
6
+ import torch
7
+ from models.master import create_model
8
+ from process import process_data
9
+ from utils import radius_graph_pbc
10
+ import gc
11
+
12
+ MEAN_TEMP = torch.tensor(192.1785) #training temp mean
13
+ STD_TEMP = torch.tensor(81.2135) #training temp std
14
+
15
+
16
+ @torch.no_grad()
17
+ def main():
18
+ model = create_model()
19
+ st.title("CartNet ADP Prediction")
20
+ st.image('fig/pipeline.png')
21
+
22
+ st.markdown("""
23
+ CartNet is a graph neural network specifically designed for predicting Anisotropic Displacement Parameters (ADPs) in crystal structures. The model has been trained on over 220,000 molecular crystal structures from the Cambridge Structural Database (CSD), making it highly accurate and robust for ADP prediction tasks. CartNet addresses the computational challenges of traditional methods by encoding the full 3D geometry of atomic structures into a Cartesian reference frame, bypassing the need for unit cell encoding. The model incorporates innovative features, including a neighbour equalization technique to enhance interaction detection and a Cholesky-based output layer to ensure valid ADP predictions. Additionally, it introduces a rotational SO(3) data augmentation technique to improve generalization across different crystal structure orientations, making the model highly efficient and accurate in predicting ADPs while significantly reducing computational costs.
24
+ """)
25
+
26
+ uploaded_file = st.file_uploader("Upload a CIF file", type=["cif"], accept_multiple_files=False)
27
+ # uploaded_file = "ABABEM.cif"
28
+ if uploaded_file is not None:
29
+ try:
30
+ with open(uploaded_file.name, "wb") as f:
31
+ f.write(uploaded_file.getbuffer())
32
+ filename = str(uploaded_file.name)
33
+ # Read the CIF file using ASE
34
+ atoms = read(filename, format="cif")
35
+ cif = ReadCif(filename)
36
+ cif_data = cif.first_block()
37
+ if "_diffrn_ambient_temperature" in cif_data.keys():
38
+ temperature = float(cif_data["_diffrn_ambient_temperature"])
39
+ else:
40
+ raise ValueError("Temperature not found in the CIF file. \
41
+ Please provide a temperature in the field _diffrn_ambient_temperature from the CIF file.")
42
+ st.success("CIF file successfully read.")
43
+
44
+ data = Data()
45
+ data.x = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.int32)
46
+
47
+ data.pos = torch.tensor(atoms.positions, dtype=torch.float32)
48
+ data.temperature_og = torch.tensor([temperature], dtype=torch.float32)
49
+ data.temperature = (data.temperature_og - MEAN_TEMP) / STD_TEMP
50
+ data.cell = torch.tensor(atoms.cell.array, dtype=torch.float32).unsqueeze(0)
51
+
52
+ data.pbc = torch.tensor([True, True, True])
53
+ data.natoms = len(atoms)
54
+
55
+ del atoms
56
+ gc.collect()
57
+ batch = Batch.from_data_list([data])
58
+
59
+
60
+ edge_index, _, _, edge_attr = radius_graph_pbc(batch, 5.0, 64)
61
+ del batch
62
+ gc.collect()
63
+ data.cart_dist = torch.norm(edge_attr, dim=-1)
64
+ data.cart_dir = torch.nn.functional.normalize(edge_attr, dim=-1)
65
+ data.edge_index = edge_index
66
+ data.non_H_mask = data.x != 1
67
+ delattr(data, "pbc")
68
+ delattr(data, "natoms")
69
+ batch = Batch.from_data_list([data])
70
+ del data, edge_index, edge_attr
71
+ gc.collect()
72
+
73
+ st.success("Graph successfully created.")
74
+
75
+ process_data(batch, model)
76
+ st.success("ADPs successfully predicted.")
77
+
78
+ # Create a download button for the processed CIF file
79
+ with open("output.cif", "r") as f:
80
+ cif_contents = f.read()
81
+
82
+ st.download_button(
83
+ label="Download processed CIF file",
84
+ data=cif_contents,
85
+ file_name="output.cif",
86
+ mime="text/plain"
87
+ )
88
+
89
+ os.remove("output.cif")
90
+ os.remove(filename)
91
+ gc.collect()
92
+ except Exception as e:
93
+ st.error(f"An error occurred while reading the CIF file: {e}")
94
+
95
+
96
+ st.markdown("""
97
+ 📌 The official implementation of the paper with all experiments can be found at [CartNet GitHub Repository](https://github.com/imatge-upc/CartNet).
98
+ """)
99
+
100
+ st.markdown("""
101
+ ### How to cite
102
+
103
+ If you use CartNet in your research, please cite our paper:
104
+
105
+ ```bibtex
106
+ @article{your_paper_citation,
107
+ title={Title of the Paper},
108
+ author={Author1 and Author2 and Author3},
109
+ journal={Journal Name},
110
+ year={2023},
111
+ volume={XX},
112
+ number={YY},
113
+ pages={ZZZ}
114
+ }
115
+ ```
116
+ """)
117
+
118
+ if __name__ == "__main__":
119
+ main()
models/cartnet.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Universitat Politècnica de Catalunya 2024 https://imatge.upc.edu
2
+ # Distributed under the MIT License.
3
+ # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT)
4
+
5
+ import torch
6
+ import torch_geometric.nn as pyg_nn
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch_scatter import scatter
10
+ from models.utils import ExpNormalSmearing, CosineCutoff
11
+
12
+
13
+ class CartNet(torch.nn.Module):
14
+ """
15
+ CartNet model from Cartesian Encoding Graph Neural Network for Crystal Structures Property Prediction: Application to Thermal Ellipsoid Estimation.
16
+ Args:
17
+ dim_in (int): Dimensionality of the input features.
18
+ dim_rbf (int): Dimensionality of the radial basis function embeddings.
19
+ num_layers (int): Number of CartNet layers in the model.
20
+ radius (float, optional): Radius cutoff for neighbor interactions. Default is 5.0.
21
+ invariant (bool, optional): If `True`, enforces rotational invariance in the encoder. Default is `False`.
22
+ temperature (bool, optional): If `True`, includes temperature information in the encoder. Default is `True`.
23
+ use_envelope (bool, optional): If `True`, applies an envelope function to the interactions. Default is `True`.
24
+ cholesky (bool, optional): If `True`, uses a Cholesky head for the output. If `False`, uses a scalar head. Default is `True`.
25
+ Methods:
26
+ forward(batch):
27
+ Performs a forward pass of the model.
28
+ Args:
29
+ batch: A batch of input data.
30
+ Returns:
31
+ pred: The model's predictions.
32
+ true: The ground truth values corresponding to the input batch.
33
+ """
34
+
35
+
36
+ def __init__(self,
37
+ dim_in: int,
38
+ dim_rbf: int,
39
+ num_layers: int,
40
+ radius: float = 5.0,
41
+ invariant: bool = False,
42
+ temperature: bool = True,
43
+ use_envelope: bool = True,
44
+ cholesky: bool = True):
45
+ super().__init__()
46
+
47
+ self.encoder = Encoder(dim_in, dim_rbf=dim_rbf, radius=radius, invariant=invariant, temperature=temperature)
48
+ self.dim_in = dim_in
49
+
50
+ layers = []
51
+ for _ in range(num_layers):
52
+ layers.append(CartNet_layer(
53
+ dim_in=dim_in,
54
+ use_envelope=use_envelope,
55
+ ))
56
+ self.layers = torch.nn.Sequential(*layers)
57
+
58
+ if cholesky:
59
+ self.head = Cholesky_head(dim_in)
60
+ else:
61
+ self.head = Scalar_head(dim_in)
62
+
63
+ def forward(self, batch):
64
+ batch = self.encoder(batch)
65
+
66
+ for layer in self.layers:
67
+ batch = layer(batch)
68
+
69
+ pred = self.head(batch)
70
+
71
+ return pred
72
+
73
+ class Encoder(torch.nn.Module):
74
+ """
75
+ Encoder module for the CartNet model.
76
+ This module encodes node and edge features for input into the CartNet model, incorporating optional temperature information and rotational invariance.
77
+ Args:
78
+ dim_in (int): Dimension of the input features after embedding.
79
+ dim_rbf (int): Dimension of the radial basis function used for edge attributes.
80
+ radius (float, optional): Cutoff radius for neighbor interactions. Defaults to 5.0.
81
+ invariant (bool, optional): If True, the encoder enforces rotational invariance by excluding directional information from edge attributes. Defaults to False.
82
+ temperature (bool, optional): If True, includes temperature data in the node embeddings. Defaults to True.
83
+ Attributes:
84
+ dim_in (int): Dimension of the input features.
85
+ invariant (bool): Indicates if rotational invariance is enforced.
86
+ temperature (bool): Indicates if temperature information is included.
87
+ embedding (nn.Embedding): Embedding layer mapping atomic numbers to feature vectors.
88
+ temperature_proj_atom (pyg_nn.Linear): Linear layer projecting temperature to embedding dimensions (used if temperature is True).
89
+ bias (nn.Parameter): Bias term added to embeddings (used if temperature is False).
90
+ activation (nn.Module): Activation function (SiLU).
91
+ encoder_atom (nn.Sequential): Sequential network encoding node features.
92
+ encoder_edge (nn.Sequential): Sequential network encoding edge features.
93
+ rbf (ExpNormalSmearing): Radial basis function for encoding distances.
94
+ """
95
+
96
+ def __init__(
97
+ self,
98
+ dim_in: int,
99
+ dim_rbf: int,
100
+ radius: float = 5.0,
101
+ invariant: bool = False,
102
+ temperature: bool = True,
103
+ ):
104
+ super(Encoder, self).__init__()
105
+ self.dim_in = dim_in
106
+ self.invariant = invariant
107
+ self.temperature = temperature
108
+ self.embedding = nn.Embedding(119, self.dim_in*2)
109
+ if self.temperature:
110
+ self.temperature_proj_atom = pyg_nn.Linear(1, self.dim_in*2, bias=True)
111
+ else:
112
+ self.bias = nn.Parameter(torch.zeros(self.dim_in*2))
113
+ self.activation = nn.SiLU(inplace=True)
114
+
115
+
116
+ self.encoder_atom = nn.Sequential(self.activation,
117
+ pyg_nn.Linear(self.dim_in*2, self.dim_in),
118
+ self.activation)
119
+ if self.invariant:
120
+ dim_edge = dim_rbf
121
+ else:
122
+ dim_edge = dim_rbf + 3
123
+
124
+ self.encoder_edge = nn.Sequential(pyg_nn.Linear(dim_edge, self.dim_in*2),
125
+ self.activation,
126
+ pyg_nn.Linear(self.dim_in*2, self.dim_in),
127
+ self.activation)
128
+
129
+ self.rbf = ExpNormalSmearing(0.0,radius,dim_rbf,False)
130
+
131
+ torch.nn.init.xavier_uniform_(self.embedding.weight.data)
132
+
133
+ def forward(self, batch):
134
+
135
+ x = self.embedding(batch.x) + self.temperature_proj_atom(batch.temperature.unsqueeze(-1))[batch.batch]
136
+
137
+
138
+ batch.x = self.encoder_atom(x)
139
+
140
+ batch.edge_attr = self.encoder_edge(torch.cat([self.rbf(batch.cart_dist), batch.cart_dir], dim=-1))
141
+
142
+ return batch
143
+
144
+ class CartNet_layer(pyg_nn.conv.MessagePassing):
145
+ """
146
+ The message-passing layer used in the CartNet architecture.
147
+ Parameters:
148
+ dim_in (int): Dimension of the input node features.
149
+ use_envelope (bool, optional): If True, applies an envelope function to the distances. Defaults to True.
150
+ Attributes:
151
+ dim_in (int): Dimension of the input node features.
152
+ activation (nn.Module): Activation function (SiLU) used in the layer.
153
+ MLP_aggr (nn.Sequential): MLP used for aggregating messages.
154
+ MLP_gate (nn.Sequential): MLP used for computing gating coefficients.
155
+ norm (nn.BatchNorm1d): Batch normalization applied to the gating coefficients.
156
+ norm2 (nn.BatchNorm1d): Batch normalization applied to the aggregated messages.
157
+ use_envelope (bool): Indicates if the envelope function is used.
158
+ envelope (CosineCutoff): Envelope function applied to the distances.
159
+ """
160
+
161
+ def __init__(self,
162
+ dim_in: int,
163
+ use_envelope: bool = True
164
+ ):
165
+ super().__init__()
166
+ self.dim_in = dim_in
167
+ self.activation = nn.SiLU(inplace=True)
168
+ self.MLP_aggr = nn.Sequential(
169
+ pyg_nn.Linear(dim_in*3, dim_in, bias=True),
170
+ self.activation,
171
+ pyg_nn.Linear(dim_in, dim_in, bias=True),
172
+ )
173
+ self.MLP_gate = nn.Sequential(
174
+ pyg_nn.Linear(dim_in*3, dim_in, bias=True),
175
+ self.activation,
176
+ pyg_nn.Linear(dim_in, dim_in, bias=True),
177
+ )
178
+
179
+ self.norm = nn.BatchNorm1d(dim_in)
180
+ self.norm2 = nn.BatchNorm1d(dim_in)
181
+ self.use_envelope = use_envelope
182
+ self.envelope = CosineCutoff(0, 5.0)
183
+
184
+
185
+ def forward(self, batch):
186
+
187
+ x, e, edge_index, dist = batch.x, batch.edge_attr, batch.edge_index, batch.cart_dist
188
+ """
189
+ x : [n_nodes, dim_in]
190
+ e : [n_edges, dim_in]
191
+ edge_index : [2, n_edges]
192
+ dist : [n_edges]
193
+ batch : [n_nodes]
194
+ """
195
+
196
+ x_in = x
197
+ e_in = e
198
+
199
+ x, e = self.propagate(edge_index,
200
+ Xx=x, Ee=e,
201
+ He=dist,
202
+ )
203
+
204
+ batch.x = self.activation(x) + x_in
205
+
206
+ batch.edge_attr = e_in + e
207
+
208
+ return batch
209
+
210
+
211
+ def message(self, Xx_i, Ee, Xx_j, He):
212
+ """
213
+ x_i : [n_edges, dim_in]
214
+ x_j : [n_edges, dim_in]
215
+ e : [n_edges, dim_in]
216
+ """
217
+
218
+ e_ij = self.MLP_gate(torch.cat([Xx_i, Xx_j, Ee], dim=-1))
219
+ e_ij = F.sigmoid(self.norm(e_ij))
220
+
221
+ if self.use_envelope:
222
+ sigma_ij = self.envelope(He).unsqueeze(-1)*e_ij
223
+ else:
224
+ sigma_ij = e_ij
225
+
226
+ self.e = sigma_ij
227
+ return sigma_ij
228
+
229
+ def aggregate(self, sigma_ij, index, Xx_i, Xx_j, Ee, Xx):
230
+ """
231
+ sigma_ij : [n_edges, dim_in] ; is the output from message() function
232
+ index : [n_edges]
233
+ x_j : [n_edges, dim_in]
234
+ """
235
+ dim_size = Xx.shape[0]
236
+
237
+ sender = self.MLP_aggr(torch.cat([Xx_i, Xx_j, Ee], dim=-1))
238
+
239
+
240
+ out = scatter(sigma_ij*sender, index, 0, None, dim_size,
241
+ reduce='sum')
242
+
243
+ return out
244
+
245
+ def update(self, aggr_out):
246
+ """
247
+ aggr_out : [n_nodes, dim_in] ; is the output from aggregate() function after the aggregation
248
+ x : [n_nodes, dim_in]
249
+ """
250
+ x = self.norm2(aggr_out)
251
+
252
+ e_out = self.e
253
+ del self.e
254
+
255
+ return x, e_out
256
+
257
+ class Cholesky_head(torch.nn.Module):
258
+ """
259
+ The Cholesky head used in the CartNet model.
260
+ It enforce the positive definiteness of the output covariance matrix.
261
+
262
+ Args:
263
+ dim_in (int): The input dimension of the features.
264
+ """
265
+
266
+ def __init__(self,
267
+ dim_in: int
268
+ ):
269
+ super(Cholesky_head, self).__init__()
270
+ self.MLP = nn.Sequential(pyg_nn.Linear(dim_in, dim_in//2),
271
+ nn.SiLU(inplace=True),
272
+ pyg_nn.Linear(dim_in//2, 6))
273
+
274
+ def forward(self, batch):
275
+ pred = self.MLP(batch.x[batch.non_H_mask])
276
+
277
+ diag_elements = F.softplus(pred[:, :3])
278
+
279
+ i,j = torch.tensor([0,1,2,0,0,1]), torch.tensor([0,1,2,1,2,2])
280
+ L_matrix = torch.zeros(pred.size(0),3,3, device=pred.device, dtype=pred.dtype)
281
+ L_matrix[:,i[:3], i[:3]] = diag_elements
282
+ L_matrix[:,i[3:], j[3:]] = pred[:,3:]
283
+
284
+ U = torch.bmm(L_matrix.transpose(1, 2), L_matrix)
285
+
286
+ return U
287
+
288
+
289
+
models/master.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import streamlit as st
3
+ from models.cartnet import CartNet
4
+
5
+ # We cache the loading function to make is very fast on reload.
6
+ @st.cache_resource
7
+ def create_model():
8
+ model = CartNet(dim_in=256, dim_rbf=64, num_layers=4, radius=5.0, invariant=False, temperature=True, use_envelope=True, cholesky=True)
9
+ ckpt_path = "cpkt/cartnet_adp.ckpt"
10
+ load = torch.load(ckpt_path, map_location=torch.device('cpu'))["model_state"]
11
+
12
+ model.load_state_dict(load)
13
+ model.eval()
14
+ return model
models/utils.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import numpy as np
4
+ from torch import nn, Tensor
5
+ import torch.nn.functional as F
6
+ from typing import Optional
7
+
8
+ # Implementation from TensorNet
9
+ # https://github.com/torchmd/torchmd-net
10
+ class ExpNormalSmearing(nn.Module):
11
+ def __init__(
12
+ self,
13
+ cutoff_lower=0.0,
14
+ cutoff_upper=5.0,
15
+ num_rbf=50,
16
+ trainable=True,
17
+ dtype=torch.float32,
18
+ ):
19
+ super(ExpNormalSmearing, self).__init__()
20
+ self.cutoff_lower = cutoff_lower
21
+ self.cutoff_upper = cutoff_upper
22
+ self.num_rbf = num_rbf
23
+ self.trainable = trainable
24
+ self.dtype = dtype
25
+ self.cutoff_fn = CosineCutoff(0, cutoff_upper)
26
+ self.alpha = 5.0 / (cutoff_upper - cutoff_lower)
27
+
28
+ means, betas = self._initial_params()
29
+ if trainable:
30
+ self.register_parameter("means", nn.Parameter(means))
31
+ self.register_parameter("betas", nn.Parameter(betas))
32
+ else:
33
+ self.register_buffer("means", means)
34
+ self.register_buffer("betas", betas)
35
+
36
+ def _initial_params(self):
37
+ # initialize means and betas according to the default values in PhysNet
38
+ # https://pubs.acs.org/doi/10.1021/acs.jctc.9b00181
39
+ start_value = torch.exp(
40
+ torch.scalar_tensor(
41
+ -self.cutoff_upper + self.cutoff_lower, dtype=self.dtype
42
+ )
43
+ )
44
+ means = torch.linspace(start_value, 1, self.num_rbf, dtype=self.dtype)
45
+ betas = torch.tensor(
46
+ [(2 / self.num_rbf * (1 - start_value)) ** -2] * self.num_rbf,
47
+ dtype=self.dtype,
48
+ )
49
+ return means, betas
50
+
51
+ def reset_parameters(self):
52
+ means, betas = self._initial_params()
53
+ self.means.data.copy_(means)
54
+ self.betas.data.copy_(betas)
55
+
56
+ def forward(self, dist):
57
+ dist = dist.unsqueeze(-1)
58
+ return self.cutoff_fn(dist) * torch.exp(
59
+ -self.betas
60
+ * (torch.exp(self.alpha * (-dist + self.cutoff_lower)) - self.means) ** 2
61
+ )
62
+
63
+ class CosineCutoff(nn.Module):
64
+ def __init__(self, cutoff_lower=0.0, cutoff_upper=5.0):
65
+ super(CosineCutoff, self).__init__()
66
+ self.cutoff_lower = cutoff_lower
67
+ self.cutoff_upper = cutoff_upper
68
+
69
+ def forward(self, distances: Tensor) -> Tensor:
70
+ if self.cutoff_lower > 0:
71
+ cutoffs = 0.5 * (
72
+ torch.cos(
73
+ math.pi
74
+ * (
75
+ 2
76
+ * (distances - self.cutoff_lower)
77
+ / (self.cutoff_upper - self.cutoff_lower)
78
+ + 1.0
79
+ )
80
+ )
81
+ + 1.0
82
+ )
83
+ # remove contributions below the cutoff radius
84
+ cutoffs = cutoffs * (distances < self.cutoff_upper)
85
+ cutoffs = cutoffs * (distances > self.cutoff_lower)
86
+ return cutoffs
87
+ else:
88
+ cutoffs = 0.5 * (torch.cos(distances * math.pi / self.cutoff_upper) + 1.0)
89
+ # remove contributions beyond the cutoff radius
90
+ cutoffs = cutoffs * (distances < self.cutoff_upper)
91
+ return cutoffs
92
+
93
+
94
+ # Implementation from Comformer
95
+ # https://github.com/divelab/AIRS/tree/main/OpenMat/ComFormer
96
+ class RBFExpansion(nn.Module):
97
+ """Expand interatomic distances with radial basis functions."""
98
+
99
+ def __init__(
100
+ self,
101
+ vmin: float = 0,
102
+ vmax: float = 8,
103
+ bins: int = 40,
104
+ lengthscale: Optional[float] = None,
105
+ ):
106
+ """Register torch parameters for RBF expansion."""
107
+ super().__init__()
108
+ self.vmin = vmin
109
+ self.vmax = vmax
110
+ self.bins = bins
111
+ self.register_buffer(
112
+ "centers", torch.linspace(self.vmin, self.vmax, self.bins)
113
+ )
114
+
115
+ if lengthscale is None:
116
+ # SchNet-style
117
+ # set lengthscales relative to granularity of RBF expansion
118
+ self.lengthscale = np.diff(self.centers).mean()
119
+ self.gamma = 1 / self.lengthscale
120
+
121
+ else:
122
+ self.lengthscale = lengthscale
123
+ self.gamma = 1 / (lengthscale ** 2)
124
+
125
+ def forward(self, distance: torch.Tensor) -> torch.Tensor:
126
+ """Apply RBF expansion to interatomic distance tensor."""
127
+ return torch.exp(
128
+ -self.gamma * (distance.unsqueeze(1) - self.centers) ** 2
129
+ )
predict.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from ase.io import read
4
+ from CifFile import ReadCif
5
+ from torch_geometric.data import Data, Batch
6
+ import torch
7
+ from models.master import create_model
8
+ from process import process_data
9
+ from utils import radius_graph_pbc
10
+ import gc
11
+
12
+ MEAN_TEMP = torch.tensor(192.1785) #training temp mean
13
+ STD_TEMP = torch.tensor(81.2135) #training temp std
14
+
15
+ @torch.no_grad()
16
+ def process_cif(input_file, output_file):
17
+ model = create_model()
18
+
19
+ try:
20
+ # Read the CIF file using ASE
21
+ atoms = read(input_file, format="cif")
22
+ cif = ReadCif(input_file)
23
+ cif_data = cif.first_block()
24
+ if "_diffrn_ambient_temperature" in cif_data.keys():
25
+ temperature = float(cif_data["_diffrn_ambient_temperature"])
26
+ else:
27
+ raise ValueError("Temperature not found in the CIF file. \
28
+ Please provide a temperature in the field _diffrn_ambient_temperature from the CIF file.")
29
+
30
+ data = Data()
31
+ data.x = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.int32)
32
+
33
+ if len(atoms.positions) > 300:
34
+ raise ValueError("This implementation is not optimized for large systems. For large systems, please use the local version.")
35
+
36
+ data.pos = torch.tensor(atoms.positions, dtype=torch.float32)
37
+ data.temperature_og = torch.tensor([temperature], dtype=torch.float32)
38
+ data.temperature = (data.temperature_og - MEAN_TEMP) / STD_TEMP
39
+ data.cell = torch.tensor(atoms.cell.array, dtype=torch.float32).unsqueeze(0)
40
+
41
+ data.pbc = torch.tensor([True, True, True])
42
+ data.natoms = len(atoms)
43
+
44
+ del atoms
45
+ gc.collect()
46
+ batch = Batch.from_data_list([data])
47
+
48
+ edge_index, _, _, edge_attr = radius_graph_pbc(batch, 5.0, 64)
49
+ del batch
50
+ gc.collect()
51
+ data.cart_dist = torch.norm(edge_attr, dim=-1)
52
+ data.cart_dir = torch.nn.functional.normalize(edge_attr, dim=-1)
53
+ data.edge_index = edge_index
54
+ data.non_H_mask = data.x != 1
55
+ delattr(data, "pbc")
56
+ delattr(data, "natoms")
57
+ batch = Batch.from_data_list([data])
58
+ del data, edge_index, edge_attr
59
+ gc.collect()
60
+
61
+ process_data(batch, model, output_file)
62
+
63
+ gc.collect()
64
+ except Exception as e:
65
+ print(f"An error occurred while processing the CIF file: {e}")
66
+
67
+ if __name__ == "__main__":
68
+ parser = argparse.ArgumentParser(description="Process a CIF file and output the result.")
69
+ parser.add_argument("input_file", type=str, help="Path to the input CIF file.")
70
+ parser.add_argument("output_file", type=str, help="Path to the output CIF file.")
71
+ args = parser.parse_args()
72
+
73
+ process_cif(args.input_file, args.output_file)
process.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from ase.io import write
3
+ from ase import Atoms
4
+ import gc
5
+
6
+ @torch.no_grad()
7
+ def process_data(batch, model, output_file="output.cif"):
8
+ atoms = batch.x.numpy().astype(int) # Atomic numbers
9
+ positions = batch.pos.numpy() # Atomic positions
10
+ cell = batch.cell.squeeze(0).numpy() # Cell parameters
11
+ temperature = batch.temperature_og.numpy()[0]
12
+
13
+
14
+ adps = model(batch)
15
+
16
+ # Convert Ucart to Ucif
17
+ M = batch.cell.squeeze(0)
18
+ N = torch.diag(torch.linalg.norm(torch.linalg.inv(M.transpose(-1,-2)).squeeze(0), dim=-1))
19
+
20
+ M = torch.linalg.inv(M)
21
+ N = torch.linalg.inv(N)
22
+
23
+ adps = M.transpose(-1,-2)@adps@M
24
+ adps = N.transpose(-1,-2)@adps@N
25
+ del M, N
26
+ gc.collect()
27
+
28
+
29
+ non_H_mask = batch.non_H_mask.numpy()
30
+ indices = torch.arange(len(atoms))[non_H_mask].numpy()
31
+ indices = {indices[i]: i for i in range(len(indices))}
32
+ # Create ASE Atoms object
33
+ ase_atoms = Atoms(numbers=atoms, positions=positions, cell=cell, pbc=True)
34
+
35
+ # Convert positions to fractional coordinates
36
+ fractional_positions = ase_atoms.get_scaled_positions()
37
+
38
+ # Write to CIF file
39
+ write(output_file, ase_atoms)
40
+
41
+ with open(output_file, 'r') as file:
42
+ lines = file.readlines()
43
+
44
+ # Find the line where "loop_" appears and remove lines from there to the end
45
+ for i, line in enumerate(lines):
46
+ if line.strip().startswith('loop_'):
47
+ lines = lines[:i]
48
+ break
49
+
50
+ # Write the modified lines to a new output file
51
+ with open(output_file, 'w') as file:
52
+ file.writelines(lines)
53
+
54
+ # Manually append positions and ADPs to the CIF file
55
+ with open(output_file, 'a') as cif_file:
56
+
57
+ # Write temperature
58
+ cif_file.write(f"\n_diffrn_ambient_temperature {temperature}\n")
59
+ # Write atomic positions
60
+ cif_file.write("\nloop_\n")
61
+ cif_file.write("_atom_site_label\n")
62
+ cif_file.write("_atom_site_type_symbol\n")
63
+ cif_file.write("_atom_site_fract_x\n")
64
+ cif_file.write("_atom_site_fract_y\n")
65
+ cif_file.write("_atom_site_fract_z\n")
66
+ cif_file.write("_atom_site_U_iso_or_equiv\n")
67
+ cif_file.write("_atom_site_thermal_displace_type\n")
68
+
69
+ element_count = {}
70
+ for i, (atom_number, frac_pos) in enumerate(zip(atoms, fractional_positions)):
71
+ element = ase_atoms[i].symbol
72
+ assert atom_number == ase_atoms[i].number
73
+ if element not in element_count:
74
+ element_count[element] = 0
75
+ element_count[element] += 1
76
+ label = f"{element}{element_count[element]}"
77
+ u_iso = torch.trace(adps[indices[i]]).mean() if element != 'H' else 0.01
78
+ type = "Uani" if element != 'H' else "Uiso"
79
+ cif_file.write(f"{label} {element} {frac_pos[0]} {frac_pos[1]} {frac_pos[2]} {u_iso} {type}\n")
80
+
81
+ # Write ADPs
82
+ cif_file.write("\nloop_\n")
83
+ cif_file.write("_atom_site_aniso_label\n")
84
+ cif_file.write("_atom_site_aniso_U_11\n")
85
+ cif_file.write("_atom_site_aniso_U_22\n")
86
+ cif_file.write("_atom_site_aniso_U_33\n")
87
+ cif_file.write("_atom_site_aniso_U_23\n")
88
+ cif_file.write("_atom_site_aniso_U_13\n")
89
+ cif_file.write("_atom_site_aniso_U_12\n")
90
+
91
+ element_count = {}
92
+ for i, atom_number in enumerate(atoms):
93
+ if atom_number == 1:
94
+ continue
95
+ element = ase_atoms[i].symbol
96
+ if element not in element_count:
97
+ element_count[element] = 0
98
+ element_count[element] += 1
99
+ label = f"{element}{element_count[element]}"
100
+ cif_file.write(f"{label} {adps[indices[i],0,0]} {adps[indices[i],1,1]} {adps[indices[i],2,2]} {adps[indices[i],1,2]} {adps[indices[i],0,2]} {adps[indices[i],0,1]}\n")
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ numpy<2
2
+ torch==1.13.1
3
+ torch_geometric==2.5.2
4
+ torch-scatter==2.1.1
5
+ -f https://data.pyg.org/whl/torch-1.13.1+cpu.html
6
+ streamlit==1.40.1
7
+ ase==3.23.0
8
+ PyCifRW==4.4.6
utils.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Facebook, Inc. and its affiliates.
3
+
4
+ This source code is licensed under the MIT license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+
8
+ import numpy as np
9
+ import torch
10
+ from torch_scatter import segment_coo, segment_csr
11
+
12
+
13
+
14
+
15
+ def radius_graph_pbc(
16
+ data,
17
+ radius,
18
+ max_num_neighbors_threshold,
19
+ enforce_max_neighbors_strictly: bool = False,
20
+ pbc=[True, True, True],
21
+ ):
22
+ device = data.pos.device
23
+ batch_size = len(data.natoms)
24
+
25
+ if hasattr(data, "pbc"):
26
+ data.pbc = torch.atleast_2d(data.pbc)
27
+ for i in range(3):
28
+ if not torch.any(data.pbc[:, i]).item():
29
+ pbc[i] = False
30
+ elif torch.all(data.pbc[:, i]).item():
31
+ pbc[i] = True
32
+ else:
33
+ raise RuntimeError(
34
+ "Different structures in the batch have different PBC configurations. This is not currently supported."
35
+ )
36
+
37
+ # position of the atoms
38
+ atom_pos = data.pos
39
+
40
+ # Before computing the pairwise distances between atoms, first create a list of atom indices to compare for the entire batch
41
+ num_atoms_per_image = data.natoms
42
+ num_atoms_per_image_sqr = (num_atoms_per_image**2).long()
43
+
44
+ # index offset between images
45
+ index_offset = (
46
+ torch.cumsum(num_atoms_per_image, dim=0) - num_atoms_per_image
47
+ )
48
+
49
+ index_offset_expand = torch.repeat_interleave(
50
+ index_offset, num_atoms_per_image_sqr
51
+ )
52
+ num_atoms_per_image_expand = torch.repeat_interleave(
53
+ num_atoms_per_image, num_atoms_per_image_sqr
54
+ )
55
+
56
+ # Compute a tensor containing sequences of numbers that range from 0 to num_atoms_per_image_sqr for each image
57
+ # that is used to compute indices for the pairs of atoms. This is a very convoluted way to implement
58
+ # the following (but 10x faster since it removes the for loop)
59
+ # for batch_idx in range(batch_size):
60
+ # batch_count = torch.cat([batch_count, torch.arange(num_atoms_per_image_sqr[batch_idx], device=device)], dim=0)
61
+ num_atom_pairs = torch.sum(num_atoms_per_image_sqr)
62
+ index_sqr_offset = (
63
+ torch.cumsum(num_atoms_per_image_sqr, dim=0) - num_atoms_per_image_sqr
64
+ )
65
+ index_sqr_offset = torch.repeat_interleave(
66
+ index_sqr_offset, num_atoms_per_image_sqr
67
+ )
68
+ atom_count_sqr = (
69
+ torch.arange(num_atom_pairs, device=device) - index_sqr_offset
70
+ )
71
+
72
+ # Compute the indices for the pairs of atoms (using division and mod)
73
+ # If the systems get too large this apporach could run into numerical precision issues
74
+ index1 = (
75
+ torch.div(
76
+ atom_count_sqr, num_atoms_per_image_expand, rounding_mode="floor"
77
+ )
78
+ ) + index_offset_expand
79
+ index2 = (
80
+ atom_count_sqr % num_atoms_per_image_expand
81
+ ) + index_offset_expand
82
+ # Get the positions for each atom
83
+ pos1 = torch.index_select(atom_pos, 0, index1)
84
+ pos2 = torch.index_select(atom_pos, 0, index2)
85
+
86
+ # Calculate required number of unit cells in each direction.
87
+ # Smallest distance between planes separated by a1 is
88
+ # 1 / ||(a2 x a3) / V||_2, since a2 x a3 is the area of the plane.
89
+ # Note that the unit cell volume V = a1 * (a2 x a3) and that
90
+ # (a2 x a3) / V is also the reciprocal primitive vector
91
+ # (crystallographer's definition).
92
+
93
+ cross_a2a3 = torch.cross(data.cell[:, 1], data.cell[:, 2], dim=-1)
94
+ cell_vol = torch.sum(data.cell[:, 0] * cross_a2a3, dim=-1, keepdim=True)
95
+
96
+ if pbc[0]:
97
+ inv_min_dist_a1 = torch.norm(cross_a2a3 / cell_vol, p=2, dim=-1)
98
+ rep_a1 = torch.ceil(radius * inv_min_dist_a1)
99
+ else:
100
+ rep_a1 = data.cell.new_zeros(1)
101
+
102
+ if pbc[1]:
103
+ cross_a3a1 = torch.cross(data.cell[:, 2], data.cell[:, 0], dim=-1)
104
+ inv_min_dist_a2 = torch.norm(cross_a3a1 / cell_vol, p=2, dim=-1)
105
+ rep_a2 = torch.ceil(radius * inv_min_dist_a2)
106
+ else:
107
+ rep_a2 = data.cell.new_zeros(1)
108
+
109
+ if pbc[2]:
110
+ cross_a1a2 = torch.cross(data.cell[:, 0], data.cell[:, 1], dim=-1)
111
+ inv_min_dist_a3 = torch.norm(cross_a1a2 / cell_vol, p=2, dim=-1)
112
+ rep_a3 = torch.ceil(radius * inv_min_dist_a3)
113
+ else:
114
+ rep_a3 = data.cell.new_zeros(1)
115
+
116
+ # Take the max over all images for uniformity. This is essentially padding.
117
+ # Note that this can significantly increase the number of computed distances
118
+ # if the required repetitions are very different between images
119
+ # (which they usually are). Changing this to sparse (scatter) operations
120
+ # might be worth the effort if this function becomes a bottleneck.
121
+ max_rep = [rep_a1.max(), rep_a2.max(), rep_a3.max()]
122
+
123
+ # Tensor of unit cells
124
+ cells_per_dim = [
125
+ torch.arange(-rep, rep + 1, device=device, dtype=torch.float)
126
+ for rep in max_rep
127
+ ]
128
+ unit_cell = torch.cartesian_prod(*cells_per_dim)
129
+ num_cells = len(unit_cell)
130
+ unit_cell_per_atom = unit_cell.view(1, num_cells, 3).repeat(
131
+ len(index2), 1, 1
132
+ )
133
+ unit_cell = torch.transpose(unit_cell, 0, 1)
134
+ unit_cell_batch = unit_cell.view(1, 3, num_cells).expand(
135
+ batch_size, -1, -1
136
+ )
137
+
138
+ # Compute the x, y, z positional offsets for each cell in each image
139
+ data_cell = torch.transpose(data.cell, 1, 2)
140
+ pbc_offsets = torch.bmm(data_cell, unit_cell_batch)
141
+ pbc_offsets_per_atom = torch.repeat_interleave(
142
+ pbc_offsets, num_atoms_per_image_sqr, dim=0
143
+ )
144
+
145
+ # Expand the positions and indices for the 9 cells
146
+ pos1 = pos1.view(-1, 3, 1).expand(-1, -1, num_cells)
147
+ pos2 = pos2.view(-1, 3, 1).expand(-1, -1, num_cells)
148
+ index1 = index1.view(-1, 1).repeat(1, num_cells).view(-1)
149
+ index2 = index2.view(-1, 1).repeat(1, num_cells).view(-1)
150
+ # Add the PBC offsets for the second atom
151
+ pos2 = pos2 + pbc_offsets_per_atom
152
+
153
+ # Compute the squared distance between atoms
154
+ direction = pos1 - pos2
155
+ atom_distance_sqr = torch.sum((direction) ** 2, dim=1)
156
+ direction = direction.permute(0, 2, 1).reshape(-1, 3)
157
+ atom_distance_sqr = atom_distance_sqr.view(-1)
158
+
159
+ # Remove pairs that are too far apart
160
+ mask_within_radius = torch.le(atom_distance_sqr, radius * radius)
161
+ # Remove pairs with the same atoms (distance = 0.0)
162
+ mask_not_same = torch.gt(atom_distance_sqr, 0.0001)
163
+ mask = torch.logical_and(mask_within_radius, mask_not_same)
164
+ index1 = torch.masked_select(index1, mask)
165
+ index2 = torch.masked_select(index2, mask)
166
+ unit_cell = torch.masked_select(
167
+ unit_cell_per_atom.view(-1, 3), mask.view(-1, 1).expand(-1, 3)
168
+ )
169
+ unit_cell = unit_cell.view(-1, 3)
170
+ atom_distance_sqr = torch.masked_select(atom_distance_sqr, mask)
171
+ direction = torch.masked_select(direction, mask.view(-1, 1).expand(-1, 3)).view(-1, 3)
172
+
173
+ if max_num_neighbors_threshold is not None:
174
+ mask_num_neighbors, num_neighbors_image = get_max_neighbors_mask(
175
+ natoms=data.natoms,
176
+ index=index1,
177
+ atom_distance=atom_distance_sqr,
178
+ max_num_neighbors_threshold=max_num_neighbors_threshold,
179
+ enforce_max_strictly=enforce_max_neighbors_strictly,
180
+ )
181
+
182
+ if not torch.all(mask_num_neighbors):
183
+ # Mask out the atoms to ensure each atom has at most max_num_neighbors_threshold neighbors
184
+ index1 = torch.masked_select(index1, mask_num_neighbors)
185
+ index2 = torch.masked_select(index2, mask_num_neighbors)
186
+ atom_distance_sqr = torch.masked_select(atom_distance_sqr, mask_num_neighbors)
187
+ direction = torch.masked_select(direction, mask_num_neighbors.view(-1, 1).expand(-1, 3)).view(-1, 3)
188
+ unit_cell = torch.masked_select(
189
+ unit_cell.view(-1, 3), mask_num_neighbors.view(-1, 1).expand(-1, 3)
190
+ )
191
+ unit_cell = unit_cell.view(-1, 3)
192
+
193
+ edge_index = torch.stack((index2, index1))
194
+
195
+ return edge_index, unit_cell, torch.sqrt(atom_distance_sqr), direction
196
+
197
+
198
+ def get_max_neighbors_mask(
199
+ natoms,
200
+ index,
201
+ atom_distance,
202
+ max_num_neighbors_threshold,
203
+ degeneracy_tolerance: float = 0.01,
204
+ enforce_max_strictly: bool = False,
205
+ ):
206
+ """
207
+ Give a mask that filters out edges so that each atom has at most
208
+ `max_num_neighbors_threshold` neighbors.
209
+ Assumes that `index` is sorted.
210
+
211
+ Enforcing the max strictly can force the arbitrary choice between
212
+ degenerate edges. This can lead to undesired behaviors; for
213
+ example, bulk formation energies which are not invariant to
214
+ unit cell choice.
215
+
216
+ A degeneracy tolerance can help prevent sudden changes in edge
217
+ existence from small changes in atom position, for example,
218
+ rounding errors, slab relaxation, temperature, etc.
219
+ """
220
+
221
+ device = natoms.device
222
+ num_atoms = natoms.sum()
223
+
224
+ # Get number of neighbors
225
+ # segment_coo assumes sorted index
226
+ ones = index.new_ones(1).expand_as(index)
227
+ num_neighbors = segment_coo(ones, index, dim_size=num_atoms)
228
+ max_num_neighbors = num_neighbors.max()
229
+ num_neighbors_thresholded = num_neighbors.clamp(
230
+ max=max_num_neighbors_threshold
231
+ )
232
+
233
+ # Get number of (thresholded) neighbors per image
234
+ image_indptr = torch.zeros(
235
+ natoms.shape[0] + 1, device=device, dtype=torch.long
236
+ )
237
+ image_indptr[1:] = torch.cumsum(natoms, dim=0)
238
+ num_neighbors_image = segment_csr(num_neighbors_thresholded, image_indptr)
239
+
240
+ # If max_num_neighbors is below the threshold, return early
241
+ if (
242
+ max_num_neighbors <= max_num_neighbors_threshold
243
+ or max_num_neighbors_threshold <= 0
244
+ ):
245
+ mask_num_neighbors = torch.tensor(
246
+ [True], dtype=bool, device=device
247
+ ).expand_as(index)
248
+ return mask_num_neighbors, num_neighbors_image
249
+
250
+ # Create a tensor of size [num_atoms, max_num_neighbors] to sort the distances of the neighbors.
251
+ # Fill with infinity so we can easily remove unused distances later.
252
+ distance_sort = torch.full(
253
+ [num_atoms * max_num_neighbors], np.inf, device=device
254
+ )
255
+
256
+ # Create an index map to map distances from atom_distance to distance_sort
257
+ # index_sort_map assumes index to be sorted
258
+ index_neighbor_offset = torch.cumsum(num_neighbors, dim=0) - num_neighbors
259
+ index_neighbor_offset_expand = torch.repeat_interleave(
260
+ index_neighbor_offset, num_neighbors
261
+ )
262
+ index_sort_map = (
263
+ index * max_num_neighbors
264
+ + torch.arange(len(index), device=device)
265
+ - index_neighbor_offset_expand
266
+ )
267
+ print(index_sort_map.dtype, atom_distance.dtype)
268
+ distance_sort.index_copy_(0, index_sort_map, atom_distance)
269
+ distance_sort = distance_sort.view(num_atoms, max_num_neighbors)
270
+
271
+ # Sort neighboring atoms based on distance
272
+ distance_sort, index_sort = torch.sort(distance_sort, dim=1)
273
+
274
+ # Select the max_num_neighbors_threshold neighbors that are closest
275
+ if enforce_max_strictly:
276
+ distance_sort = distance_sort[:, :max_num_neighbors_threshold]
277
+ index_sort = index_sort[:, :max_num_neighbors_threshold]
278
+ max_num_included = max_num_neighbors_threshold
279
+
280
+ else:
281
+ effective_cutoff = (
282
+ distance_sort[:, max_num_neighbors_threshold]
283
+ + degeneracy_tolerance
284
+ )
285
+ is_included = torch.le(distance_sort.T, effective_cutoff)
286
+
287
+ # Set all undesired edges to infinite length to be removed later
288
+ distance_sort[~is_included.T] = np.inf
289
+
290
+ # Subselect tensors for efficiency
291
+ num_included_per_atom = torch.sum(is_included, dim=0)
292
+ max_num_included = torch.max(num_included_per_atom)
293
+ distance_sort = distance_sort[:, :max_num_included]
294
+ index_sort = index_sort[:, :max_num_included]
295
+
296
+ # Recompute the number of neighbors
297
+ num_neighbors_thresholded = num_neighbors.clamp(
298
+ max=num_included_per_atom
299
+ )
300
+
301
+ num_neighbors_image = segment_csr(
302
+ num_neighbors_thresholded, image_indptr
303
+ )
304
+
305
+ # Offset index_sort so that it indexes into index
306
+ index_sort = index_sort + index_neighbor_offset.view(-1, 1).expand(
307
+ -1, max_num_included
308
+ )
309
+ # Remove "unused pairs" with infinite distances
310
+ mask_finite = torch.isfinite(distance_sort)
311
+ index_sort = torch.masked_select(index_sort, mask_finite)
312
+
313
+ # At this point index_sort contains the index into index of the
314
+ # closest max_num_neighbors_threshold neighbors per atom
315
+ # Create a mask to remove all pairs not in index_sort
316
+ mask_num_neighbors = torch.zeros(len(index), device=device, dtype=bool)
317
+ mask_num_neighbors.index_fill_(0, index_sort, True)
318
+
319
+ return mask_num_neighbors, num_neighbors_image
320
+
321
+
322
+
323
+