Àlex Solé commited on
Commit
6afeebe
·
1 Parent(s): 43fc1a3

Initial fix multipe structures

Browse files
Files changed (3) hide show
  1. .gitignore +2 -0
  2. main.py +9 -8
  3. main_local.py +77 -55
.gitignore CHANGED
@@ -1 +1,3 @@
1
  .DS_Store
 
 
 
1
  .DS_Store
2
+ *.pyc
3
+ *.cif
main.py CHANGED
@@ -40,14 +40,15 @@ def main():
40
  if key.lower().startswith("data"):
41
  st.markdown(cif_data[key])
42
  st.markdown(cif_data["_publ_section_comment"])
43
- if "_diffrn_ambient_temperature" in cif_data.keys():
44
- temperature = float(cif_data["_diffrn_ambient_temperature"].split("(")[0])
45
- elif "_cell_measurement_temperature" in cif_data.keys():
46
- temperature = float(cif_data["_cell_measurement_temperature"].split("(")[0])
47
- else:
48
- raise ValueError("Temperature not found in the CIF file. \
49
- Please provide a temperature in the field _diffrn_ambient_temperature o in the field _cell_measurement_temperature from the CIF file.")
50
- st.success("CIF file successfully read.")
 
51
 
52
  data = Data()
53
  data.x = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.int32)
 
40
  if key.lower().startswith("data"):
41
  st.markdown(cif_data[key])
42
  st.markdown(cif_data["_publ_section_comment"])
43
+ temperature = 100
44
+ # if "_diffrn_ambient_temperature" in cif_data.keys():
45
+ # temperature = float(cif_data["_diffrn_ambient_temperature"].split("(")[0])
46
+ # elif "_cell_measurement_temperature" in cif_data.keys():
47
+ # temperature = float(cif_data["_cell_measurement_temperature"].split("(")[0])
48
+ # else:
49
+ # raise ValueError("Temperature not found in the CIF file. \
50
+ # Please provide a temperature in the field _diffrn_ambient_temperature o in the field _cell_measurement_temperature from the CIF file.")
51
+ # st.success("CIF file successfully read.")
52
 
53
  data = Data()
54
  data.x = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.int32)
main_local.py CHANGED
@@ -7,6 +7,7 @@ import torch
7
  from models.master import create_model
8
  from process import process_data
9
  from utils import radius_graph_pbc
 
10
  import gc
11
 
12
  MEAN_TEMP = torch.tensor(192.1785) #training temp mean
@@ -24,71 +25,92 @@ def main():
24
  """)
25
 
26
  uploaded_file = st.file_uploader("Upload a CIF file", type=["cif"], accept_multiple_files=False)
27
- # uploaded_file = "ABABEM.cif"
28
  if uploaded_file is not None:
29
  try:
30
- with open(uploaded_file.name, "wb") as f:
31
- f.write(uploaded_file.getbuffer())
32
  filename = str(uploaded_file.name)
33
- # Read the CIF file using ASE
34
- atoms = read(filename, format="cif")
35
- cif = ReadCif(filename)
36
- cif_data = cif.first_block()
37
- if "_diffrn_ambient_temperature" in cif_data.keys():
38
- temperature = float(cif_data["_diffrn_ambient_temperature"])
39
- else:
40
- raise ValueError("Temperature not found in the CIF file. \
41
- Please provide a temperature in the field _diffrn_ambient_temperature from the CIF file.")
42
- st.success("CIF file successfully read.")
43
-
44
- data = Data()
45
- data.x = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.int32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- data.pos = torch.tensor(atoms.positions, dtype=torch.float32)
48
- data.temperature_og = torch.tensor([temperature], dtype=torch.float32)
49
- data.temperature = (data.temperature_og - MEAN_TEMP) / STD_TEMP
50
- data.cell = torch.tensor(atoms.cell.array, dtype=torch.float32).unsqueeze(0)
51
 
52
- data.pbc = torch.tensor([True, True, True])
53
- data.natoms = len(atoms)
54
 
55
- del atoms
56
- gc.collect()
57
- batch = Batch.from_data_list([data])
58
 
59
 
60
- edge_index, _, _, edge_attr = radius_graph_pbc(batch, 5.0, 64)
61
- del batch
62
- gc.collect()
63
- data.cart_dist = torch.norm(edge_attr, dim=-1)
64
- data.cart_dir = torch.nn.functional.normalize(edge_attr, dim=-1)
65
- data.edge_index = edge_index
66
- data.non_H_mask = data.x != 1
67
- delattr(data, "pbc")
68
- delattr(data, "natoms")
69
- batch = Batch.from_data_list([data])
70
- del data, edge_index, edge_attr
71
- gc.collect()
72
-
73
- st.success("Graph successfully created.")
74
-
75
- process_data(batch, model)
76
- st.success("ADPs successfully predicted.")
77
 
78
- # Create a download button for the processed CIF file
79
- with open("output.cif", "r") as f:
80
- cif_contents = f.read()
81
 
82
- st.download_button(
83
- label="Download processed CIF file",
84
- data=cif_contents,
85
- file_name="output.cif",
86
- mime="text/plain"
87
- )
88
-
89
- os.remove("output.cif")
90
- os.remove(filename)
91
- gc.collect()
92
  except Exception as e:
93
  st.error(f"An error occurred while reading the CIF file: {e}")
94
 
 
7
  from models.master import create_model
8
  from process import process_data
9
  from utils import radius_graph_pbc
10
+ from io import BytesIO, StringIO
11
  import gc
12
 
13
  MEAN_TEMP = torch.tensor(192.1785) #training temp mean
 
25
  """)
