Spaces:
Sleeping
Sleeping
import torch | |
import math | |
import numpy as np | |
from torch import nn, Tensor | |
import torch.nn.functional as F | |
from typing import Optional | |
# Implementation from TensorNet | |
# https://github.com/torchmd/torchmd-net | |
class ExpNormalSmearing(nn.Module): | |
def __init__( | |
self, | |
cutoff_lower=0.0, | |
cutoff_upper=5.0, | |
num_rbf=50, | |
trainable=True, | |
dtype=torch.float32, | |
): | |
super(ExpNormalSmearing, self).__init__() | |
self.cutoff_lower = cutoff_lower | |
self.cutoff_upper = cutoff_upper | |
self.num_rbf = num_rbf | |
self.trainable = trainable | |
self.dtype = dtype | |
self.cutoff_fn = CosineCutoff(0, cutoff_upper) | |
self.alpha = 5.0 / (cutoff_upper - cutoff_lower) | |
means, betas = self._initial_params() | |
if trainable: | |
self.register_parameter("means", nn.Parameter(means)) | |
self.register_parameter("betas", nn.Parameter(betas)) | |
else: | |
self.register_buffer("means", means) | |
self.register_buffer("betas", betas) | |
def _initial_params(self): | |
# initialize means and betas according to the default values in PhysNet | |
# https://pubs.acs.org/doi/10.1021/acs.jctc.9b00181 | |
start_value = torch.exp( | |
torch.scalar_tensor( | |
-self.cutoff_upper + self.cutoff_lower, dtype=self.dtype | |
) | |
) | |
means = torch.linspace(start_value, 1, self.num_rbf, dtype=self.dtype) | |
betas = torch.tensor( | |
[(2 / self.num_rbf * (1 - start_value)) ** -2] * self.num_rbf, | |
dtype=self.dtype, | |
) | |
return means, betas | |
def reset_parameters(self): | |
means, betas = self._initial_params() | |
self.means.data.copy_(means) | |
self.betas.data.copy_(betas) | |
def forward(self, dist): | |
dist = dist.unsqueeze(-1) | |
return self.cutoff_fn(dist) * torch.exp( | |
-self.betas | |
* (torch.exp(self.alpha * (-dist + self.cutoff_lower)) - self.means) ** 2 | |
) | |
class CosineCutoff(nn.Module): | |
def __init__(self, cutoff_lower=0.0, cutoff_upper=5.0): | |
super(CosineCutoff, self).__init__() | |
self.cutoff_lower = cutoff_lower | |
self.cutoff_upper = cutoff_upper | |
def forward(self, distances: Tensor) -> Tensor: | |
if self.cutoff_lower > 0: | |
cutoffs = 0.5 * ( | |
torch.cos( | |
math.pi | |
* ( | |
2 | |
* (distances - self.cutoff_lower) | |
/ (self.cutoff_upper - self.cutoff_lower) | |
+ 1.0 | |
) | |
) | |
+ 1.0 | |
) | |
# remove contributions below the cutoff radius | |
cutoffs = cutoffs * (distances < self.cutoff_upper) | |
cutoffs = cutoffs * (distances > self.cutoff_lower) | |
return cutoffs | |
else: | |
cutoffs = 0.5 * (torch.cos(distances * math.pi / self.cutoff_upper) + 1.0) | |
# remove contributions beyond the cutoff radius | |
cutoffs = cutoffs * (distances < self.cutoff_upper) | |
return cutoffs | |
# Implementation from Comformer | |
# https://github.com/divelab/AIRS/tree/main/OpenMat/ComFormer | |
class RBFExpansion(nn.Module): | |
"""Expand interatomic distances with radial basis functions.""" | |
def __init__( | |
self, | |
vmin: float = 0, | |
vmax: float = 8, | |
bins: int = 40, | |
lengthscale: Optional[float] = None, | |
): | |
"""Register torch parameters for RBF expansion.""" | |
super().__init__() | |
self.vmin = vmin | |
self.vmax = vmax | |
self.bins = bins | |
self.register_buffer( | |
"centers", torch.linspace(self.vmin, self.vmax, self.bins) | |
) | |
if lengthscale is None: | |
# SchNet-style | |
# set lengthscales relative to granularity of RBF expansion | |
self.lengthscale = np.diff(self.centers).mean() | |
self.gamma = 1 / self.lengthscale | |
else: | |
self.lengthscale = lengthscale | |
self.gamma = 1 / (lengthscale ** 2) | |
def forward(self, distance: torch.Tensor) -> torch.Tensor: | |
"""Apply RBF expansion to interatomic distance tensor.""" | |
return torch.exp( | |
-self.gamma * (distance.unsqueeze(1) - self.centers) ** 2 | |
) |