cartnet-demo / models /cartnet.py
Àlex Solé
merged from streamlit
744c6a1
raw
history blame
10.7 kB
# Copyright Universitat Politècnica de Catalunya 2024 https://imatge.upc.edu
# Distributed under the MIT License.
# (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT)
import torch
import torch_geometric.nn as pyg_nn
import torch.nn as nn
import torch.nn.functional as F
from torch_scatter import scatter
from models.utils import ExpNormalSmearing, CosineCutoff
class CartNet(torch.nn.Module):
"""
CartNet model from Cartesian Encoding Graph Neural Network for Crystal Structures Property Prediction: Application to Thermal Ellipsoid Estimation.
Args:
dim_in (int): Dimensionality of the input features.
dim_rbf (int): Dimensionality of the radial basis function embeddings.
num_layers (int): Number of CartNet layers in the model.
radius (float, optional): Radius cutoff for neighbor interactions. Default is 5.0.
invariant (bool, optional): If `True`, enforces rotational invariance in the encoder. Default is `False`.
temperature (bool, optional): If `True`, includes temperature information in the encoder. Default is `True`.
use_envelope (bool, optional): If `True`, applies an envelope function to the interactions. Default is `True`.
cholesky (bool, optional): If `True`, uses a Cholesky head for the output. If `False`, uses a scalar head. Default is `True`.
Methods:
forward(batch):
Performs a forward pass of the model.
Args:
batch: A batch of input data.
Returns:
pred: The model's predictions.
true: The ground truth values corresponding to the input batch.
"""
def __init__(self,
dim_in: int,
dim_rbf: int,
num_layers: int,
radius: float = 5.0,
invariant: bool = False,
temperature: bool = True,
use_envelope: bool = True,
cholesky: bool = True):
super().__init__()
self.encoder = Encoder(dim_in, dim_rbf=dim_rbf, radius=radius, invariant=invariant, temperature=temperature)
self.dim_in = dim_in
layers = []
for _ in range(num_layers):
layers.append(CartNet_layer(
dim_in=dim_in,
use_envelope=use_envelope,
))
self.layers = torch.nn.Sequential(*layers)
if cholesky:
self.head = Cholesky_head(dim_in)
else:
self.head = Scalar_head(dim_in)
def forward(self, batch):
batch = self.encoder(batch)
for layer in self.layers:
batch = layer(batch)
pred = self.head(batch)
return pred
class Encoder(torch.nn.Module):
"""
Encoder module for the CartNet model.
This module encodes node and edge features for input into the CartNet model, incorporating optional temperature information and rotational invariance.
Args:
dim_in (int): Dimension of the input features after embedding.
dim_rbf (int): Dimension of the radial basis function used for edge attributes.
radius (float, optional): Cutoff radius for neighbor interactions. Defaults to 5.0.
invariant (bool, optional): If True, the encoder enforces rotational invariance by excluding directional information from edge attributes. Defaults to False.
temperature (bool, optional): If True, includes temperature data in the node embeddings. Defaults to True.
Attributes:
dim_in (int): Dimension of the input features.
invariant (bool): Indicates if rotational invariance is enforced.
temperature (bool): Indicates if temperature information is included.
embedding (nn.Embedding): Embedding layer mapping atomic numbers to feature vectors.
temperature_proj_atom (pyg_nn.Linear): Linear layer projecting temperature to embedding dimensions (used if temperature is True).
bias (nn.Parameter): Bias term added to embeddings (used if temperature is False).
activation (nn.Module): Activation function (SiLU).
encoder_atom (nn.Sequential): Sequential network encoding node features.
encoder_edge (nn.Sequential): Sequential network encoding edge features.
rbf (ExpNormalSmearing): Radial basis function for encoding distances.
"""
def __init__(
self,
dim_in: int,
dim_rbf: int,
radius: float = 5.0,
invariant: bool = False,
temperature: bool = True,
):
super(Encoder, self).__init__()
self.dim_in = dim_in
self.invariant = invariant
self.temperature = temperature
self.embedding = nn.Embedding(119, self.dim_in*2)
if self.temperature:
self.temperature_proj_atom = pyg_nn.Linear(1, self.dim_in*2, bias=True)
else:
self.bias = nn.Parameter(torch.zeros(self.dim_in*2))
self.activation = nn.SiLU(inplace=True)
self.encoder_atom = nn.Sequential(self.activation,
pyg_nn.Linear(self.dim_in*2, self.dim_in),
self.activation)
if self.invariant:
dim_edge = dim_rbf
else:
dim_edge = dim_rbf + 3
self.encoder_edge = nn.Sequential(pyg_nn.Linear(dim_edge, self.dim_in*2),
self.activation,
pyg_nn.Linear(self.dim_in*2, self.dim_in),
self.activation)
self.rbf = ExpNormalSmearing(0.0,radius,dim_rbf,False)
torch.nn.init.xavier_uniform_(self.embedding.weight.data)
def forward(self, batch):
x = self.embedding(batch.x) + self.temperature_proj_atom(batch.temperature.unsqueeze(-1))[batch.batch]
batch.x = self.encoder_atom(x)
batch.edge_attr = self.encoder_edge(torch.cat([self.rbf(batch.cart_dist), batch.cart_dir], dim=-1))
return batch
class CartNet_layer(pyg_nn.conv.MessagePassing):
"""
The message-passing layer used in the CartNet architecture.
Parameters:
dim_in (int): Dimension of the input node features.
use_envelope (bool, optional): If True, applies an envelope function to the distances. Defaults to True.
Attributes:
dim_in (int): Dimension of the input node features.
activation (nn.Module): Activation function (SiLU) used in the layer.
MLP_aggr (nn.Sequential): MLP used for aggregating messages.
MLP_gate (nn.Sequential): MLP used for computing gating coefficients.
norm (nn.BatchNorm1d): Batch normalization applied to the gating coefficients.
norm2 (nn.BatchNorm1d): Batch normalization applied to the aggregated messages.
use_envelope (bool): Indicates if the envelope function is used.
envelope (CosineCutoff): Envelope function applied to the distances.
"""
def __init__(self,
dim_in: int,
use_envelope: bool = True
):
super().__init__()
self.dim_in = dim_in
self.activation = nn.SiLU(inplace=True)
self.MLP_aggr = nn.Sequential(
pyg_nn.Linear(dim_in*3, dim_in, bias=True),
self.activation,
pyg_nn.Linear(dim_in, dim_in, bias=True),
)
self.MLP_gate = nn.Sequential(
pyg_nn.Linear(dim_in*3, dim_in, bias=True),
self.activation,
pyg_nn.Linear(dim_in, dim_in, bias=True),
)
self.norm = nn.BatchNorm1d(dim_in)
self.norm2 = nn.BatchNorm1d(dim_in)
self.use_envelope = use_envelope
self.envelope = CosineCutoff(0, 5.0)
def forward(self, batch):
x, e, edge_index, dist = batch.x, batch.edge_attr, batch.edge_index, batch.cart_dist
"""
x : [n_nodes, dim_in]
e : [n_edges, dim_in]
edge_index : [2, n_edges]
dist : [n_edges]
batch : [n_nodes]
"""
x_in = x
e_in = e
x, e = self.propagate(edge_index,
Xx=x, Ee=e,
He=dist,
)
batch.x = self.activation(x) + x_in
batch.edge_attr = e_in + e
return batch
def message(self, Xx_i, Ee, Xx_j, He):
"""
x_i : [n_edges, dim_in]
x_j : [n_edges, dim_in]
e : [n_edges, dim_in]
"""
e_ij = self.MLP_gate(torch.cat([Xx_i, Xx_j, Ee], dim=-1))
e_ij = F.sigmoid(self.norm(e_ij))
if self.use_envelope:
sigma_ij = self.envelope(He).unsqueeze(-1)*e_ij
else:
sigma_ij = e_ij
self.e = sigma_ij
return sigma_ij
def aggregate(self, sigma_ij, index, Xx_i, Xx_j, Ee, Xx):
"""
sigma_ij : [n_edges, dim_in] ; is the output from message() function
index : [n_edges]
x_j : [n_edges, dim_in]
"""
dim_size = Xx.shape[0]
sender = self.MLP_aggr(torch.cat([Xx_i, Xx_j, Ee], dim=-1))
out = scatter(sigma_ij*sender, index, 0, None, dim_size,
reduce='sum')
return out
def update(self, aggr_out):
"""
aggr_out : [n_nodes, dim_in] ; is the output from aggregate() function after the aggregation
x : [n_nodes, dim_in]
"""
x = self.norm2(aggr_out)
e_out = self.e
del self.e
return x, e_out
class Cholesky_head(torch.nn.Module):
"""
The Cholesky head used in the CartNet model.
It enforce the positive definiteness of the output covariance matrix.
Args:
dim_in (int): The input dimension of the features.
"""
def __init__(self,
dim_in: int
):
super(Cholesky_head, self).__init__()
self.MLP = nn.Sequential(pyg_nn.Linear(dim_in, dim_in//2),
nn.SiLU(inplace=True),
pyg_nn.Linear(dim_in//2, 6))
def forward(self, batch):
pred = self.MLP(batch.x[batch.non_H_mask])
diag_elements = F.softplus(pred[:, :3])
i,j = torch.tensor([0,1,2,0,0,1]), torch.tensor([0,1,2,1,2,2])
L_matrix = torch.zeros(pred.size(0),3,3, device=pred.device, dtype=pred.dtype)
L_matrix[:,i[:3], i[:3]] = diag_elements
L_matrix[:,i[3:], j[3:]] = pred[:,3:]
U = torch.bmm(L_matrix.transpose(1, 2), L_matrix)
return U