Spaces:
Sleeping
Sleeping
Àlex Solé
fixed output of wrong predictions, limiting the max volume possible for a prediction
58deabf
import streamlit as st | |
import torch | |
from torch_geometric.data import Data, Batch | |
from ase.io import write | |
from ase import Atoms | |
import gc | |
from io import BytesIO, StringIO | |
from utils import radius_graph_pbc | |
MEAN_TEMP = torch.tensor(192.1785) #training temp mean | |
STD_TEMP = torch.tensor(81.2135) #training temp std | |
def process_ase(atoms, temperature, model): | |
data = Data() | |
data.x = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.int32) | |
data.pos = torch.tensor(atoms.positions, dtype=torch.float32) | |
data.temperature_og = torch.tensor([temperature], dtype=torch.float32) | |
data.temperature = (data.temperature_og - MEAN_TEMP) / STD_TEMP | |
data.cell = torch.tensor(atoms.cell.array, dtype=torch.float32).unsqueeze(0) | |
data.pbc = torch.tensor([True, True, True]) | |
data.natoms = len(atoms) | |
del atoms | |
gc.collect() | |
batch = Batch.from_data_list([data]) | |
edge_index, _, _, edge_attr = radius_graph_pbc(batch, 5.0, 64) | |
del batch | |
gc.collect() | |
data.cart_dist = torch.norm(edge_attr, dim=-1) | |
data.cart_dir = torch.nn.functional.normalize(edge_attr, dim=-1) | |
data.edge_index = edge_index | |
data.non_H_mask = data.x != 1 | |
delattr(data, "pbc") | |
delattr(data, "natoms") | |
batch = Batch.from_data_list([data]) | |
del data, edge_index, edge_attr | |
gc.collect() | |
st.success("Graph successfully created.") | |
cif_file = process_data(batch, model) | |
st.success("ADPs successfully predicted.") | |
return cif_file | |
def process_data(batch, model): | |
atoms = batch.x.numpy().astype(int) # Atomic numbers | |
positions = batch.pos.numpy() # Atomic positions | |
cell = batch.cell.squeeze(0).numpy() # Cell parameters | |
temperature = batch.temperature_og.numpy()[0] | |
adps = model(batch) | |
# Convert Ucart to Ucif | |
M = batch.cell.squeeze(0) | |
N = torch.diag(torch.linalg.norm(torch.linalg.inv(M.transpose(-1,-2)).squeeze(0), dim=-1)) | |
M = torch.linalg.inv(M) | |
N = torch.linalg.inv(N) | |
adps = M.transpose(-1,-2)@adps@M | |
adps = N.transpose(-1,-2)@adps@N | |
del M, N | |
gc.collect() | |
non_H_mask = batch.non_H_mask.numpy() | |
indices = torch.arange(len(atoms))[non_H_mask].numpy() | |
indices = {indices[i]: i for i in range(len(indices))} | |
# Create ASE Atoms object | |
ase_atoms = Atoms(numbers=atoms, positions=positions, cell=cell, pbc=True) | |
# Convert positions to fractional coordinates | |
fractional_positions = ase_atoms.get_scaled_positions() | |
# Instead of reading from file, get CIF content directly from ASE's write function | |
cif_content = BytesIO() | |
write(cif_content, ase_atoms, format='cif') | |
lines = cif_content.getvalue().decode('utf-8').splitlines(True) | |
cif_content.close() | |
# Find the line where "loop_" appears and remove lines from there to the end | |
for i, line in enumerate(lines): | |
if line.strip().startswith('loop_'): | |
lines = lines[:i] | |
break | |
# Use StringIO to build the CIF content | |
cif_file = StringIO() | |
cif_file.writelines(lines) | |
# Write temperature | |
cif_file.write(f"\n_diffrn_ambient_temperature {temperature}\n") | |
# Write atomic positions | |
cif_file.write("\nloop_\n") | |
cif_file.write("_atom_site_label\n") | |
cif_file.write("_atom_site_type_symbol\n") | |
cif_file.write("_atom_site_fract_x\n") | |
cif_file.write("_atom_site_fract_y\n") | |
cif_file.write("_atom_site_fract_z\n") | |
cif_file.write("_atom_site_U_iso_or_equiv\n") | |
cif_file.write("_atom_site_thermal_displace_type\n") | |
element_count = {} | |
labels_uiso = [] | |
for i, (atom_number, frac_pos) in enumerate(zip(atoms, fractional_positions)): | |
element = ase_atoms[i].symbol | |
assert atom_number == ase_atoms[i].number | |
if element not in element_count: | |
element_count[element] = 0 | |
element_count[element] += 1 | |
label = f"{element}{element_count[element]}" | |
u_iso = torch.trace(adps[indices[i]]).mean() if element != 'H' else 0.0001 | |
if u_iso > 1: | |
labels_uiso.append(label) | |
u_iso = 0.0001 | |
type = "Uani" if (element != 'H' or u_iso > 1) else "Uiso" | |
cif_file.write(f"{label} {element} {frac_pos[0]} {frac_pos[1]} {frac_pos[2]} {u_iso} {type}\n") | |
# Write ADPs | |
cif_file.write("\nloop_\n") | |
cif_file.write("_atom_site_aniso_label\n") | |
cif_file.write("_atom_site_aniso_U_11\n") | |
cif_file.write("_atom_site_aniso_U_22\n") | |
cif_file.write("_atom_site_aniso_U_33\n") | |
cif_file.write("_atom_site_aniso_U_23\n") | |
cif_file.write("_atom_site_aniso_U_13\n") | |
cif_file.write("_atom_site_aniso_U_12\n") | |
element_count = {} | |
total_adps = 0 | |
for i, atom_number in enumerate(atoms): | |
if atom_number == 1: | |
continue | |
total_adps += 1 | |
element = ase_atoms[i].symbol | |
if element not in element_count: | |
element_count[element] = 0 | |
element_count[element] += 1 | |
label = f"{element}{element_count[element]}" | |
if label in labels_uiso: | |
continue | |
cif_file.write(f"{label} {adps[indices[i],0,0]} {adps[indices[i],1,1]} {adps[indices[i],2,2]} {adps[indices[i],1,2]} {adps[indices[i],0,2]} {adps[indices[i],0,1]}\n") | |
if len(labels_uiso) > 0: | |
st.warning(f"Succesfully predicted {100*(total_adps-len(labels_uiso))/total_adps:.2f} % of ADPs") | |
st.warning(f"CartNet produced unexpected ADPs for the following atoms (will be ignored in the output file): \n {', '.join(labels_uiso)}") | |
return cif_file | |