Spaces:
Sleeping
Sleeping
File size: 5,064 Bytes
f2b6066 744c6a1 f2b6066 744c6a1 fa790e2 f2b6066 744c6a1 fa790e2 744c6a1 fa790e2 f2b6066 fa790e2 744c6a1 fa790e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import streamlit as st
import torch
from torch_geometric.data import Data, Batch
from ase.io import write
from ase import Atoms
import gc
from io import BytesIO, StringIO
from utils import radius_graph_pbc
MEAN_TEMP = torch.tensor(192.1785) #training temp mean
STD_TEMP = torch.tensor(81.2135) #training temp std
def process_ase(atoms, temperature, model):
data = Data()
data.x = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.int32)
data.pos = torch.tensor(atoms.positions, dtype=torch.float32)
data.temperature_og = torch.tensor([temperature], dtype=torch.float32)
data.temperature = (data.temperature_og - MEAN_TEMP) / STD_TEMP
data.cell = torch.tensor(atoms.cell.array, dtype=torch.float32).unsqueeze(0)
data.pbc = torch.tensor([True, True, True])
data.natoms = len(atoms)
del atoms
gc.collect()
batch = Batch.from_data_list([data])
edge_index, _, _, edge_attr = radius_graph_pbc(batch, 5.0, 64)
del batch
gc.collect()
data.cart_dist = torch.norm(edge_attr, dim=-1)
data.cart_dir = torch.nn.functional.normalize(edge_attr, dim=-1)
data.edge_index = edge_index
data.non_H_mask = data.x != 1
delattr(data, "pbc")
delattr(data, "natoms")
batch = Batch.from_data_list([data])
del data, edge_index, edge_attr
gc.collect()
st.success("Graph successfully created.")
cif_file = process_data(batch, model)
st.success("ADPs successfully predicted.")
return cif_file
def process_data(batch, model, output_file="output.cif"):
atoms = batch.x.numpy().astype(int) # Atomic numbers
positions = batch.pos.numpy() # Atomic positions
cell = batch.cell.squeeze(0).numpy() # Cell parameters
temperature = batch.temperature_og.numpy()[0]
adps = model(batch)
# Convert Ucart to Ucif
M = batch.cell.squeeze(0)
N = torch.diag(torch.linalg.norm(torch.linalg.inv(M.transpose(-1,-2)).squeeze(0), dim=-1))
M = torch.linalg.inv(M)
N = torch.linalg.inv(N)
adps = M.transpose(-1,-2)@adps@M
adps = N.transpose(-1,-2)@adps@N
del M, N
gc.collect()
non_H_mask = batch.non_H_mask.numpy()
indices = torch.arange(len(atoms))[non_H_mask].numpy()
indices = {indices[i]: i for i in range(len(indices))}
# Create ASE Atoms object
ase_atoms = Atoms(numbers=atoms, positions=positions, cell=cell, pbc=True)
# Convert positions to fractional coordinates
fractional_positions = ase_atoms.get_scaled_positions()
# Instead of reading from file, get CIF content directly from ASE's write function
cif_content = BytesIO()
write(cif_content, ase_atoms, format='cif')
lines = cif_content.getvalue().decode('utf-8').splitlines(True)
cif_content.close()
# Find the line where "loop_" appears and remove lines from there to the end
for i, line in enumerate(lines):
if line.strip().startswith('loop_'):
lines = lines[:i]
break
# Use StringIO to build the CIF content
cif_file = StringIO()
cif_file.writelines(lines)
# Write temperature
cif_file.write(f"\n_diffrn_ambient_temperature {temperature}\n")
# Write atomic positions
cif_file.write("\nloop_\n")
cif_file.write("_atom_site_label\n")
cif_file.write("_atom_site_type_symbol\n")
cif_file.write("_atom_site_fract_x\n")
cif_file.write("_atom_site_fract_y\n")
cif_file.write("_atom_site_fract_z\n")
cif_file.write("_atom_site_U_iso_or_equiv\n")
cif_file.write("_atom_site_thermal_displace_type\n")
element_count = {}
for i, (atom_number, frac_pos) in enumerate(zip(atoms, fractional_positions)):
element = ase_atoms[i].symbol
assert atom_number == ase_atoms[i].number
if element not in element_count:
element_count[element] = 0
element_count[element] += 1
label = f"{element}{element_count[element]}"
u_iso = torch.trace(adps[indices[i]]).mean() if element != 'H' else 0.0001
type = "Uani" if element != 'H' else "Uiso"
cif_file.write(f"{label} {element} {frac_pos[0]} {frac_pos[1]} {frac_pos[2]} {u_iso} {type}\n")
# Write ADPs
cif_file.write("\nloop_\n")
cif_file.write("_atom_site_aniso_label\n")
cif_file.write("_atom_site_aniso_U_11\n")
cif_file.write("_atom_site_aniso_U_22\n")
cif_file.write("_atom_site_aniso_U_33\n")
cif_file.write("_atom_site_aniso_U_23\n")
cif_file.write("_atom_site_aniso_U_13\n")
cif_file.write("_atom_site_aniso_U_12\n")
element_count = {}
for i, atom_number in enumerate(atoms):
if atom_number == 1:
continue
element = ase_atoms[i].symbol
if element not in element_count:
element_count[element] = 0
element_count[element] += 1
label = f"{element}{element_count[element]}"
cif_file.write(f"{label} {adps[indices[i],0,0]} {adps[indices[i],1,1]} {adps[indices[i],2,2]} {adps[indices[i],1,2]} {adps[indices[i],0,2]} {adps[indices[i],0,1]}\n")
return cif_file
|