cartnet-demo / models /utils.py
Àlex Solé
merged from streamlit
744c6a1
raw
history blame
4.37 kB
import torch
import math
import numpy as np
from torch import nn, Tensor
import torch.nn.functional as F
from typing import Optional
# Implementation from TensorNet
# https://github.com/torchmd/torchmd-net
class ExpNormalSmearing(nn.Module):
def __init__(
self,
cutoff_lower=0.0,
cutoff_upper=5.0,
num_rbf=50,
trainable=True,
dtype=torch.float32,
):
super(ExpNormalSmearing, self).__init__()
self.cutoff_lower = cutoff_lower
self.cutoff_upper = cutoff_upper
self.num_rbf = num_rbf
self.trainable = trainable
self.dtype = dtype
self.cutoff_fn = CosineCutoff(0, cutoff_upper)
self.alpha = 5.0 / (cutoff_upper - cutoff_lower)
means, betas = self._initial_params()
if trainable:
self.register_parameter("means", nn.Parameter(means))
self.register_parameter("betas", nn.Parameter(betas))
else:
self.register_buffer("means", means)
self.register_buffer("betas", betas)
def _initial_params(self):
# initialize means and betas according to the default values in PhysNet
# https://pubs.acs.org/doi/10.1021/acs.jctc.9b00181
start_value = torch.exp(
torch.scalar_tensor(
-self.cutoff_upper + self.cutoff_lower, dtype=self.dtype
)
)
means = torch.linspace(start_value, 1, self.num_rbf, dtype=self.dtype)
betas = torch.tensor(
[(2 / self.num_rbf * (1 - start_value)) ** -2] * self.num_rbf,
dtype=self.dtype,
)
return means, betas
def reset_parameters(self):
means, betas = self._initial_params()
self.means.data.copy_(means)
self.betas.data.copy_(betas)
def forward(self, dist):
dist = dist.unsqueeze(-1)
return self.cutoff_fn(dist) * torch.exp(
-self.betas
* (torch.exp(self.alpha * (-dist + self.cutoff_lower)) - self.means) ** 2
)
class CosineCutoff(nn.Module):
def __init__(self, cutoff_lower=0.0, cutoff_upper=5.0):
super(CosineCutoff, self).__init__()
self.cutoff_lower = cutoff_lower
self.cutoff_upper = cutoff_upper
def forward(self, distances: Tensor) -> Tensor:
if self.cutoff_lower > 0:
cutoffs = 0.5 * (
torch.cos(
math.pi
* (
2
* (distances - self.cutoff_lower)
/ (self.cutoff_upper - self.cutoff_lower)
+ 1.0
)
)
+ 1.0
)
# remove contributions below the cutoff radius
cutoffs = cutoffs * (distances < self.cutoff_upper)
cutoffs = cutoffs * (distances > self.cutoff_lower)
return cutoffs
else:
cutoffs = 0.5 * (torch.cos(distances * math.pi / self.cutoff_upper) + 1.0)
# remove contributions beyond the cutoff radius
cutoffs = cutoffs * (distances < self.cutoff_upper)
return cutoffs
# Implementation from Comformer
# https://github.com/divelab/AIRS/tree/main/OpenMat/ComFormer
class RBFExpansion(nn.Module):
"""Expand interatomic distances with radial basis functions."""
def __init__(
self,
vmin: float = 0,
vmax: float = 8,
bins: int = 40,
lengthscale: Optional[float] = None,
):
"""Register torch parameters for RBF expansion."""
super().__init__()
self.vmin = vmin
self.vmax = vmax
self.bins = bins
self.register_buffer(
"centers", torch.linspace(self.vmin, self.vmax, self.bins)
)
if lengthscale is None:
# SchNet-style
# set lengthscales relative to granularity of RBF expansion
self.lengthscale = np.diff(self.centers).mean()
self.gamma = 1 / self.lengthscale
else:
self.lengthscale = lengthscale
self.gamma = 1 / (lengthscale ** 2)
def forward(self, distance: torch.Tensor) -> torch.Tensor:
"""Apply RBF expansion to interatomic distance tensor."""
return torch.exp(
-self.gamma * (distance.unsqueeze(1) - self.centers) ** 2
)