26
 
27
  uploaded_file = st.file_uploader("Upload a CIF file", type=["cif"], accept_multiple_files=False)
28
+
29
  if uploaded_file is not None:
30
  try:
 
 
31
  filename = str(uploaded_file.name)
32
+ file = BytesIO(uploaded_file.getbuffer())
33
+ cif = ReadCif(file)
34
+ print(cif.keys())
35
+ if len(cif.keys())>1:
36
+ st.markdown("Found " + str(len(cif.keys())) + " blocks in the CIF file. We will process all of them and export as separate CIF files.")
37
+ for key in cif.keys():
38
+ print(key)
39
+ # print(cif[key])
40
+ block = "data_"+str(key)+"\n"+ cif[key].printsection()
41
+ atoms = read(StringIO(block), format="cif")
42
+ print("atoms")
43
+ print(atoms)
44
+ # atoms = read(atoms_2, format="cif")
45
+ # with open(uploaded_file.name, "wb") as f:
46
+ # f.write(uploaded_file.getbuffer())
47
+ # filename = str(uploaded_file.name)
48
+ # # Read the CIF file using ASE
49
+ # atoms = read(filename, format="cif")
50
+ # cif = ReadCif(filename)
51
+ # print(cif.keys())
52
+ # print(len(atoms))
53
+ # # st.markdown(cif)
54
+ # cif_data = cif
55
+ # st.markdown(f"### CIF file: {filename}")
56
+ # temperature = 100
57
+ # if "_diffrn_ambient_temperature" in cif_data.keys():
58
+ # temperature = float(cif_data["_diffrn_ambient_temperature"].split("(")[0])
59
+ # elif "_cell_measurement_temperature" in cif_data.keys():
60
+ # temperature = float(cif_data["_cell_measurement_temperature"].split("(")[0])
61
+ # else:
62
+ # raise ValueError("Temperature not found in the CIF file. \
63
+ # Please provide a temperature in the field _diffrn_ambient_temperature o in the field _cell_measurement_temperature from the CIF file.")
64
+ # st.success("CIF file successfully read.")
65
+
66
+ # data = Data()
67
+ # data.x = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.int32)
68
 
69
+ # data.pos = torch.tensor(atoms.positions, dtype=torch.float32)
70
+ # data.temperature_og = torch.tensor([temperature], dtype=torch.float32)
71
+ # data.temperature = (data.temperature_og - MEAN_TEMP) / STD_TEMP
72
+ # data.cell = torch.tensor(atoms.cell.array, dtype=torch.float32).unsqueeze(0)
73
 
74
+ # data.pbc = torch.tensor([True, True, True])
75
+ # data.natoms = len(atoms)
76
 
77
+ # del atoms
78
+ # gc.collect()
79
+ # batch = Batch.from_data_list([data])
80
 
81
 
82
+ # edge_index, _, _, edge_attr = radius_graph_pbc(batch, 5.0, 64)
83
+ # del batch
84
+ # gc.collect()
85
+ # data.cart_dist = torch.norm(edge_attr, dim=-1)
86
+ # data.cart_dir = torch.nn.functional.normalize(edge_attr, dim=-1)
87
+ # data.edge_index = edge_index
88
+ # data.non_H_mask = data.x != 1
89
+ # delattr(data, "pbc")
90
+ # delattr(data, "natoms")
91
+ # batch = Batch.from_data_list([data])
92
+ # del data, edge_index, edge_attr
93
+ # gc.collect()
94
+
95
+ # st.success("Graph successfully created.")
96
+
97
+ # process_data(batch, model)
98
+ # st.success("ADPs successfully predicted.")
99
 
100
+ # # Create a download button for the processed CIF file
101
+ # with open("output.cif", "r") as f:
102
+ # cif_contents = f.read()
103
 
104
+ # st.download_button(
105
+ # label="Download processed CIF file",
106
+ # data=cif_contents,
107
+ # file_name="output.cif",
108
+ # mime="text/plain"
109
+ # )
110
+
111
+ # os.remove("output.cif")
112
+ # os.remove(filename)
113
+ # gc.collect()
114
  except Exception as e:
115
  st.error(f"An error occurred while reading the CIF file: {e}")
116