cartnet-demo / process.py
Àlex Solé
fixed output of wrong predictions, limiting the max volume possible for a prediction
58deabf
raw
history blame
5.6 kB
import streamlit as st
import torch
from torch_geometric.data import Data, Batch
from ase.io import write
from ase import Atoms
import gc
from io import BytesIO, StringIO
from utils import radius_graph_pbc
MEAN_TEMP = torch.tensor(192.1785) #training temp mean
STD_TEMP = torch.tensor(81.2135) #training temp std
def process_ase(atoms, temperature, model):
data = Data()
data.x = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.int32)
data.pos = torch.tensor(atoms.positions, dtype=torch.float32)
data.temperature_og = torch.tensor([temperature], dtype=torch.float32)
data.temperature = (data.temperature_og - MEAN_TEMP) / STD_TEMP
data.cell = torch.tensor(atoms.cell.array, dtype=torch.float32).unsqueeze(0)
data.pbc = torch.tensor([True, True, True])
data.natoms = len(atoms)
del atoms
gc.collect()
batch = Batch.from_data_list([data])
edge_index, _, _, edge_attr = radius_graph_pbc(batch, 5.0, 64)
del batch
gc.collect()
data.cart_dist = torch.norm(edge_attr, dim=-1)
data.cart_dir = torch.nn.functional.normalize(edge_attr, dim=-1)
data.edge_index = edge_index
data.non_H_mask = data.x != 1
delattr(data, "pbc")
delattr(data, "natoms")
batch = Batch.from_data_list([data])
del data, edge_index, edge_attr
gc.collect()
st.success("Graph successfully created.")
cif_file = process_data(batch, model)
st.success("ADPs successfully predicted.")
return cif_file
def process_data(batch, model):
atoms = batch.x.numpy().astype(int) # Atomic numbers
positions = batch.pos.numpy() # Atomic positions
cell = batch.cell.squeeze(0).numpy() # Cell parameters
temperature = batch.temperature_og.numpy()[0]
adps = model(batch)
# Convert Ucart to Ucif
M = batch.cell.squeeze(0)
N = torch.diag(torch.linalg.norm(torch.linalg.inv(M.transpose(-1,-2)).squeeze(0), dim=-1))
M = torch.linalg.inv(M)
N = torch.linalg.inv(N)
adps = M.transpose(-1,-2)@adps@M
adps = N.transpose(-1,-2)@adps@N
del M, N
gc.collect()
non_H_mask = batch.non_H_mask.numpy()
indices = torch.arange(len(atoms))[non_H_mask].numpy()
indices = {indices[i]: i for i in range(len(indices))}
# Create ASE Atoms object
ase_atoms = Atoms(numbers=atoms, positions=positions, cell=cell, pbc=True)
# Convert positions to fractional coordinates
fractional_positions = ase_atoms.get_scaled_positions()
# Instead of reading from file, get CIF content directly from ASE's write function
cif_content = BytesIO()
write(cif_content, ase_atoms, format='cif')
lines = cif_content.getvalue().decode('utf-8').splitlines(True)
cif_content.close()
# Find the line where "loop_" appears and remove lines from there to the end
for i, line in enumerate(lines):
if line.strip().startswith('loop_'):
lines = lines[:i]
break
# Use StringIO to build the CIF content
cif_file = StringIO()
cif_file.writelines(lines)
# Write temperature
cif_file.write(f"\n_diffrn_ambient_temperature {temperature}\n")
# Write atomic positions
cif_file.write("\nloop_\n")
cif_file.write("_atom_site_label\n")
cif_file.write("_atom_site_type_symbol\n")
cif_file.write("_atom_site_fract_x\n")
cif_file.write("_atom_site_fract_y\n")
cif_file.write("_atom_site_fract_z\n")
cif_file.write("_atom_site_U_iso_or_equiv\n")
cif_file.write("_atom_site_thermal_displace_type\n")
element_count = {}
labels_uiso = []
for i, (atom_number, frac_pos) in enumerate(zip(atoms, fractional_positions)):
element = ase_atoms[i].symbol
assert atom_number == ase_atoms[i].number
if element not in element_count:
element_count[element] = 0
element_count[element] += 1
label = f"{element}{element_count[element]}"
u_iso = torch.trace(adps[indices[i]]).mean() if element != 'H' else 0.0001
if u_iso > 1:
labels_uiso.append(label)
u_iso = 0.0001
type = "Uani" if (element != 'H' or u_iso > 1) else "Uiso"
cif_file.write(f"{label} {element} {frac_pos[0]} {frac_pos[1]} {frac_pos[2]} {u_iso} {type}\n")
# Write ADPs
cif_file.write("\nloop_\n")
cif_file.write("_atom_site_aniso_label\n")
cif_file.write("_atom_site_aniso_U_11\n")
cif_file.write("_atom_site_aniso_U_22\n")
cif_file.write("_atom_site_aniso_U_33\n")
cif_file.write("_atom_site_aniso_U_23\n")
cif_file.write("_atom_site_aniso_U_13\n")
cif_file.write("_atom_site_aniso_U_12\n")
element_count = {}
total_adps = 0
for i, atom_number in enumerate(atoms):
if atom_number == 1:
continue
total_adps += 1
element = ase_atoms[i].symbol
if element not in element_count:
element_count[element] = 0
element_count[element] += 1
label = f"{element}{element_count[element]}"
if label in labels_uiso:
continue
cif_file.write(f"{label} {adps[indices[i],0,0]} {adps[indices[i],1,1]} {adps[indices[i],2,2]} {adps[indices[i],1,2]} {adps[indices[i],0,2]} {adps[indices[i],0,1]}\n")
if len(labels_uiso) > 0:
st.warning(f"Succesfully predicted {100*(total_adps-len(labels_uiso))/total_adps:.2f} % of ADPs")
st.warning(f"CartNet produced unexpected ADPs for the following atoms (will be ignored in the output file): \n {', '.join(labels_uiso)}")
return cif_file