Spaces:
Sleeping
Sleeping
File size: 5,597 Bytes
f2b6066 744c6a1 f2b6066 744c6a1 fa790e2 f2b6066 744c6a1 0b6c862 744c6a1 fa790e2 744c6a1 fa790e2 58deabf fa790e2 58deabf f2b6066 58deabf fa790e2 744c6a1 fa790e2 58deabf fa790e2 58deabf fa790e2 58deabf fa790e2 58deabf fa790e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import streamlit as st
import torch
from torch_geometric.data import Data, Batch
from ase.io import write
from ase import Atoms
import gc
from io import BytesIO, StringIO
from utils import radius_graph_pbc
MEAN_TEMP = torch.tensor(192.1785) #training temp mean
STD_TEMP = torch.tensor(81.2135) #training temp std
def process_ase(atoms, temperature, model):
data = Data()
data.x = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.int32)
data.pos = torch.tensor(atoms.positions, dtype=torch.float32)
data.temperature_og = torch.tensor([temperature], dtype=torch.float32)
data.temperature = (data.temperature_og - MEAN_TEMP) / STD_TEMP
data.cell = torch.tensor(atoms.cell.array, dtype=torch.float32).unsqueeze(0)
data.pbc = torch.tensor([True, True, True])
data.natoms = len(atoms)
del atoms
gc.collect()
batch = Batch.from_data_list([data])
edge_index, _, _, edge_attr = radius_graph_pbc(batch, 5.0, 64)
del batch
gc.collect()
data.cart_dist = torch.norm(edge_attr, dim=-1)
data.cart_dir = torch.nn.functional.normalize(edge_attr, dim=-1)
data.edge_index = edge_index
data.non_H_mask = data.x != 1
delattr(data, "pbc")
delattr(data, "natoms")
batch = Batch.from_data_list([data])
del data, edge_index, edge_attr
gc.collect()
st.success("Graph successfully created.")
cif_file = process_data(batch, model)
st.success("ADPs successfully predicted.")
return cif_file
def process_data(batch, model):
atoms = batch.x.numpy().astype(int) # Atomic numbers
positions = batch.pos.numpy() # Atomic positions
cell = batch.cell.squeeze(0).numpy() # Cell parameters
temperature = batch.temperature_og.numpy()[0]
adps = model(batch)
# Convert Ucart to Ucif
M = batch.cell.squeeze(0)
N = torch.diag(torch.linalg.norm(torch.linalg.inv(M.transpose(-1,-2)).squeeze(0), dim=-1))
M = torch.linalg.inv(M)
N = torch.linalg.inv(N)
adps = M.transpose(-1,-2)@adps@M
adps = N.transpose(-1,-2)@adps@N
del M, N
gc.collect()
non_H_mask = batch.non_H_mask.numpy()
indices = torch.arange(len(atoms))[non_H_mask].numpy()
indices = {indices[i]: i for i in range(len(indices))}
# Create ASE Atoms object
ase_atoms = Atoms(numbers=atoms, positions=positions, cell=cell, pbc=True)
# Convert positions to fractional coordinates
fractional_positions = ase_atoms.get_scaled_positions()
# Instead of reading from file, get CIF content directly from ASE's write function
cif_content = BytesIO()
write(cif_content, ase_atoms, format='cif')
lines = cif_content.getvalue().decode('utf-8').splitlines(True)
cif_content.close()
# Find the line where "loop_" appears and remove lines from there to the end
for i, line in enumerate(lines):
if line.strip().startswith('loop_'):
lines = lines[:i]
break
# Use StringIO to build the CIF content
cif_file = StringIO()
cif_file.writelines(lines)
# Write temperature
cif_file.write(f"\n_diffrn_ambient_temperature {temperature}\n")
# Write atomic positions
cif_file.write("\nloop_\n")
cif_file.write("_atom_site_label\n")
cif_file.write("_atom_site_type_symbol\n")
cif_file.write("_atom_site_fract_x\n")
cif_file.write("_atom_site_fract_y\n")
cif_file.write("_atom_site_fract_z\n")
cif_file.write("_atom_site_U_iso_or_equiv\n")
cif_file.write("_atom_site_thermal_displace_type\n")
element_count = {}
labels_uiso = []
for i, (atom_number, frac_pos) in enumerate(zip(atoms, fractional_positions)):
element = ase_atoms[i].symbol
assert atom_number == ase_atoms[i].number
if element not in element_count:
element_count[element] = 0
element_count[element] += 1
label = f"{element}{element_count[element]}"
u_iso = torch.trace(adps[indices[i]]).mean() if element != 'H' else 0.0001
if u_iso > 1:
labels_uiso.append(label)
u_iso = 0.0001
type = "Uani" if (element != 'H' or u_iso > 1) else "Uiso"
cif_file.write(f"{label} {element} {frac_pos[0]} {frac_pos[1]} {frac_pos[2]} {u_iso} {type}\n")
# Write ADPs
cif_file.write("\nloop_\n")
cif_file.write("_atom_site_aniso_label\n")
cif_file.write("_atom_site_aniso_U_11\n")
cif_file.write("_atom_site_aniso_U_22\n")
cif_file.write("_atom_site_aniso_U_33\n")
cif_file.write("_atom_site_aniso_U_23\n")
cif_file.write("_atom_site_aniso_U_13\n")
cif_file.write("_atom_site_aniso_U_12\n")
element_count = {}
total_adps = 0
for i, atom_number in enumerate(atoms):
if atom_number == 1:
continue
total_adps += 1
element = ase_atoms[i].symbol
if element not in element_count:
element_count[element] = 0
element_count[element] += 1
label = f"{element}{element_count[element]}"
if label in labels_uiso:
continue
cif_file.write(f"{label} {adps[indices[i],0,0]} {adps[indices[i],1,1]} {adps[indices[i],2,2]} {adps[indices[i],1,2]} {adps[indices[i],0,2]} {adps[indices[i],0,1]}\n")
if len(labels_uiso) > 0:
st.warning(f"Succesfully predicted {100*(total_adps-len(labels_uiso))/total_adps:.2f} % of ADPs")
st.warning(f"CartNet produced unexpected ADPs for the following atoms (will be ignored in the output file): \n {', '.join(labels_uiso)}")
return cif_file
|