cartnet-demo / process.py
Àlex Solé
fixed bug multiple structures csd
fa790e2
raw
history blame
3.69 kB
import torch
from ase.io import write
from ase import Atoms
import gc
from io import BytesIO, StringIO
@torch.no_grad()
def process_data(batch, model, output_file="output.cif"):
atoms = batch.x.numpy().astype(int) # Atomic numbers
positions = batch.pos.numpy() # Atomic positions
cell = batch.cell.squeeze(0).numpy() # Cell parameters
temperature = batch.temperature_og.numpy()[0]
adps = model(batch)
# Convert Ucart to Ucif
M = batch.cell.squeeze(0)
N = torch.diag(torch.linalg.norm(torch.linalg.inv(M.transpose(-1,-2)).squeeze(0), dim=-1))
M = torch.linalg.inv(M)
N = torch.linalg.inv(N)
adps = M.transpose(-1,-2)@adps@M
adps = N.transpose(-1,-2)@adps@N
del M, N
gc.collect()
non_H_mask = batch.non_H_mask.numpy()
indices = torch.arange(len(atoms))[non_H_mask].numpy()
indices = {indices[i]: i for i in range(len(indices))}
# Create ASE Atoms object
ase_atoms = Atoms(numbers=atoms, positions=positions, cell=cell, pbc=True)
# Convert positions to fractional coordinates
fractional_positions = ase_atoms.get_scaled_positions()
# Instead of reading from file, get CIF content directly from ASE's write function
cif_content = BytesIO()
write(cif_content, ase_atoms, format='cif')
lines = cif_content.getvalue().decode('utf-8').splitlines(True)
cif_content.close()
# Find the line where "loop_" appears and remove lines from there to the end
for i, line in enumerate(lines):
if line.strip().startswith('loop_'):
lines = lines[:i]
break
# Use StringIO to build the CIF content
cif_file = StringIO()
cif_file.writelines(lines)
# Write temperature
cif_file.write(f"\n_diffrn_ambient_temperature {temperature}\n")
# Write atomic positions
cif_file.write("\nloop_\n")
cif_file.write("_atom_site_label\n")
cif_file.write("_atom_site_type_symbol\n")
cif_file.write("_atom_site_fract_x\n")
cif_file.write("_atom_site_fract_y\n")
cif_file.write("_atom_site_fract_z\n")
cif_file.write("_atom_site_U_iso_or_equiv\n")
cif_file.write("_atom_site_thermal_displace_type\n")
element_count = {}
for i, (atom_number, frac_pos) in enumerate(zip(atoms, fractional_positions)):
element = ase_atoms[i].symbol
assert atom_number == ase_atoms[i].number
if element not in element_count:
element_count[element] = 0
element_count[element] += 1
label = f"{element}{element_count[element]}"
u_iso = torch.trace(adps[indices[i]]).mean() if element != 'H' else 0.01
type = "Uani" if element != 'H' else "Uiso"
cif_file.write(f"{label} {element} {frac_pos[0]} {frac_pos[1]} {frac_pos[2]} {u_iso} {type}\n")
# Write ADPs
cif_file.write("\nloop_\n")
cif_file.write("_atom_site_aniso_label\n")
cif_file.write("_atom_site_aniso_U_11\n")
cif_file.write("_atom_site_aniso_U_22\n")
cif_file.write("_atom_site_aniso_U_33\n")
cif_file.write("_atom_site_aniso_U_23\n")
cif_file.write("_atom_site_aniso_U_13\n")
cif_file.write("_atom_site_aniso_U_12\n")
element_count = {}
for i, atom_number in enumerate(atoms):
if atom_number == 1:
continue
element = ase_atoms[i].symbol
if element not in element_count:
element_count[element] = 0
element_count[element] += 1
label = f"{element}{element_count[element]}"
cif_file.write(f"{label} {adps[indices[i],0,0]} {adps[indices[i],1,1]} {adps[indices[i],2,2]} {adps[indices[i],1,2]} {adps[indices[i],0,2]} {adps[indices[i],0,1]}\n")
return cif_file