Spaces:
Sleeping
Sleeping
File size: 4,372 Bytes
744c6a1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import torch
import math
import numpy as np
from torch import nn, Tensor
import torch.nn.functional as F
from typing import Optional
# Implementation from TensorNet
# https://github.com/torchmd/torchmd-net
class ExpNormalSmearing(nn.Module):
def __init__(
self,
cutoff_lower=0.0,
cutoff_upper=5.0,
num_rbf=50,
trainable=True,
dtype=torch.float32,
):
super(ExpNormalSmearing, self).__init__()
self.cutoff_lower = cutoff_lower
self.cutoff_upper = cutoff_upper
self.num_rbf = num_rbf
self.trainable = trainable
self.dtype = dtype
self.cutoff_fn = CosineCutoff(0, cutoff_upper)
self.alpha = 5.0 / (cutoff_upper - cutoff_lower)
means, betas = self._initial_params()
if trainable:
self.register_parameter("means", nn.Parameter(means))
self.register_parameter("betas", nn.Parameter(betas))
else:
self.register_buffer("means", means)
self.register_buffer("betas", betas)
def _initial_params(self):
# initialize means and betas according to the default values in PhysNet
# https://pubs.acs.org/doi/10.1021/acs.jctc.9b00181
start_value = torch.exp(
torch.scalar_tensor(
-self.cutoff_upper + self.cutoff_lower, dtype=self.dtype
)
)
means = torch.linspace(start_value, 1, self.num_rbf, dtype=self.dtype)
betas = torch.tensor(
[(2 / self.num_rbf * (1 - start_value)) ** -2] * self.num_rbf,
dtype=self.dtype,
)
return means, betas
def reset_parameters(self):
means, betas = self._initial_params()
self.means.data.copy_(means)
self.betas.data.copy_(betas)
def forward(self, dist):
dist = dist.unsqueeze(-1)
return self.cutoff_fn(dist) * torch.exp(
-self.betas
* (torch.exp(self.alpha * (-dist + self.cutoff_lower)) - self.means) ** 2
)
class CosineCutoff(nn.Module):
def __init__(self, cutoff_lower=0.0, cutoff_upper=5.0):
super(CosineCutoff, self).__init__()
self.cutoff_lower = cutoff_lower
self.cutoff_upper = cutoff_upper
def forward(self, distances: Tensor) -> Tensor:
if self.cutoff_lower > 0:
cutoffs = 0.5 * (
torch.cos(
math.pi
* (
2
* (distances - self.cutoff_lower)
/ (self.cutoff_upper - self.cutoff_lower)
+ 1.0
)
)
+ 1.0
)
# remove contributions below the cutoff radius
cutoffs = cutoffs * (distances < self.cutoff_upper)
cutoffs = cutoffs * (distances > self.cutoff_lower)
return cutoffs
else:
cutoffs = 0.5 * (torch.cos(distances * math.pi / self.cutoff_upper) + 1.0)
# remove contributions beyond the cutoff radius
cutoffs = cutoffs * (distances < self.cutoff_upper)
return cutoffs
# Implementation from Comformer
# https://github.com/divelab/AIRS/tree/main/OpenMat/ComFormer
class RBFExpansion(nn.Module):
"""Expand interatomic distances with radial basis functions."""
def __init__(
self,
vmin: float = 0,
vmax: float = 8,
bins: int = 40,
lengthscale: Optional[float] = None,
):
"""Register torch parameters for RBF expansion."""
super().__init__()
self.vmin = vmin
self.vmax = vmax
self.bins = bins
self.register_buffer(
"centers", torch.linspace(self.vmin, self.vmax, self.bins)
)
if lengthscale is None:
# SchNet-style
# set lengthscales relative to granularity of RBF expansion
self.lengthscale = np.diff(self.centers).mean()
self.gamma = 1 / self.lengthscale
else:
self.lengthscale = lengthscale
self.gamma = 1 / (lengthscale ** 2)
def forward(self, distance: torch.Tensor) -> torch.Tensor:
"""Apply RBF expansion to interatomic distance tensor."""
return torch.exp(
-self.gamma * (distance.unsqueeze(1) - self.centers) ** 2
) |