Spaces:
Sleeping
Sleeping
# Copyright Universitat Politècnica de Catalunya 2024 https://imatge.upc.edu | |
# Distributed under the MIT License. | |
# (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) | |
import torch | |
import torch_geometric.nn as pyg_nn | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch_scatter import scatter | |
from models.utils import ExpNormalSmearing, CosineCutoff | |
class CartNet(torch.nn.Module): | |
""" | |
CartNet model from Cartesian Encoding Graph Neural Network for Crystal Structures Property Prediction: Application to Thermal Ellipsoid Estimation. | |
Args: | |
dim_in (int): Dimensionality of the input features. | |
dim_rbf (int): Dimensionality of the radial basis function embeddings. | |
num_layers (int): Number of CartNet layers in the model. | |
radius (float, optional): Radius cutoff for neighbor interactions. Default is 5.0. | |
invariant (bool, optional): If `True`, enforces rotational invariance in the encoder. Default is `False`. | |
temperature (bool, optional): If `True`, includes temperature information in the encoder. Default is `True`. | |
use_envelope (bool, optional): If `True`, applies an envelope function to the interactions. Default is `True`. | |
cholesky (bool, optional): If `True`, uses a Cholesky head for the output. If `False`, uses a scalar head. Default is `True`. | |
Methods: | |
forward(batch): | |
Performs a forward pass of the model. | |
Args: | |
batch: A batch of input data. | |
Returns: | |
pred: The model's predictions. | |
true: The ground truth values corresponding to the input batch. | |
""" | |
def __init__(self, | |
dim_in: int, | |
dim_rbf: int, | |
num_layers: int, | |
radius: float = 5.0, | |
invariant: bool = False, | |
temperature: bool = True, | |
use_envelope: bool = True, | |
cholesky: bool = True): | |
super().__init__() | |
self.encoder = Encoder(dim_in, dim_rbf=dim_rbf, radius=radius, invariant=invariant, temperature=temperature) | |
self.dim_in = dim_in | |
layers = [] | |
for _ in range(num_layers): | |
layers.append(CartNet_layer( | |
dim_in=dim_in, | |
use_envelope=use_envelope, | |
)) | |
self.layers = torch.nn.Sequential(*layers) | |
if cholesky: | |
self.head = Cholesky_head(dim_in) | |
else: | |
self.head = Scalar_head(dim_in) | |
def forward(self, batch): | |
batch = self.encoder(batch) | |
for layer in self.layers: | |
batch = layer(batch) | |
pred = self.head(batch) | |
return pred | |
class Encoder(torch.nn.Module): | |
""" | |
Encoder module for the CartNet model. | |
This module encodes node and edge features for input into the CartNet model, incorporating optional temperature information and rotational invariance. | |
Args: | |
dim_in (int): Dimension of the input features after embedding. | |
dim_rbf (int): Dimension of the radial basis function used for edge attributes. | |
radius (float, optional): Cutoff radius for neighbor interactions. Defaults to 5.0. | |
invariant (bool, optional): If True, the encoder enforces rotational invariance by excluding directional information from edge attributes. Defaults to False. | |
temperature (bool, optional): If True, includes temperature data in the node embeddings. Defaults to True. | |
Attributes: | |
dim_in (int): Dimension of the input features. | |
invariant (bool): Indicates if rotational invariance is enforced. | |
temperature (bool): Indicates if temperature information is included. | |
embedding (nn.Embedding): Embedding layer mapping atomic numbers to feature vectors. | |
temperature_proj_atom (pyg_nn.Linear): Linear layer projecting temperature to embedding dimensions (used if temperature is True). | |
bias (nn.Parameter): Bias term added to embeddings (used if temperature is False). | |
activation (nn.Module): Activation function (SiLU). | |
encoder_atom (nn.Sequential): Sequential network encoding node features. | |
encoder_edge (nn.Sequential): Sequential network encoding edge features. | |
rbf (ExpNormalSmearing): Radial basis function for encoding distances. | |
""" | |
def __init__( | |
self, | |
dim_in: int, | |
dim_rbf: int, | |
radius: float = 5.0, | |
invariant: bool = False, | |
temperature: bool = True, | |
): | |
super(Encoder, self).__init__() | |
self.dim_in = dim_in | |
self.invariant = invariant | |
self.temperature = temperature | |
self.embedding = nn.Embedding(119, self.dim_in*2) | |
if self.temperature: | |
self.temperature_proj_atom = pyg_nn.Linear(1, self.dim_in*2, bias=True) | |
else: | |
self.bias = nn.Parameter(torch.zeros(self.dim_in*2)) | |
self.activation = nn.SiLU(inplace=True) | |
self.encoder_atom = nn.Sequential(self.activation, | |
pyg_nn.Linear(self.dim_in*2, self.dim_in), | |
self.activation) | |
if self.invariant: | |
dim_edge = dim_rbf | |
else: | |
dim_edge = dim_rbf + 3 | |
self.encoder_edge = nn.Sequential(pyg_nn.Linear(dim_edge, self.dim_in*2), | |
self.activation, | |
pyg_nn.Linear(self.dim_in*2, self.dim_in), | |
self.activation) | |
self.rbf = ExpNormalSmearing(0.0,radius,dim_rbf,False) | |
torch.nn.init.xavier_uniform_(self.embedding.weight.data) | |
def forward(self, batch): | |
x = self.embedding(batch.x) + self.temperature_proj_atom(batch.temperature.unsqueeze(-1))[batch.batch] | |
batch.x = self.encoder_atom(x) | |
batch.edge_attr = self.encoder_edge(torch.cat([self.rbf(batch.cart_dist), batch.cart_dir], dim=-1)) | |
return batch | |
class CartNet_layer(pyg_nn.conv.MessagePassing): | |
""" | |
The message-passing layer used in the CartNet architecture. | |
Parameters: | |
dim_in (int): Dimension of the input node features. | |
use_envelope (bool, optional): If True, applies an envelope function to the distances. Defaults to True. | |
Attributes: | |
dim_in (int): Dimension of the input node features. | |
activation (nn.Module): Activation function (SiLU) used in the layer. | |
MLP_aggr (nn.Sequential): MLP used for aggregating messages. | |
MLP_gate (nn.Sequential): MLP used for computing gating coefficients. | |
norm (nn.BatchNorm1d): Batch normalization applied to the gating coefficients. | |
norm2 (nn.BatchNorm1d): Batch normalization applied to the aggregated messages. | |
use_envelope (bool): Indicates if the envelope function is used. | |
envelope (CosineCutoff): Envelope function applied to the distances. | |
""" | |
def __init__(self, | |
dim_in: int, | |
use_envelope: bool = True | |
): | |
super().__init__() | |
self.dim_in = dim_in | |
self.activation = nn.SiLU(inplace=True) | |
self.MLP_aggr = nn.Sequential( | |
pyg_nn.Linear(dim_in*3, dim_in, bias=True), | |
self.activation, | |
pyg_nn.Linear(dim_in, dim_in, bias=True), | |
) | |
self.MLP_gate = nn.Sequential( | |
pyg_nn.Linear(dim_in*3, dim_in, bias=True), | |
self.activation, | |
pyg_nn.Linear(dim_in, dim_in, bias=True), | |
) | |
self.norm = nn.BatchNorm1d(dim_in) | |
self.norm2 = nn.BatchNorm1d(dim_in) | |
self.use_envelope = use_envelope | |
self.envelope = CosineCutoff(0, 5.0) | |
def forward(self, batch): | |
x, e, edge_index, dist = batch.x, batch.edge_attr, batch.edge_index, batch.cart_dist | |
""" | |
x : [n_nodes, dim_in] | |
e : [n_edges, dim_in] | |
edge_index : [2, n_edges] | |
dist : [n_edges] | |
batch : [n_nodes] | |
""" | |
x_in = x | |
e_in = e | |
x, e = self.propagate(edge_index, | |
Xx=x, Ee=e, | |
He=dist, | |
) | |
batch.x = self.activation(x) + x_in | |
batch.edge_attr = e_in + e | |
return batch | |
def message(self, Xx_i, Ee, Xx_j, He): | |
""" | |
x_i : [n_edges, dim_in] | |
x_j : [n_edges, dim_in] | |
e : [n_edges, dim_in] | |
""" | |
e_ij = self.MLP_gate(torch.cat([Xx_i, Xx_j, Ee], dim=-1)) | |
e_ij = F.sigmoid(self.norm(e_ij)) | |
if self.use_envelope: | |
sigma_ij = self.envelope(He).unsqueeze(-1)*e_ij | |
else: | |
sigma_ij = e_ij | |
self.e = sigma_ij | |
return sigma_ij | |
def aggregate(self, sigma_ij, index, Xx_i, Xx_j, Ee, Xx): | |
""" | |
sigma_ij : [n_edges, dim_in] ; is the output from message() function | |
index : [n_edges] | |
x_j : [n_edges, dim_in] | |
""" | |
dim_size = Xx.shape[0] | |
sender = self.MLP_aggr(torch.cat([Xx_i, Xx_j, Ee], dim=-1)) | |
out = scatter(sigma_ij*sender, index, 0, None, dim_size, | |
reduce='sum') | |
return out | |
def update(self, aggr_out): | |
""" | |
aggr_out : [n_nodes, dim_in] ; is the output from aggregate() function after the aggregation | |
x : [n_nodes, dim_in] | |
""" | |
x = self.norm2(aggr_out) | |
e_out = self.e | |
del self.e | |
return x, e_out | |
class Cholesky_head(torch.nn.Module): | |
""" | |
The Cholesky head used in the CartNet model. | |
It enforce the positive definiteness of the output covariance matrix. | |
Args: | |
dim_in (int): The input dimension of the features. | |
""" | |
def __init__(self, | |
dim_in: int | |
): | |
super(Cholesky_head, self).__init__() | |
self.MLP = nn.Sequential(pyg_nn.Linear(dim_in, dim_in//2), | |
nn.SiLU(inplace=True), | |
pyg_nn.Linear(dim_in//2, 6)) | |
def forward(self, batch): | |
pred = self.MLP(batch.x[batch.non_H_mask]) | |
diag_elements = F.softplus(pred[:, :3]) | |
i,j = torch.tensor([0,1,2,0,0,1]), torch.tensor([0,1,2,1,2,2]) | |
L_matrix = torch.zeros(pred.size(0),3,3, device=pred.device, dtype=pred.dtype) | |
L_matrix[:,i[:3], i[:3]] = diag_elements | |
L_matrix[:,i[3:], j[3:]] = pred[:,3:] | |
U = torch.bmm(L_matrix.transpose(1, 2), L_matrix) | |
return U | |