Àlex Solé commited on
Commit
fa790e2
·
1 Parent(s): 4e099cc

fixed bug multiple structures csd

Browse files
Files changed (4) hide show
  1. main.py +88 -69
  2. main_local.py +76 -78
  3. process.py +53 -54
  4. utils.py +1 -1
main.py CHANGED
@@ -8,6 +8,7 @@ from models.master import create_model
8
  from process import process_data
9
  from utils import radius_graph_pbc
10
  import gc
 
11
 
12
  MEAN_TEMP = torch.tensor(192.1785) #training temp mean
13
  STD_TEMP = torch.tensor(81.2135) #training temp std
@@ -16,7 +17,7 @@ STD_TEMP = torch.tensor(81.2135) #training temp std
16
  @torch.no_grad()
17
  def main():
18
  model = create_model()
19
- st.title("CartNet ADP Prediction")
20
  st.image('fig/pipeline.png')
21
 
22
  st.markdown("""
@@ -24,85 +25,101 @@ def main():
24
  """)
25
 
26
  uploaded_file = st.file_uploader("Upload a CIF file", type=["cif"], accept_multiple_files=False)
27
- # uploaded_file = "ABABEM.cif"
28
  if uploaded_file is not None:
29
  try:
30
- with open(uploaded_file.name, "wb") as f:
31
- f.write(uploaded_file.getbuffer())
32
  filename = str(uploaded_file.name)
33
- # Read the CIF file using ASE
34
- atoms = read(filename, format="cif")
35
- cif = ReadCif(filename)
36
- cif_data = cif.first_block()
37
-
38
- if "_diffrn_ambient_temperature" in cif_data.keys():
39
- temperature = float(cif_data["_diffrn_ambient_temperature"].split("(")[0])
40
- elif "_cell_measurement_temperature" in cif_data.keys():
41
- temperature = float(cif_data["_cell_measurement_temperature"].split("(")[0])
42
- else:
43
- raise ValueError("Temperature not found in the CIF file. \
44
- Please provide a temperature in the field _diffrn_ambient_temperature o in the field _cell_measurement_temperature from the CIF file.")
45
- st.success("CIF file successfully read.")
46
-
47
- data = Data()
48
- data.x = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.int32)
49
-
50
- if len(atoms.positions) > 1000:
51
- st.markdown("""
52
- ⚠️ **Warning**: The structure is too large. Please upload a smaller one or use the [local implementation of CartNet Web App](https://github.com/alexsoleg/cartnet-streamlit/).
53
- """)
54
- raise ValueError("Please provide a structure with less than 1000 atoms in the unit cell.")
55
-
56
- data.pos = torch.tensor(atoms.positions, dtype=torch.float32)
57
- data.temperature_og = torch.tensor([temperature], dtype=torch.float32)
58
- data.temperature = (data.temperature_og - MEAN_TEMP) / STD_TEMP
59
- data.cell = torch.tensor(atoms.cell.array, dtype=torch.float32).unsqueeze(0)
60
 
61
- data.pbc = torch.tensor([True, True, True])
62
- data.natoms = len(atoms)
63
-
64
- del atoms
65
- gc.collect()
66
- batch = Batch.from_data_list([data])
67
 
68
-
69
- edge_index, _, _, edge_attr = radius_graph_pbc(batch, 5.0, 64)
70
- del batch
71
- gc.collect()
72
- data.cart_dist = torch.norm(edge_attr, dim=-1)
73
- data.cart_dir = torch.nn.functional.normalize(edge_attr, dim=-1)
74
- data.edge_index = edge_index
75
- data.non_H_mask = data.x != 1
76
- delattr(data, "pbc")
77
- delattr(data, "natoms")
78
- batch = Batch.from_data_list([data])
79
- del data, edge_index, edge_attr
80
- gc.collect()
81
-
82
- st.success("Graph successfully created.")
83
-
84
- process_data(batch, model)
85
- st.success("ADPs successfully predicted.")
86
-
87
- # Create a download button for the processed CIF file
88
- with open("output.cif", "r") as f:
89
- cif_contents = f.read()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
- st.download_button(
92
- label="Download processed CIF file",
93
- data=cif_contents,
94
- file_name="output.cif",
95
- mime="text/plain"
96
- )
97
-
98
- os.remove("output.cif")
99
- os.remove(filename)
100
  gc.collect()
