import streamlit as st import torch from torch_geometric.data import Data, Batch from ase.io import write from ase import Atoms import gc from io import BytesIO, StringIO from utils import radius_graph_pbc MEAN_TEMP = torch.tensor(192.1785) #training temp mean STD_TEMP = torch.tensor(81.2135) #training temp std def process_ase(atoms, temperature, model): data = Data() data.x = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.int32) data.pos = torch.tensor(atoms.positions, dtype=torch.float32) data.temperature_og = torch.tensor([temperature], dtype=torch.float32) data.temperature = (data.temperature_og - MEAN_TEMP) / STD_TEMP data.cell = torch.tensor(atoms.cell.array, dtype=torch.float32).unsqueeze(0) data.pbc = torch.tensor([True, True, True]) data.natoms = len(atoms) del atoms gc.collect() batch = Batch.from_data_list([data]) edge_index, _, _, edge_attr = radius_graph_pbc(batch, 5.0, 64) del batch gc.collect() data.cart_dist = torch.norm(edge_attr, dim=-1) data.cart_dir = torch.nn.functional.normalize(edge_attr, dim=-1) data.edge_index = edge_index data.non_H_mask = data.x != 1 delattr(data, "pbc") delattr(data, "natoms") batch = Batch.from_data_list([data]) del data, edge_index, edge_attr gc.collect() st.success("Graph successfully created.") cif_file = process_data(batch, model) st.success("ADPs successfully predicted.") return cif_file def process_data(batch, model): atoms = batch.x.numpy().astype(int) # Atomic numbers positions = batch.pos.numpy() # Atomic positions cell = batch.cell.squeeze(0).numpy() # Cell parameters temperature = batch.temperature_og.numpy()[0] adps = model(batch) # Convert Ucart to Ucif M = batch.cell.squeeze(0) N = torch.diag(torch.linalg.norm(torch.linalg.inv(M.transpose(-1,-2)).squeeze(0), dim=-1)) M = torch.linalg.inv(M) N = torch.linalg.inv(N) adps = M.transpose(-1,-2)@adps@M adps = N.transpose(-1,-2)@adps@N del M, N gc.collect() non_H_mask = batch.non_H_mask.numpy() indices = torch.arange(len(atoms))[non_H_mask].numpy() indices = {indices[i]: i for i in range(len(indices))} # Create ASE Atoms object ase_atoms = Atoms(numbers=atoms, positions=positions, cell=cell, pbc=True) # Convert positions to fractional coordinates fractional_positions = ase_atoms.get_scaled_positions() # Instead of reading from file, get CIF content directly from ASE's write function cif_content = BytesIO() write(cif_content, ase_atoms, format='cif') lines = cif_content.getvalue().decode('utf-8').splitlines(True) cif_content.close() # Find the line where "loop_" appears and remove lines from there to the end for i, line in enumerate(lines): if line.strip().startswith('loop_'): lines = lines[:i] break # Use StringIO to build the CIF content cif_file = StringIO() cif_file.writelines(lines) # Write temperature cif_file.write(f"\n_diffrn_ambient_temperature {temperature}\n") # Write atomic positions cif_file.write("\nloop_\n") cif_file.write("_atom_site_label\n") cif_file.write("_atom_site_type_symbol\n") cif_file.write("_atom_site_fract_x\n") cif_file.write("_atom_site_fract_y\n") cif_file.write("_atom_site_fract_z\n") cif_file.write("_atom_site_U_iso_or_equiv\n") cif_file.write("_atom_site_thermal_displace_type\n") element_count = {} labels_uiso = [] for i, (atom_number, frac_pos) in enumerate(zip(atoms, fractional_positions)): element = ase_atoms[i].symbol assert atom_number == ase_atoms[i].number if element not in element_count: element_count[element] = 0 element_count[element] += 1 label = f"{element}{element_count[element]}" u_iso = torch.trace(adps[indices[i]]).mean() if element != 'H' else 0.0001 if u_iso > 1: labels_uiso.append(label) u_iso = 0.0001 type = "Uani" if (element != 'H' or u_iso > 1) else "Uiso" cif_file.write(f"{label} {element} {frac_pos[0]} {frac_pos[1]} {frac_pos[2]} {u_iso} {type}\n") # Write ADPs cif_file.write("\nloop_\n") cif_file.write("_atom_site_aniso_label\n") cif_file.write("_atom_site_aniso_U_11\n") cif_file.write("_atom_site_aniso_U_22\n") cif_file.write("_atom_site_aniso_U_33\n") cif_file.write("_atom_site_aniso_U_23\n") cif_file.write("_atom_site_aniso_U_13\n") cif_file.write("_atom_site_aniso_U_12\n") element_count = {} total_adps = 0 for i, atom_number in enumerate(atoms): if atom_number == 1: continue total_adps += 1 element = ase_atoms[i].symbol if element not in element_count: element_count[element] = 0 element_count[element] += 1 label = f"{element}{element_count[element]}" if label in labels_uiso: continue cif_file.write(f"{label} {adps[indices[i],0,0]} {adps[indices[i],1,1]} {adps[indices[i],2,2]} {adps[indices[i],1,2]} {adps[indices[i],0,2]} {adps[indices[i],0,1]}\n") if len(labels_uiso) > 0: st.warning(f"Succesfully predicted {100*(total_adps-len(labels_uiso))/total_adps:.2f} % of ADPs") st.warning(f"CartNet produced unexpected ADPs for the following atoms (will be ignored in the output file): \n {', '.join(labels_uiso)}") return cif_file