File size: 3,691 Bytes
744c6a1
 
 
 
fa790e2
744c6a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa790e2
 
 
 
 
744c6a1
 
 
 
 
 
 
fa790e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
744c6a1
fa790e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
from ase.io import write
from ase import Atoms
import gc
from io import BytesIO, StringIO

@torch.no_grad()
def process_data(batch, model, output_file="output.cif"):
    atoms = batch.x.numpy().astype(int)  # Atomic numbers
    positions = batch.pos.numpy()  # Atomic positions
    cell = batch.cell.squeeze(0).numpy()  # Cell parameters
    temperature = batch.temperature_og.numpy()[0]


    adps = model(batch)
    
    # Convert Ucart to Ucif
    M = batch.cell.squeeze(0)
    N = torch.diag(torch.linalg.norm(torch.linalg.inv(M.transpose(-1,-2)).squeeze(0), dim=-1))

    M = torch.linalg.inv(M)
    N = torch.linalg.inv(N)

    adps = M.transpose(-1,-2)@adps@M
    adps = N.transpose(-1,-2)@adps@N
    del M, N
    gc.collect()
    
    
    non_H_mask = batch.non_H_mask.numpy()
    indices = torch.arange(len(atoms))[non_H_mask].numpy()
    indices = {indices[i]: i for i in range(len(indices))}
    # Create ASE Atoms object
    ase_atoms = Atoms(numbers=atoms, positions=positions, cell=cell, pbc=True)

    # Convert positions to fractional coordinates
    fractional_positions = ase_atoms.get_scaled_positions()


    # Instead of reading from file, get CIF content directly from ASE's write function
    cif_content = BytesIO()
    write(cif_content, ase_atoms, format='cif')
    lines = cif_content.getvalue().decode('utf-8').splitlines(True)
    cif_content.close()

    # Find the line where "loop_" appears and remove lines from there to the end
    for i, line in enumerate(lines):
        if line.strip().startswith('loop_'):
            lines = lines[:i]
            break

    # Use StringIO to build the CIF content
    cif_file = StringIO()
    cif_file.writelines(lines)
    # Write temperature
    cif_file.write(f"\n_diffrn_ambient_temperature    {temperature}\n")
    # Write atomic positions
    cif_file.write("\nloop_\n")
    cif_file.write("_atom_site_label\n")
    cif_file.write("_atom_site_type_symbol\n")
    cif_file.write("_atom_site_fract_x\n")
    cif_file.write("_atom_site_fract_y\n")
    cif_file.write("_atom_site_fract_z\n")
    cif_file.write("_atom_site_U_iso_or_equiv\n")
    cif_file.write("_atom_site_thermal_displace_type\n")
    
    element_count = {}
    for i, (atom_number, frac_pos) in enumerate(zip(atoms, fractional_positions)):
        element = ase_atoms[i].symbol
        assert atom_number == ase_atoms[i].number
        if element not in element_count:
            element_count[element] = 0
        element_count[element] += 1
        label = f"{element}{element_count[element]}"
        u_iso = torch.trace(adps[indices[i]]).mean() if element != 'H' else 0.01
        type = "Uani" if element != 'H' else "Uiso"
        cif_file.write(f"{label} {element} {frac_pos[0]} {frac_pos[1]} {frac_pos[2]} {u_iso} {type}\n")

    # Write ADPs
    cif_file.write("\nloop_\n")
    cif_file.write("_atom_site_aniso_label\n")
    cif_file.write("_atom_site_aniso_U_11\n")
    cif_file.write("_atom_site_aniso_U_22\n")
    cif_file.write("_atom_site_aniso_U_33\n")
    cif_file.write("_atom_site_aniso_U_23\n")
    cif_file.write("_atom_site_aniso_U_13\n")
    cif_file.write("_atom_site_aniso_U_12\n")
    
    element_count = {}
    for i, atom_number in enumerate(atoms):
        if atom_number == 1:
            continue
        element = ase_atoms[i].symbol
        if element not in element_count:
            element_count[element] = 0
        element_count[element] += 1
        label = f"{element}{element_count[element]}"
        cif_file.write(f"{label} {adps[indices[i],0,0]} {adps[indices[i],1,1]} {adps[indices[i],2,2]} {adps[indices[i],1,2]} {adps[indices[i],0,2]} {adps[indices[i],0,1]}\n")
    return cif_file