101
  except Exception as e:
102
  st.error(f"An error occurred while reading the CIF file: {e}")
103
- st.markdown("""
 
104
  ⚠️ **Warning**: This online web application is designed for structures with up to 1000 atoms in the unit cell. For larger structures, please use the [local implementation of CartNet Web App](https://github.com/alexsoleg/cartnet-streamlit/).
105
  """)
 
 
 
 
106
 
107
  st.markdown("""
108
  📌 The official implementation of the paper with all experiments can be found at [CartNet GitHub Repository](https://huggingface.co/spaces/alexsoleg/cartnet-demo/tree/main).
@@ -128,3 +145,5 @@ def main():
128
 
129
  if __name__ == "__main__":
130
  main()
 
 
 
8
  from process import process_data
9
  from utils import radius_graph_pbc
10
  import gc
11
+ from io import BytesIO, StringIO
12
 
13
  MEAN_TEMP = torch.tensor(192.1785) #training temp mean
14
  STD_TEMP = torch.tensor(81.2135) #training temp std
 
17
  @torch.no_grad()
18
  def main():
19
  model = create_model()
20
+ st.title("CartNet Thermal Ellipsoid Prediction")
21
  st.image('fig/pipeline.png')
22
 
23
  st.markdown("""
 
25
  """)
26
 
27
  uploaded_file = st.file_uploader("Upload a CIF file", type=["cif"], accept_multiple_files=False)
28
+
29
  if uploaded_file is not None:
30
  try:
 
 
31
  filename = str(uploaded_file.name)
32
+ file = BytesIO(uploaded_file.getbuffer())
33
+ cif = ReadCif(file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ if len(cif.keys())>1:
36
+ st.warning("⚠️ **Warning**: Found " + str(len(cif.keys())) + " blocks in the CIF file. We will process all of them and export as separate CIF files.")
 
 
 
 
37
 
38
+ st.markdown(f"### CIF file: {filename}")
39
+ for key in cif.keys():
40
+ st.markdown(f"### Block: {key}")
41
+ try:
42
+ block = "data_"+str(key)+"\n"+ cif[key].printsection()
43
+ atoms = read(StringIO(block), format="cif")
44
+
45
+ if len(atoms.positions) > 1000:
46
+ st.error("""
47
+ ⚠️ **Warning**: The structure is too large. Please upload a smaller one or use the [local implementation of CartNet Web App](https://github.com/alexsoleg/cartnet-streamlit/).
48
+ """)
49
+ continue
50
+
51
+ cif_data = cif[key]
52
+ if "_diffrn_ambient_temperature" in cif_data.keys():
53
+ temperature = float(cif_data["_diffrn_ambient_temperature"].split("(")[0])
54
+ elif "_cell_measurement_temperature" in cif_data.keys():
55
+ temperature = float(cif_data["_cell_measurement_temperature"].split("(")[0])
56
+ else:
57
+ st.error("Temperature not found in the CIF file. \
58
+ Please provide a temperature in the field _diffrn_ambient_temperature o in the field _cell_measurement_temperature from the CIF file.")
59
+ continue
60
+ st.success("CIF file successfully read.")
61
+ except Exception as e:
62
+ st.error(f"Error: {e}")
63
+ st.error(f"We couldn't find any structure for the block {key}. Please make sure the cif is compatible with ASE. If the error message is a blank line, it means ASE didn't found any coordinates.")
64
+
65
+ continue
66
+
67
+ data = Data()
68
+ data.x = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.int32)
69
+
70
+ data.pos = torch.tensor(atoms.positions, dtype=torch.float32)
71
+ data.temperature_og = torch.tensor([temperature], dtype=torch.float32)
72
+ data.temperature = (data.temperature_og - MEAN_TEMP) / STD_TEMP
73
+ data.cell = torch.tensor(atoms.cell.array, dtype=torch.float32).unsqueeze(0)
74
+
75
+ data.pbc = torch.tensor([True, True, True])
76
+ data.natoms = len(atoms)
77
+
78
+ del atoms
79
+ gc.collect()
80
+ batch = Batch.from_data_list([data])
81
+
82
+
83
+ edge_index, _, _, edge_attr = radius_graph_pbc(batch, 5.0, 64)
84
+ del batch
85
+ gc.collect()
86
+ data.cart_dist = torch.norm(edge_attr, dim=-1)
87
+ data.cart_dir = torch.nn.functional.normalize(edge_attr, dim=-1)
88
+ data.edge_index = edge_index
89
+ data.non_H_mask = data.x != 1
90
+ delattr(data, "pbc")
91
+ delattr(data, "natoms")
92
+ batch = Batch.from_data_list([data])
93
+ del data, edge_index, edge_attr
94
+ gc.collect()
95
+
96
+ st.success("Graph successfully created.")
97
+
98
+ cif_file = process_data(batch, model)
99
+ st.success("ADPs successfully predicted.")
100
+
101
+ cif_file = BytesIO(cif_file.getvalue().encode())
102
+ st.download_button(
103
+ label="Download processed CIF file",
104
+ data=cif_file,
105
+ file_name=f"output_{key}.cif",
106
+ mime="text/plain",
107
+ key=f"download_button_{key}"
108
+ )
109
+
110
+ gc.collect()
111
 
 
 
 
 
 
 
 
 
 
112
  gc.collect()
113
  except Exception as e:
114
  st.error(f"An error occurred while reading the CIF file: {e}")
115
+
116
+ st.warning("""
117
  ⚠️ **Warning**: This online web application is designed for structures with up to 1000 atoms in the unit cell. For larger structures, please use the [local implementation of CartNet Web App](https://github.com/alexsoleg/cartnet-streamlit/).
118
  """)
119
+
120
+ st.warning("""
121
+ ⚠️ **Warning**: We use [ASE library](https://wiki.fysik.dtu.dk/ase/) for reading the cif files, please make sure it is compatible.
122
+ """)
123
 
124
  st.markdown("""
125
  📌 The official implementation of the paper with all experiments can be found at [CartNet GitHub Repository](https://huggingface.co/spaces/alexsoleg/cartnet-demo/tree/main).
 
145
 
146
  if __name__ == "__main__":
147
  main()
148
+
149
+
main_local.py CHANGED
@@ -17,7 +17,7 @@ STD_TEMP = torch.tensor(81.2135) #training temp std
17
  @torch.no_grad()
18
  def main():
19
  model = create_model()
20
- st.title("CartNet ADP Prediction")
21
  st.image('fig/pipeline.png')
22
 
23
  st.markdown("""
@@ -25,92 +25,86 @@ def main():
25
  """)
26
 
27
  uploaded_file = st.file_uploader("Upload a CIF file", type=["cif"], accept_multiple_files=False)
 
28
 
29
  if uploaded_file is not None:
30
  try:
31
  filename = str(uploaded_file.name)
32
  file = BytesIO(uploaded_file.getbuffer())
33
  cif = ReadCif(file)
34
- print(cif.keys())
35
- if len(cif.keys())>1:
36
- st.markdown("Found " + str(len(cif.keys())) + " blocks in the CIF file. We will process all of them and export as separate CIF files.")
37
- for key in cif.keys():
38
- print(key)
39
- # print(cif[key])
40
- block = "data_"+str(key)+"\n"+ cif[key].printsection()
41
- atoms = read(StringIO(block), format="cif")
42
- print("atoms")
43
- print(atoms)
44
- # atoms = read(atoms_2, format="cif")
45
- # with open(uploaded_file.name, "wb") as f:
46
- # f.write(uploaded_file.getbuffer())
47
- # filename = str(uploaded_file.name)
48
- # # Read the CIF file using ASE
49
- # atoms = read(filename, format="cif")
50
- # cif = ReadCif(filename)
51
- # print(cif.keys())
52
- # print(len(atoms))
53
- # # st.markdown(cif)
54
- # cif_data = cif
55
- # st.markdown(f"### CIF file: {filename}")
56
- # temperature = 100
57
- # if "_diffrn_ambient_temperature" in cif_data.keys():
58
- # temperature = float(cif_data["_diffrn_ambient_temperature"].split("(")[0])
59
- # elif "_cell_measurement_temperature" in cif_data.keys():
60
- # temperature = float(cif_data["_cell_measurement_temperature"].split("(")[0])
61
- # else:
62
- # raise ValueError("Temperature not found in the CIF file. \
63
- # Please provide a temperature in the field _diffrn_ambient_temperature o in the field _cell_measurement_temperature from the CIF file.")
64
- # st.success("CIF file successfully read.")
65
-
66
- # data = Data()
67
- # data.x = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.int32)
68
-
69
- # data.pos = torch.tensor(atoms.positions, dtype=torch.float32)
70
- # data.temperature_og = torch.tensor([temperature], dtype=torch.float32)
71
- # data.temperature = (data.temperature_og - MEAN_TEMP) / STD_TEMP
72
- # data.cell = torch.tensor(atoms.cell.array, dtype=torch.float32).unsqueeze(0)
73
-
74
- # data.pbc = torch.tensor([True, True, True])
75
- # data.natoms = len(atoms)
76
 
77
- # del atoms
78
- # gc.collect()
79
- # batch = Batch.from_data_list([data])
80
-
81
-
82
- # edge_index, _, _, edge_attr = radius_graph_pbc(batch, 5.0, 64)
83
- # del batch
84
- # gc.collect()
85
- # data.cart_dist = torch.norm(edge_attr, dim=-1)
86
- # data.cart_dir = torch.nn.functional.normalize(edge_attr, dim=-1)
87
- # data.edge_index = edge_index
88
- # data.non_H_mask = data.x != 1
89
- # delattr(data, "pbc")
90
- # delattr(data, "natoms")
91
- # batch = Batch.from_data_list([data])
92
- # del data, edge_index, edge_attr
93
- # gc.collect()
94
-
95
- # st.success("Graph successfully created.")
96
-
97
- # process_data(batch, model)
98
- # st.success("ADPs successfully predicted.")
99
 
100
- # # Create a download button for the processed CIF file
101
- # with open("output.cif", "r") as f:
102
- # cif_contents = f.read()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
- # st.download_button(
105
- # label="Download processed CIF file",
106
- # data=cif_contents,
107
- # file_name="output.cif",
108
- # mime="text/plain"
109
- # )
110
-
111
- # os.remove("output.cif")
112
- # os.remove(filename)
113
- # gc.collect()
114
  except Exception as e:
115
  st.error(f"An error occurred while reading the CIF file: {e}")
116
 
@@ -119,6 +113,10 @@ def main():
119
  📌 The official implementation of the paper with all experiments can be found at [CartNet GitHub Repository](https://github.com/imatge-upc/CartNet).
120
  """)
121
 
 
 
 
 
122
  st.markdown("""
123
  ### How to cite
124
 
 
17
  @torch.no_grad()
18
  def main():
19
  model = create_model()
20
+ st.title("CartNet Thermal Ellipsoid Prediction")
21
  st.image('fig/pipeline.png')
22
 
23
  st.markdown("""
 
25
  """)
26
 
27
  uploaded_file = st.file_uploader("Upload a CIF file", type=["cif"], accept_multiple_files=False)
28
+
29
 
30
  if uploaded_file is not None:
31
  try:
32
  filename = str(uploaded_file.name)
33
  file = BytesIO(uploaded_file.getbuffer())
34
  cif = ReadCif(file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ if len(cif.keys())>1:
37
+ st.warning("⚠️ **Warning**: Found " + str(len(cif.keys())) + " blocks in the CIF file. We will process all of them and export as separate CIF files.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ st.markdown(f"### CIF file: {filename}")
40
+ for key in cif.keys():
41
+ st.markdown(f"### Block: {key}")
42
+ try:
43
+ block = "data_"+str(key)+"\n"+ cif[key].printsection()
44
+ atoms = read(StringIO(block), format="cif")
45
+
46
+ cif_data = cif[key]
47
+ if "_diffrn_ambient_temperature" in cif_data.keys():
48
+ temperature = float(cif_data["_diffrn_ambient_temperature"].split("(")[0])
49
+ elif "_cell_measurement_temperature" in cif_data.keys():
50
+ temperature = float(cif_data["_cell_measurement_temperature"].split("(")[0])
51
+ else:
52
+ st.error("Temperature not found in the CIF file. \
53
+ Please provide a temperature in the field _diffrn_ambient_temperature o in the field _cell_measurement_temperature from the CIF file.")
54
+ continue
55
+ st.success("CIF file successfully read.")
56
+ except Exception as e:
57
+ st.error(f"Error: {e}")
58
+ st.error(f"We couldn't find any structure for the block {key}. Please make sure the cif is compatible with ASE. If the error message is a blank line, it means ASE didn't found any coordinates.")
59
+
60
+ continue
61
+
62
+ data = Data()
63
+ data.x = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.int32)
64
+
65
+ data.pos = torch.tensor(atoms.positions, dtype=torch.float32)
66
+ data.temperature_og = torch.tensor([temperature], dtype=torch.float32)
67
+ data.temperature = (data.temperature_og - MEAN_TEMP) / STD_TEMP
68
+ data.cell = torch.tensor(atoms.cell.array, dtype=torch.float32).unsqueeze(0)
69
+
70
+ data.pbc = torch.tensor([True, True, True])
71
+ data.natoms = len(atoms)
72
+
73
+ del atoms
74
+ gc.collect()
75
+ batch = Batch.from_data_list([data])
76
+
77
+
78
+ edge_index, _, _, edge_attr = radius_graph_pbc(batch, 5.0, 64)
79
+ del batch
80
+ gc.collect()
81
+ data.cart_dist = torch.norm(edge_attr, dim=-1)
82
+ data.cart_dir = torch.nn.functional.normalize(edge_attr, dim=-1)
83
+ data.edge_index = edge_index
84
+ data.non_H_mask = data.x != 1
85
+ delattr(data, "pbc")
86
+ delattr(data, "natoms")
87
+ batch = Batch.from_data_list([data])
88
+ del data, edge_index, edge_attr
89
+ gc.collect()
90
+
91
+ st.success("Graph successfully created.")
92
+
93
+ cif_file = process_data(batch, model)
94
+ st.success("ADPs successfully predicted.")
95
+
96
+ cif_file = BytesIO(cif_file.getvalue().encode())
97
+ st.download_button(
98
+ label="Download processed CIF file",
99
+ data=cif_file,
100
+ file_name=f"output_{key}.cif",
101
+ mime="text/plain",
102
+ key=f"download_button_{key}"
103
+ )
104
+
105
+ gc.collect()
106
 
107
+ gc.collect()
 
 
 
 
 
 
 
 
 
108
  except Exception as e:
109
  st.error(f"An error occurred while reading the CIF file: {e}")
110
 
 
113
  📌 The official implementation of the paper with all experiments can be found at [CartNet GitHub Repository](https://github.com/imatge-upc/CartNet).
114
  """)
115
 
116
+ st.warning("""
117
+ ⚠️ **Warning**: We use [ASE library](https://wiki.fysik.dtu.dk/ase/) for reading the cif files, please make sure it is compatible.
118
+ """)
119
+
120
  st.markdown("""
121
  ### How to cite
122
 
process.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
  from ase.io import write
3
  from ase import Atoms
4
  import gc
 
5
 
6
  @torch.no_grad()
7
  def process_data(batch, model, output_file="output.cif"):
@@ -35,11 +36,12 @@ def process_data(batch, model, output_file="output.cif"):
35
  # Convert positions to fractional coordinates
36
  fractional_positions = ase_atoms.get_scaled_positions()
37
 
38
- # Write to CIF file
39
- write(output_file, ase_atoms)
40
 
41
- with open(output_file, 'r') as file:
42
- lines = file.readlines()
 
 
 
43
 
44
  # Find the line where "loop_" appears and remove lines from there to the end
45
  for i, line in enumerate(lines):
@@ -47,54 +49,51 @@ def process_data(batch, model, output_file="output.cif"):
47
  lines = lines[:i]
48
  break
49
 
50
- # Write the modified lines to a new output file
51
- with open(output_file, 'w') as file:
52
- file.writelines(lines)
53
-
54
- # Manually append positions and ADPs to the CIF file
55
- with open(output_file, 'a') as cif_file:
56
-
57
- # Write temperature
58
- cif_file.write(f"\n_diffrn_ambient_temperature {temperature}\n")
59
- # Write atomic positions
60
- cif_file.write("\nloop_\n")
61
- cif_file.write("_atom_site_label\n")
62
- cif_file.write("_atom_site_type_symbol\n")
63
- cif_file.write("_atom_site_fract_x\n")
64
- cif_file.write("_atom_site_fract_y\n")
65
- cif_file.write("_atom_site_fract_z\n")
66
- cif_file.write("_atom_site_U_iso_or_equiv\n")
67
- cif_file.write("_atom_site_thermal_displace_type\n")
68
-
69
- element_count = {}
70
- for i, (atom_number, frac_pos) in enumerate(zip(atoms, fractional_positions)):
71
- element = ase_atoms[i].symbol
72
- assert atom_number == ase_atoms[i].number
73
- if element not in element_count:
74
- element_count[element] = 0
75
- element_count[element] += 1
76
- label = f"{element}{element_count[element]}"
77
- u_iso = torch.trace(adps[indices[i]]).mean() if element != 'H' else 0.01
78
- type = "Uani" if element != 'H' else "Uiso"
79
- cif_file.write(f"{label} {element} {frac_pos[0]} {frac_pos[1]} {frac_pos[2]} {u_iso} {type}\n")
80
 
81
- # Write ADPs
82
- cif_file.write("\nloop_\n")
83
- cif_file.write("_atom_site_aniso_label\n")
84
- cif_file.write("_atom_site_aniso_U_11\n")
85
- cif_file.write("_atom_site_aniso_U_22\n")
86
- cif_file.write("_atom_site_aniso_U_33\n")
87
- cif_file.write("_atom_site_aniso_U_23\n")
88
- cif_file.write("_atom_site_aniso_U_13\n")
89
- cif_file.write("_atom_site_aniso_U_12\n")
90
-
91
- element_count = {}
92
- for i, atom_number in enumerate(atoms):
93
- if atom_number == 1:
94
- continue
95
- element = ase_atoms[i].symbol
96
- if element not in element_count:
97
- element_count[element] = 0
98
- element_count[element] += 1
99
- label = f"{element}{element_count[element]}"
100
- cif_file.write(f"{label} {adps[indices[i],0,0]} {adps[indices[i],1,1]} {adps[indices[i],2,2]} {adps[indices[i],1,2]} {adps[indices[i],0,2]} {adps[indices[i],0,1]}\n")
 
 
2
  from ase.io import write
3
  from ase import Atoms
4
  import gc
5
+ from io import BytesIO, StringIO
6
 
7
  @torch.no_grad()
8
  def process_data(batch, model, output_file="output.cif"):
 
36
  # Convert positions to fractional coordinates
37
  fractional_positions = ase_atoms.get_scaled_positions()
38
 
 
 
39
 
40
+ # Instead of reading from file, get CIF content directly from ASE's write function
41
+ cif_content = BytesIO()
42
+ write(cif_content, ase_atoms, format='cif')
43
+ lines = cif_content.getvalue().decode('utf-8').splitlines(True)
44
+ cif_content.close()
45
 
46
  # Find the line where "loop_" appears and remove lines from there to the end
47
  for i, line in enumerate(lines):
 
49
  lines = lines[:i]
50
  break
51
 
52
+ # Use StringIO to build the CIF content
53
+ cif_file = StringIO()
54
+ cif_file.writelines(lines)
55
+ # Write temperature
56
+ cif_file.write(f"\n_diffrn_ambient_temperature {temperature}\n")
57
+ # Write atomic positions
58
+ cif_file.write("\nloop_\n")
59
+ cif_file.write("_atom_site_label\n")
60
+ cif_file.write("_atom_site_type_symbol\n")
61
+ cif_file.write("_atom_site_fract_x\n")
62
+ cif_file.write("_atom_site_fract_y\n")
63
+ cif_file.write("_atom_site_fract_z\n")
64
+ cif_file.write("_atom_site_U_iso_or_equiv\n")
65
+ cif_file.write("_atom_site_thermal_displace_type\n")
66
+
67
+ element_count = {}
68
+ for i, (atom_number, frac_pos) in enumerate(zip(atoms, fractional_positions)):
69
+ element = ase_atoms[i].symbol
70
+ assert atom_number == ase_atoms[i].number
71
+ if element not in element_count:
72
+ element_count[element] = 0
73
+ element_count[element] += 1
74
+ label = f"{element}{element_count[element]}"
75
+ u_iso = torch.trace(adps[indices[i]]).mean() if element != 'H' else 0.01
76
+ type = "Uani" if element != 'H' else "Uiso"
77
+ cif_file.write(f"{label} {element} {frac_pos[0]} {frac_pos[1]} {frac_pos[2]} {u_iso} {type}\n")
 
 
 
 
78
 
79
+ # Write ADPs
80
+ cif_file.write("\nloop_\n")
81
+ cif_file.write("_atom_site_aniso_label\n")
82
+ cif_file.write("_atom_site_aniso_U_11\n")
83
+ cif_file.write("_atom_site_aniso_U_22\n")
84
+ cif_file.write("_atom_site_aniso_U_33\n")
85
+ cif_file.write("_atom_site_aniso_U_23\n")
86
+ cif_file.write("_atom_site_aniso_U_13\n")
87
+ cif_file.write("_atom_site_aniso_U_12\n")
88
+
89
+ element_count = {}
90
+ for i, atom_number in enumerate(atoms):
91
+ if atom_number == 1:
92
+ continue
93
+ element = ase_atoms[i].symbol
94
+ if element not in element_count:
95
+ element_count[element] = 0
96
+ element_count[element] += 1
97
+ label = f"{element}{element_count[element]}"
98
+ cif_file.write(f"{label} {adps[indices[i],0,0]} {adps[indices[i],1,1]} {adps[indices[i],2,2]} {adps[indices[i],1,2]} {adps[indices[i],0,2]} {adps[indices[i],0,1]}\n")
99
+ return cif_file
utils.py CHANGED
@@ -264,7 +264,7 @@ def get_max_neighbors_mask(
264
  + torch.arange(len(index), device=device)
265
  - index_neighbor_offset_expand
266
  )
267
- print(index_sort_map.dtype, atom_distance.dtype)
268
  distance_sort.index_copy_(0, index_sort_map, atom_distance)
269
  distance_sort = distance_sort.view(num_atoms, max_num_neighbors)
270
 
 
264
  + torch.arange(len(index), device=device)
265
  - index_neighbor_offset_expand
266
  )
267
+
268
  distance_sort.index_copy_(0, index_sort_map, atom_distance)
269
  distance_sort = distance_sort.view(num_atoms, max_num_neighbors)
270