File size: 5,064 Bytes
f2b6066
744c6a1
f2b6066
744c6a1
 
 
fa790e2
f2b6066
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
744c6a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa790e2
 
 
 
 
744c6a1
 
 
 
 
 
 
fa790e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f2b6066
fa790e2
 
744c6a1
fa790e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import streamlit as st
import torch
from torch_geometric.data import Data, Batch
from ase.io import write
from ase import Atoms
import gc
from io import BytesIO, StringIO
from utils import radius_graph_pbc

MEAN_TEMP = torch.tensor(192.1785) #training temp mean
STD_TEMP = torch.tensor(81.2135) #training temp std

def process_ase(atoms, temperature, model):
    data = Data()
    data.x = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.int32)
    
    data.pos = torch.tensor(atoms.positions, dtype=torch.float32)
    data.temperature_og = torch.tensor([temperature], dtype=torch.float32)
    data.temperature = (data.temperature_og - MEAN_TEMP) / STD_TEMP
    data.cell = torch.tensor(atoms.cell.array, dtype=torch.float32).unsqueeze(0)

    data.pbc = torch.tensor([True, True, True])
    data.natoms = len(atoms)

    del atoms
    gc.collect()
    batch = Batch.from_data_list([data])
    

    edge_index, _, _, edge_attr = radius_graph_pbc(batch, 5.0, 64)
    del batch
    gc.collect()
    data.cart_dist = torch.norm(edge_attr, dim=-1)
    data.cart_dir = torch.nn.functional.normalize(edge_attr, dim=-1)
    data.edge_index = edge_index
    data.non_H_mask = data.x != 1
    delattr(data, "pbc")
    delattr(data, "natoms")
    batch = Batch.from_data_list([data])
    del data, edge_index, edge_attr
    gc.collect()

    st.success("Graph successfully created.")

    cif_file = process_data(batch, model)
    st.success("ADPs successfully predicted.")
    return cif_file

def process_data(batch, model, output_file="output.cif"):
    atoms = batch.x.numpy().astype(int)  # Atomic numbers
    positions = batch.pos.numpy()  # Atomic positions
    cell = batch.cell.squeeze(0).numpy()  # Cell parameters
    temperature = batch.temperature_og.numpy()[0]


    adps = model(batch)
    
    # Convert Ucart to Ucif
    M = batch.cell.squeeze(0)
    N = torch.diag(torch.linalg.norm(torch.linalg.inv(M.transpose(-1,-2)).squeeze(0), dim=-1))

    M = torch.linalg.inv(M)
    N = torch.linalg.inv(N)

    adps = M.transpose(-1,-2)@adps@M
    adps = N.transpose(-1,-2)@adps@N
    del M, N
    gc.collect()
    
    
    non_H_mask = batch.non_H_mask.numpy()
    indices = torch.arange(len(atoms))[non_H_mask].numpy()
    indices = {indices[i]: i for i in range(len(indices))}
    # Create ASE Atoms object
    ase_atoms = Atoms(numbers=atoms, positions=positions, cell=cell, pbc=True)

    # Convert positions to fractional coordinates
    fractional_positions = ase_atoms.get_scaled_positions()


    # Instead of reading from file, get CIF content directly from ASE's write function
    cif_content = BytesIO()
    write(cif_content, ase_atoms, format='cif')
    lines = cif_content.getvalue().decode('utf-8').splitlines(True)
    cif_content.close()

    # Find the line where "loop_" appears and remove lines from there to the end
    for i, line in enumerate(lines):
        if line.strip().startswith('loop_'):
            lines = lines[:i]
            break

    # Use StringIO to build the CIF content
    cif_file = StringIO()
    cif_file.writelines(lines)
    # Write temperature
    cif_file.write(f"\n_diffrn_ambient_temperature    {temperature}\n")
    # Write atomic positions
    cif_file.write("\nloop_\n")
    cif_file.write("_atom_site_label\n")
    cif_file.write("_atom_site_type_symbol\n")
    cif_file.write("_atom_site_fract_x\n")
    cif_file.write("_atom_site_fract_y\n")
    cif_file.write("_atom_site_fract_z\n")
    cif_file.write("_atom_site_U_iso_or_equiv\n")
    cif_file.write("_atom_site_thermal_displace_type\n")
    
    element_count = {}
    for i, (atom_number, frac_pos) in enumerate(zip(atoms, fractional_positions)):
        element = ase_atoms[i].symbol
        assert atom_number == ase_atoms[i].number
        if element not in element_count:
            element_count[element] = 0
        element_count[element] += 1
        label = f"{element}{element_count[element]}"
        u_iso = torch.trace(adps[indices[i]]).mean() if element != 'H' else 0.0001
        type = "Uani" if element != 'H' else "Uiso"
        cif_file.write(f"{label} {element} {frac_pos[0]} {frac_pos[1]} {frac_pos[2]} {u_iso} {type}\n")

    # Write ADPs
    cif_file.write("\nloop_\n")
    cif_file.write("_atom_site_aniso_label\n")
    cif_file.write("_atom_site_aniso_U_11\n")
    cif_file.write("_atom_site_aniso_U_22\n")
    cif_file.write("_atom_site_aniso_U_33\n")
    cif_file.write("_atom_site_aniso_U_23\n")
    cif_file.write("_atom_site_aniso_U_13\n")
    cif_file.write("_atom_site_aniso_U_12\n")
    
    element_count = {}
    for i, atom_number in enumerate(atoms):
        if atom_number == 1:
            continue
        element = ase_atoms[i].symbol
        if element not in element_count:
            element_count[element] = 0
        element_count[element] += 1
        label = f"{element}{element_count[element]}"
        cif_file.write(f"{label} {adps[indices[i],0,0]} {adps[indices[i],1,1]} {adps[indices[i],2,2]} {adps[indices[i],1,2]} {adps[indices[i],0,2]} {adps[indices[i],0,1]}\n")
    return cif_file