import torch import math import numpy as np from torch import nn, Tensor import torch.nn.functional as F from typing import Optional # Implementation from TensorNet # https://github.com/torchmd/torchmd-net class ExpNormalSmearing(nn.Module): def __init__( self, cutoff_lower=0.0, cutoff_upper=5.0, num_rbf=50, trainable=True, dtype=torch.float32, ): super(ExpNormalSmearing, self).__init__() self.cutoff_lower = cutoff_lower self.cutoff_upper = cutoff_upper self.num_rbf = num_rbf self.trainable = trainable self.dtype = dtype self.cutoff_fn = CosineCutoff(0, cutoff_upper) self.alpha = 5.0 / (cutoff_upper - cutoff_lower) means, betas = self._initial_params() if trainable: self.register_parameter("means", nn.Parameter(means)) self.register_parameter("betas", nn.Parameter(betas)) else: self.register_buffer("means", means) self.register_buffer("betas", betas) def _initial_params(self): # initialize means and betas according to the default values in PhysNet # https://pubs.acs.org/doi/10.1021/acs.jctc.9b00181 start_value = torch.exp( torch.scalar_tensor( -self.cutoff_upper + self.cutoff_lower, dtype=self.dtype ) ) means = torch.linspace(start_value, 1, self.num_rbf, dtype=self.dtype) betas = torch.tensor( [(2 / self.num_rbf * (1 - start_value)) ** -2] * self.num_rbf, dtype=self.dtype, ) return means, betas def reset_parameters(self): means, betas = self._initial_params() self.means.data.copy_(means) self.betas.data.copy_(betas) def forward(self, dist): dist = dist.unsqueeze(-1) return self.cutoff_fn(dist) * torch.exp( -self.betas * (torch.exp(self.alpha * (-dist + self.cutoff_lower)) - self.means) ** 2 ) class CosineCutoff(nn.Module): def __init__(self, cutoff_lower=0.0, cutoff_upper=5.0): super(CosineCutoff, self).__init__() self.cutoff_lower = cutoff_lower self.cutoff_upper = cutoff_upper def forward(self, distances: Tensor) -> Tensor: if self.cutoff_lower > 0: cutoffs = 0.5 * ( torch.cos( math.pi * ( 2 * (distances - self.cutoff_lower) / (self.cutoff_upper - self.cutoff_lower) + 1.0 ) ) + 1.0 ) # remove contributions below the cutoff radius cutoffs = cutoffs * (distances < self.cutoff_upper) cutoffs = cutoffs * (distances > self.cutoff_lower) return cutoffs else: cutoffs = 0.5 * (torch.cos(distances * math.pi / self.cutoff_upper) + 1.0) # remove contributions beyond the cutoff radius cutoffs = cutoffs * (distances < self.cutoff_upper) return cutoffs # Implementation from Comformer # https://github.com/divelab/AIRS/tree/main/OpenMat/ComFormer class RBFExpansion(nn.Module): """Expand interatomic distances with radial basis functions.""" def __init__( self, vmin: float = 0, vmax: float = 8, bins: int = 40, lengthscale: Optional[float] = None, ): """Register torch parameters for RBF expansion.""" super().__init__() self.vmin = vmin self.vmax = vmax self.bins = bins self.register_buffer( "centers", torch.linspace(self.vmin, self.vmax, self.bins) ) if lengthscale is None: # SchNet-style # set lengthscales relative to granularity of RBF expansion self.lengthscale = np.diff(self.centers).mean() self.gamma = 1 / self.lengthscale else: self.lengthscale = lengthscale self.gamma = 1 / (lengthscale ** 2) def forward(self, distance: torch.Tensor) -> torch.Tensor: """Apply RBF expansion to interatomic distance tensor.""" return torch.exp( -self.gamma * (distance.unsqueeze(1) - self.centers) ** 2 )