File size: 2,764 Bytes
744c6a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import argparse
import os
from ase.io import read
from CifFile import ReadCif
from torch_geometric.data import Data, Batch
import torch
from models.master import create_model
from process import process_data
from utils import radius_graph_pbc
import gc

MEAN_TEMP = torch.tensor(192.1785) #training temp mean
STD_TEMP = torch.tensor(81.2135) #training temp std

@torch.no_grad()
def process_cif(input_file, output_file):
    model = create_model()
    
    try:
        # Read the CIF file using ASE
        atoms = read(input_file, format="cif")
        cif = ReadCif(input_file)
        cif_data = cif.first_block()
        if "_diffrn_ambient_temperature" in cif_data.keys():
            temperature = float(cif_data["_diffrn_ambient_temperature"])
        else:
            raise ValueError("Temperature not found in the CIF file. \
                                Please provide a temperature in the field _diffrn_ambient_temperature from the CIF file.")
        
        data = Data()
        data.x = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.int32)

        if len(atoms.positions) > 300:
            raise ValueError("This implementation is not optimized for large systems. For large systems, please use the local version.")
        
        data.pos = torch.tensor(atoms.positions, dtype=torch.float32)
        data.temperature_og = torch.tensor([temperature], dtype=torch.float32)
        data.temperature = (data.temperature_og - MEAN_TEMP) / STD_TEMP
        data.cell = torch.tensor(atoms.cell.array, dtype=torch.float32).unsqueeze(0)

        data.pbc = torch.tensor([True, True, True])
        data.natoms = len(atoms)

        del atoms
        gc.collect()
        batch = Batch.from_data_list([data])
        
        edge_index, _, _, edge_attr = radius_graph_pbc(batch, 5.0, 64)
        del batch
        gc.collect()
        data.cart_dist = torch.norm(edge_attr, dim=-1)
        data.cart_dir = torch.nn.functional.normalize(edge_attr, dim=-1)
        data.edge_index = edge_index
        data.non_H_mask = data.x != 1
        delattr(data, "pbc")
        delattr(data, "natoms")
        batch = Batch.from_data_list([data])
        del data, edge_index, edge_attr
        gc.collect()

        process_data(batch, model, output_file)

        gc.collect()
    except Exception as e:
        print(f"An error occurred while processing the CIF file: {e}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process a CIF file and output the result.")
    parser.add_argument("input_file", type=str, help="Path to the input CIF file.")
    parser.add_argument("output_file", type=str, help="Path to the output CIF file.")
    args = parser.parse_args()

    process_cif(args.input_file, args.output_file)