Spaces:
Sleeping
Sleeping
Àlex Solé
commited on
Commit
·
fa790e2
1
Parent(s):
4e099cc
fixed bug multiple structures csd
Browse files- main.py +88 -69
- main_local.py +76 -78
- process.py +53 -54
- utils.py +1 -1
main.py
CHANGED
@@ -8,6 +8,7 @@ from models.master import create_model
|
|
8 |
from process import process_data
|
9 |
from utils import radius_graph_pbc
|
10 |
import gc
|
|
|
11 |
|
12 |
MEAN_TEMP = torch.tensor(192.1785) #training temp mean
|
13 |
STD_TEMP = torch.tensor(81.2135) #training temp std
|
@@ -16,7 +17,7 @@ STD_TEMP = torch.tensor(81.2135) #training temp std
|
|
16 |
@torch.no_grad()
|
17 |
def main():
|
18 |
model = create_model()
|
19 |
-
st.title("CartNet
|
20 |
st.image('fig/pipeline.png')
|
21 |
|
22 |
st.markdown("""
|
@@ -24,85 +25,101 @@ def main():
|
|
24 |
""")
|
25 |
|
26 |
uploaded_file = st.file_uploader("Upload a CIF file", type=["cif"], accept_multiple_files=False)
|
27 |
-
|
28 |
if uploaded_file is not None:
|
29 |
try:
|
30 |
-
with open(uploaded_file.name, "wb") as f:
|
31 |
-
f.write(uploaded_file.getbuffer())
|
32 |
filename = str(uploaded_file.name)
|
33 |
-
|
34 |
-
|
35 |
-
cif = ReadCif(filename)
|
36 |
-
cif_data = cif.first_block()
|
37 |
-
|
38 |
-
if "_diffrn_ambient_temperature" in cif_data.keys():
|
39 |
-
temperature = float(cif_data["_diffrn_ambient_temperature"].split("(")[0])
|
40 |
-
elif "_cell_measurement_temperature" in cif_data.keys():
|
41 |
-
temperature = float(cif_data["_cell_measurement_temperature"].split("(")[0])
|
42 |
-
else:
|
43 |
-
raise ValueError("Temperature not found in the CIF file. \
|
44 |
-
Please provide a temperature in the field _diffrn_ambient_temperature o in the field _cell_measurement_temperature from the CIF file.")
|
45 |
-
st.success("CIF file successfully read.")
|
46 |
-
|
47 |
-
data = Data()
|
48 |
-
data.x = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.int32)
|
49 |
-
|
50 |
-
if len(atoms.positions) > 1000:
|
51 |
-
st.markdown("""
|
52 |
-
⚠️ **Warning**: The structure is too large. Please upload a smaller one or use the [local implementation of CartNet Web App](https://github.com/alexsoleg/cartnet-streamlit/).
|
53 |
-
""")
|
54 |
-
raise ValueError("Please provide a structure with less than 1000 atoms in the unit cell.")
|
55 |
-
|
56 |
-
data.pos = torch.tensor(atoms.positions, dtype=torch.float32)
|
57 |
-
data.temperature_og = torch.tensor([temperature], dtype=torch.float32)
|
58 |
-
data.temperature = (data.temperature_og - MEAN_TEMP) / STD_TEMP
|
59 |
-
data.cell = torch.tensor(atoms.cell.array, dtype=torch.float32).unsqueeze(0)
|
60 |
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
del atoms
|
65 |
-
gc.collect()
|
66 |
-
batch = Batch.from_data_list([data])
|
67 |
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
-
st.download_button(
|
92 |
-
label="Download processed CIF file",
|
93 |
-
data=cif_contents,
|
94 |
-
file_name="output.cif",
|
95 |
-
mime="text/plain"
|
96 |
-
)
|
97 |
-
|
98 |
-
os.remove("output.cif")
|
99 |
-
os.remove(filename)
|
100 |
gc.collect()
|
101 |
except Exception as e:
|
102 |
st.error(f"An error occurred while reading the CIF file: {e}")
|
103 |
-
|
|
|
104 |
⚠️ **Warning**: This online web application is designed for structures with up to 1000 atoms in the unit cell. For larger structures, please use the [local implementation of CartNet Web App](https://github.com/alexsoleg/cartnet-streamlit/).
|
105 |
""")
|
|
|
|
|
|
|
|
|
106 |
|
107 |
st.markdown("""
|
108 |
📌 The official implementation of the paper with all experiments can be found at [CartNet GitHub Repository](https://huggingface.co/spaces/alexsoleg/cartnet-demo/tree/main).
|
@@ -128,3 +145,5 @@ def main():
|
|
128 |
|
129 |
if __name__ == "__main__":
|
130 |
main()
|
|
|
|
|
|
8 |
from process import process_data
|
9 |
from utils import radius_graph_pbc
|
10 |
import gc
|
11 |
+
from io import BytesIO, StringIO
|
12 |
|
13 |
MEAN_TEMP = torch.tensor(192.1785) #training temp mean
|
14 |
STD_TEMP = torch.tensor(81.2135) #training temp std
|
|
|
17 |
@torch.no_grad()
|
18 |
def main():
|
19 |
model = create_model()
|
20 |
+
st.title("CartNet Thermal Ellipsoid Prediction")
|
21 |
st.image('fig/pipeline.png')
|
22 |
|
23 |
st.markdown("""
|
|
|
25 |
""")
|
26 |
|
27 |
uploaded_file = st.file_uploader("Upload a CIF file", type=["cif"], accept_multiple_files=False)
|
28 |
+
|
29 |
if uploaded_file is not None:
|
30 |
try:
|
|
|
|
|
31 |
filename = str(uploaded_file.name)
|
32 |
+
file = BytesIO(uploaded_file.getbuffer())
|
33 |
+
cif = ReadCif(file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
+
if len(cif.keys())>1:
|
36 |
+
st.warning("⚠️ **Warning**: Found " + str(len(cif.keys())) + " blocks in the CIF file. We will process all of them and export as separate CIF files.")
|
|
|
|
|
|
|
|
|
37 |
|
38 |
+
st.markdown(f"### CIF file: {filename}")
|
39 |
+
for key in cif.keys():
|
40 |
+
st.markdown(f"### Block: {key}")
|
41 |
+
try:
|
42 |
+
block = "data_"+str(key)+"\n"+ cif[key].printsection()
|
43 |
+
atoms = read(StringIO(block), format="cif")
|
44 |
+
|
45 |
+
if len(atoms.positions) > 1000:
|
46 |
+
st.error("""
|
47 |
+
⚠️ **Warning**: The structure is too large. Please upload a smaller one or use the [local implementation of CartNet Web App](https://github.com/alexsoleg/cartnet-streamlit/).
|
48 |
+
""")
|
49 |
+
continue
|
50 |
+
|
51 |
+
cif_data = cif[key]
|
52 |
+
if "_diffrn_ambient_temperature" in cif_data.keys():
|
53 |
+
temperature = float(cif_data["_diffrn_ambient_temperature"].split("(")[0])
|
54 |
+
elif "_cell_measurement_temperature" in cif_data.keys():
|
55 |
+
temperature = float(cif_data["_cell_measurement_temperature"].split("(")[0])
|
56 |
+
else:
|
57 |
+
st.error("Temperature not found in the CIF file. \
|
58 |
+
Please provide a temperature in the field _diffrn_ambient_temperature o in the field _cell_measurement_temperature from the CIF file.")
|
59 |
+
continue
|
60 |
+
st.success("CIF file successfully read.")
|
61 |
+
except Exception as e:
|
62 |
+
st.error(f"Error: {e}")
|
63 |
+
st.error(f"We couldn't find any structure for the block {key}. Please make sure the cif is compatible with ASE. If the error message is a blank line, it means ASE didn't found any coordinates.")
|
64 |
+
|
65 |
+
continue
|
66 |
+
|
67 |
+
data = Data()
|
68 |
+
data.x = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.int32)
|
69 |
+
|
70 |
+
data.pos = torch.tensor(atoms.positions, dtype=torch.float32)
|
71 |
+
data.temperature_og = torch.tensor([temperature], dtype=torch.float32)
|
72 |
+
data.temperature = (data.temperature_og - MEAN_TEMP) / STD_TEMP
|
73 |
+
data.cell = torch.tensor(atoms.cell.array, dtype=torch.float32).unsqueeze(0)
|
74 |
+
|
75 |
+
data.pbc = torch.tensor([True, True, True])
|
76 |
+
data.natoms = len(atoms)
|
77 |
+
|
78 |
+
del atoms
|
79 |
+
gc.collect()
|
80 |
+
batch = Batch.from_data_list([data])
|
81 |
+
|
82 |
+
|
83 |
+
edge_index, _, _, edge_attr = radius_graph_pbc(batch, 5.0, 64)
|
84 |
+
del batch
|
85 |
+
gc.collect()
|
86 |
+
data.cart_dist = torch.norm(edge_attr, dim=-1)
|
87 |
+
data.cart_dir = torch.nn.functional.normalize(edge_attr, dim=-1)
|
88 |
+
data.edge_index = edge_index
|
89 |
+
data.non_H_mask = data.x != 1
|
90 |
+
delattr(data, "pbc")
|
91 |
+
delattr(data, "natoms")
|
92 |
+
batch = Batch.from_data_list([data])
|
93 |
+
del data, edge_index, edge_attr
|
94 |
+
gc.collect()
|
95 |
+
|
96 |
+
st.success("Graph successfully created.")
|
97 |
+
|
98 |
+
cif_file = process_data(batch, model)
|
99 |
+
st.success("ADPs successfully predicted.")
|
100 |
+
|
101 |
+
cif_file = BytesIO(cif_file.getvalue().encode())
|
102 |
+
st.download_button(
|
103 |
+
label="Download processed CIF file",
|
104 |
+
data=cif_file,
|
105 |
+
file_name=f"output_{key}.cif",
|
106 |
+
mime="text/plain",
|
107 |
+
key=f"download_button_{key}"
|
108 |
+
)
|
109 |
+
|
110 |
+
gc.collect()
|
111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
gc.collect()
|
113 |
except Exception as e:
|
114 |
st.error(f"An error occurred while reading the CIF file: {e}")
|
115 |
+
|
116 |
+
st.warning("""
|
117 |
⚠️ **Warning**: This online web application is designed for structures with up to 1000 atoms in the unit cell. For larger structures, please use the [local implementation of CartNet Web App](https://github.com/alexsoleg/cartnet-streamlit/).
|
118 |
""")
|
119 |
+
|
120 |
+
st.warning("""
|
121 |
+
⚠️ **Warning**: We use [ASE library](https://wiki.fysik.dtu.dk/ase/) for reading the cif files, please make sure it is compatible.
|
122 |
+
""")
|
123 |
|
124 |
st.markdown("""
|
125 |
📌 The official implementation of the paper with all experiments can be found at [CartNet GitHub Repository](https://huggingface.co/spaces/alexsoleg/cartnet-demo/tree/main).
|
|
|
145 |
|
146 |
if __name__ == "__main__":
|
147 |
main()
|
148 |
+
|
149 |
+
|
main_local.py
CHANGED
@@ -17,7 +17,7 @@ STD_TEMP = torch.tensor(81.2135) #training temp std
|
|
17 |
@torch.no_grad()
|
18 |
def main():
|
19 |
model = create_model()
|
20 |
-
st.title("CartNet
|
21 |
st.image('fig/pipeline.png')
|
22 |
|
23 |
st.markdown("""
|
@@ -25,92 +25,86 @@ def main():
|
|
25 |
""")
|
26 |
|
27 |
uploaded_file = st.file_uploader("Upload a CIF file", type=["cif"], accept_multiple_files=False)
|
|
|
28 |
|
29 |
if uploaded_file is not None:
|
30 |
try:
|
31 |
filename = str(uploaded_file.name)
|
32 |
file = BytesIO(uploaded_file.getbuffer())
|
33 |
cif = ReadCif(file)
|
34 |
-
print(cif.keys())
|
35 |
-
if len(cif.keys())>1:
|
36 |
-
st.markdown("Found " + str(len(cif.keys())) + " blocks in the CIF file. We will process all of them and export as separate CIF files.")
|
37 |
-
for key in cif.keys():
|
38 |
-
print(key)
|
39 |
-
# print(cif[key])
|
40 |
-
block = "data_"+str(key)+"\n"+ cif[key].printsection()
|
41 |
-
atoms = read(StringIO(block), format="cif")
|
42 |
-
print("atoms")
|
43 |
-
print(atoms)
|
44 |
-
# atoms = read(atoms_2, format="cif")
|
45 |
-
# with open(uploaded_file.name, "wb") as f:
|
46 |
-
# f.write(uploaded_file.getbuffer())
|
47 |
-
# filename = str(uploaded_file.name)
|
48 |
-
# # Read the CIF file using ASE
|
49 |
-
# atoms = read(filename, format="cif")
|
50 |
-
# cif = ReadCif(filename)
|
51 |
-
# print(cif.keys())
|
52 |
-
# print(len(atoms))
|
53 |
-
# # st.markdown(cif)
|
54 |
-
# cif_data = cif
|
55 |
-
# st.markdown(f"### CIF file: {filename}")
|
56 |
-
# temperature = 100
|
57 |
-
# if "_diffrn_ambient_temperature" in cif_data.keys():
|
58 |
-
# temperature = float(cif_data["_diffrn_ambient_temperature"].split("(")[0])
|
59 |
-
# elif "_cell_measurement_temperature" in cif_data.keys():
|
60 |
-
# temperature = float(cif_data["_cell_measurement_temperature"].split("(")[0])
|
61 |
-
# else:
|
62 |
-
# raise ValueError("Temperature not found in the CIF file. \
|
63 |
-
# Please provide a temperature in the field _diffrn_ambient_temperature o in the field _cell_measurement_temperature from the CIF file.")
|
64 |
-
# st.success("CIF file successfully read.")
|
65 |
-
|
66 |
-
# data = Data()
|
67 |
-
# data.x = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.int32)
|
68 |
-
|
69 |
-
# data.pos = torch.tensor(atoms.positions, dtype=torch.float32)
|
70 |
-
# data.temperature_og = torch.tensor([temperature], dtype=torch.float32)
|
71 |
-
# data.temperature = (data.temperature_og - MEAN_TEMP) / STD_TEMP
|
72 |
-
# data.cell = torch.tensor(atoms.cell.array, dtype=torch.float32).unsqueeze(0)
|
73 |
-
|
74 |
-
# data.pbc = torch.tensor([True, True, True])
|
75 |
-
# data.natoms = len(atoms)
|
76 |
|
77 |
-
|
78 |
-
|
79 |
-
# batch = Batch.from_data_list([data])
|
80 |
-
|
81 |
-
|
82 |
-
# edge_index, _, _, edge_attr = radius_graph_pbc(batch, 5.0, 64)
|
83 |
-
# del batch
|
84 |
-
# gc.collect()
|
85 |
-
# data.cart_dist = torch.norm(edge_attr, dim=-1)
|
86 |
-
# data.cart_dir = torch.nn.functional.normalize(edge_attr, dim=-1)
|
87 |
-
# data.edge_index = edge_index
|
88 |
-
# data.non_H_mask = data.x != 1
|
89 |
-
# delattr(data, "pbc")
|
90 |
-
# delattr(data, "natoms")
|
91 |
-
# batch = Batch.from_data_list([data])
|
92 |
-
# del data, edge_index, edge_attr
|
93 |
-
# gc.collect()
|
94 |
-
|
95 |
-
# st.success("Graph successfully created.")
|
96 |
-
|
97 |
-
# process_data(batch, model)
|
98 |
-
# st.success("ADPs successfully predicted.")
|
99 |
|
100 |
-
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
-
|
105 |
-
# label="Download processed CIF file",
|
106 |
-
# data=cif_contents,
|
107 |
-
# file_name="output.cif",
|
108 |
-
# mime="text/plain"
|
109 |
-
# )
|
110 |
-
|
111 |
-
# os.remove("output.cif")
|
112 |
-
# os.remove(filename)
|
113 |
-
# gc.collect()
|
114 |
except Exception as e:
|
115 |
st.error(f"An error occurred while reading the CIF file: {e}")
|
116 |
|
@@ -119,6 +113,10 @@ def main():
|
|
119 |
📌 The official implementation of the paper with all experiments can be found at [CartNet GitHub Repository](https://github.com/imatge-upc/CartNet).
|
120 |
""")
|
121 |
|
|
|
|
|
|
|
|
|
122 |
st.markdown("""
|
123 |
### How to cite
|
124 |
|
|
|
17 |
@torch.no_grad()
|
18 |
def main():
|
19 |
model = create_model()
|
20 |
+
st.title("CartNet Thermal Ellipsoid Prediction")
|
21 |
st.image('fig/pipeline.png')
|
22 |
|
23 |
st.markdown("""
|
|
|
25 |
""")
|
26 |
|
27 |
uploaded_file = st.file_uploader("Upload a CIF file", type=["cif"], accept_multiple_files=False)
|
28 |
+
|
29 |
|
30 |
if uploaded_file is not None:
|
31 |
try:
|
32 |
filename = str(uploaded_file.name)
|
33 |
file = BytesIO(uploaded_file.getbuffer())
|
34 |
cif = ReadCif(file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
+
if len(cif.keys())>1:
|
37 |
+
st.warning("⚠️ **Warning**: Found " + str(len(cif.keys())) + " blocks in the CIF file. We will process all of them and export as separate CIF files.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
+
st.markdown(f"### CIF file: {filename}")
|
40 |
+
for key in cif.keys():
|
41 |
+
st.markdown(f"### Block: {key}")
|
42 |
+
try:
|
43 |
+
block = "data_"+str(key)+"\n"+ cif[key].printsection()
|
44 |
+
atoms = read(StringIO(block), format="cif")
|
45 |
+
|
46 |
+
cif_data = cif[key]
|
47 |
+
if "_diffrn_ambient_temperature" in cif_data.keys():
|
48 |
+
temperature = float(cif_data["_diffrn_ambient_temperature"].split("(")[0])
|
49 |
+
elif "_cell_measurement_temperature" in cif_data.keys():
|
50 |
+
temperature = float(cif_data["_cell_measurement_temperature"].split("(")[0])
|
51 |
+
else:
|
52 |
+
st.error("Temperature not found in the CIF file. \
|
53 |
+
Please provide a temperature in the field _diffrn_ambient_temperature o in the field _cell_measurement_temperature from the CIF file.")
|
54 |
+
continue
|
55 |
+
st.success("CIF file successfully read.")
|
56 |
+
except Exception as e:
|
57 |
+
st.error(f"Error: {e}")
|
58 |
+
st.error(f"We couldn't find any structure for the block {key}. Please make sure the cif is compatible with ASE. If the error message is a blank line, it means ASE didn't found any coordinates.")
|
59 |
+
|
60 |
+
continue
|
61 |
+
|
62 |
+
data = Data()
|
63 |
+
data.x = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.int32)
|
64 |
+
|
65 |
+
data.pos = torch.tensor(atoms.positions, dtype=torch.float32)
|
66 |
+
data.temperature_og = torch.tensor([temperature], dtype=torch.float32)
|
67 |
+
data.temperature = (data.temperature_og - MEAN_TEMP) / STD_TEMP
|
68 |
+
data.cell = torch.tensor(atoms.cell.array, dtype=torch.float32).unsqueeze(0)
|
69 |
+
|
70 |
+
data.pbc = torch.tensor([True, True, True])
|
71 |
+
data.natoms = len(atoms)
|
72 |
+
|
73 |
+
del atoms
|
74 |
+
gc.collect()
|
75 |
+
batch = Batch.from_data_list([data])
|
76 |
+
|
77 |
+
|
78 |
+
edge_index, _, _, edge_attr = radius_graph_pbc(batch, 5.0, 64)
|
79 |
+
del batch
|
80 |
+
gc.collect()
|
81 |
+
data.cart_dist = torch.norm(edge_attr, dim=-1)
|
82 |
+
data.cart_dir = torch.nn.functional.normalize(edge_attr, dim=-1)
|
83 |
+
data.edge_index = edge_index
|
84 |
+
data.non_H_mask = data.x != 1
|
85 |
+
delattr(data, "pbc")
|
86 |
+
delattr(data, "natoms")
|
87 |
+
batch = Batch.from_data_list([data])
|
88 |
+
del data, edge_index, edge_attr
|
89 |
+
gc.collect()
|
90 |
+
|
91 |
+
st.success("Graph successfully created.")
|
92 |
+
|
93 |
+
cif_file = process_data(batch, model)
|
94 |
+
st.success("ADPs successfully predicted.")
|
95 |
+
|
96 |
+
cif_file = BytesIO(cif_file.getvalue().encode())
|
97 |
+
st.download_button(
|
98 |
+
label="Download processed CIF file",
|
99 |
+
data=cif_file,
|
100 |
+
file_name=f"output_{key}.cif",
|
101 |
+
mime="text/plain",
|
102 |
+
key=f"download_button_{key}"
|
103 |
+
)
|
104 |
+
|
105 |
+
gc.collect()
|
106 |
|
107 |
+
gc.collect()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
except Exception as e:
|
109 |
st.error(f"An error occurred while reading the CIF file: {e}")
|
110 |
|
|
|
113 |
📌 The official implementation of the paper with all experiments can be found at [CartNet GitHub Repository](https://github.com/imatge-upc/CartNet).
|
114 |
""")
|
115 |
|
116 |
+
st.warning("""
|
117 |
+
⚠️ **Warning**: We use [ASE library](https://wiki.fysik.dtu.dk/ase/) for reading the cif files, please make sure it is compatible.
|
118 |
+
""")
|
119 |
+
|
120 |
st.markdown("""
|
121 |
### How to cite
|
122 |
|
process.py
CHANGED
@@ -2,6 +2,7 @@ import torch
|
|
2 |
from ase.io import write
|
3 |
from ase import Atoms
|
4 |
import gc
|
|
|
5 |
|
6 |
@torch.no_grad()
|
7 |
def process_data(batch, model, output_file="output.cif"):
|
@@ -35,11 +36,12 @@ def process_data(batch, model, output_file="output.cif"):
|
|
35 |
# Convert positions to fractional coordinates
|
36 |
fractional_positions = ase_atoms.get_scaled_positions()
|
37 |
|
38 |
-
# Write to CIF file
|
39 |
-
write(output_file, ase_atoms)
|
40 |
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
43 |
|
44 |
# Find the line where "loop_" appears and remove lines from there to the end
|
45 |
for i, line in enumerate(lines):
|
@@ -47,54 +49,51 @@ def process_data(batch, model, output_file="output.cif"):
|
|
47 |
lines = lines[:i]
|
48 |
break
|
49 |
|
50 |
-
#
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
label = f"{element}{element_count[element]}"
|
77 |
-
u_iso = torch.trace(adps[indices[i]]).mean() if element != 'H' else 0.01
|
78 |
-
type = "Uani" if element != 'H' else "Uiso"
|
79 |
-
cif_file.write(f"{label} {element} {frac_pos[0]} {frac_pos[1]} {frac_pos[2]} {u_iso} {type}\n")
|
80 |
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
|
|
|
2 |
from ase.io import write
|
3 |
from ase import Atoms
|
4 |
import gc
|
5 |
+
from io import BytesIO, StringIO
|
6 |
|
7 |
@torch.no_grad()
|
8 |
def process_data(batch, model, output_file="output.cif"):
|
|
|
36 |
# Convert positions to fractional coordinates
|
37 |
fractional_positions = ase_atoms.get_scaled_positions()
|
38 |
|
|
|
|
|
39 |
|
40 |
+
# Instead of reading from file, get CIF content directly from ASE's write function
|
41 |
+
cif_content = BytesIO()
|
42 |
+
write(cif_content, ase_atoms, format='cif')
|
43 |
+
lines = cif_content.getvalue().decode('utf-8').splitlines(True)
|
44 |
+
cif_content.close()
|
45 |
|
46 |
# Find the line where "loop_" appears and remove lines from there to the end
|
47 |
for i, line in enumerate(lines):
|
|
|
49 |
lines = lines[:i]
|
50 |
break
|
51 |
|
52 |
+
# Use StringIO to build the CIF content
|
53 |
+
cif_file = StringIO()
|
54 |
+
cif_file.writelines(lines)
|
55 |
+
# Write temperature
|
56 |
+
cif_file.write(f"\n_diffrn_ambient_temperature {temperature}\n")
|
57 |
+
# Write atomic positions
|
58 |
+
cif_file.write("\nloop_\n")
|
59 |
+
cif_file.write("_atom_site_label\n")
|
60 |
+
cif_file.write("_atom_site_type_symbol\n")
|
61 |
+
cif_file.write("_atom_site_fract_x\n")
|
62 |
+
cif_file.write("_atom_site_fract_y\n")
|
63 |
+
cif_file.write("_atom_site_fract_z\n")
|
64 |
+
cif_file.write("_atom_site_U_iso_or_equiv\n")
|
65 |
+
cif_file.write("_atom_site_thermal_displace_type\n")
|
66 |
+
|
67 |
+
element_count = {}
|
68 |
+
for i, (atom_number, frac_pos) in enumerate(zip(atoms, fractional_positions)):
|
69 |
+
element = ase_atoms[i].symbol
|
70 |
+
assert atom_number == ase_atoms[i].number
|
71 |
+
if element not in element_count:
|
72 |
+
element_count[element] = 0
|
73 |
+
element_count[element] += 1
|
74 |
+
label = f"{element}{element_count[element]}"
|
75 |
+
u_iso = torch.trace(adps[indices[i]]).mean() if element != 'H' else 0.01
|
76 |
+
type = "Uani" if element != 'H' else "Uiso"
|
77 |
+
cif_file.write(f"{label} {element} {frac_pos[0]} {frac_pos[1]} {frac_pos[2]} {u_iso} {type}\n")
|
|
|
|
|
|
|
|
|
78 |
|
79 |
+
# Write ADPs
|
80 |
+
cif_file.write("\nloop_\n")
|
81 |
+
cif_file.write("_atom_site_aniso_label\n")
|
82 |
+
cif_file.write("_atom_site_aniso_U_11\n")
|
83 |
+
cif_file.write("_atom_site_aniso_U_22\n")
|
84 |
+
cif_file.write("_atom_site_aniso_U_33\n")
|
85 |
+
cif_file.write("_atom_site_aniso_U_23\n")
|
86 |
+
cif_file.write("_atom_site_aniso_U_13\n")
|
87 |
+
cif_file.write("_atom_site_aniso_U_12\n")
|
88 |
+
|
89 |
+
element_count = {}
|
90 |
+
for i, atom_number in enumerate(atoms):
|
91 |
+
if atom_number == 1:
|
92 |
+
continue
|
93 |
+
element = ase_atoms[i].symbol
|
94 |
+
if element not in element_count:
|
95 |
+
element_count[element] = 0
|
96 |
+
element_count[element] += 1
|
97 |
+
label = f"{element}{element_count[element]}"
|
98 |
+
cif_file.write(f"{label} {adps[indices[i],0,0]} {adps[indices[i],1,1]} {adps[indices[i],2,2]} {adps[indices[i],1,2]} {adps[indices[i],0,2]} {adps[indices[i],0,1]}\n")
|
99 |
+
return cif_file
|
utils.py
CHANGED
@@ -264,7 +264,7 @@ def get_max_neighbors_mask(
|
|
264 |
+ torch.arange(len(index), device=device)
|
265 |
- index_neighbor_offset_expand
|
266 |
)
|
267 |
-
|
268 |
distance_sort.index_copy_(0, index_sort_map, atom_distance)
|
269 |
distance_sort = distance_sort.view(num_atoms, max_num_neighbors)
|
270 |
|
|
|
264 |
+ torch.arange(len(index), device=device)
|
265 |
- index_neighbor_offset_expand
|
266 |
)
|
267 |
+
|
268 |
distance_sort.index_copy_(0, index_sort_map, atom_distance)
|
269 |
distance_sort = distance_sort.view(num_atoms, max_num_neighbors)
|
270 |
|