Spaces:
Sleeping
Sleeping
import pyro | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from typing import Dict | |
from pyro.distributions.conditional import ( | |
ConditionalTransformModule, | |
ConditionalTransformedDistribution, | |
TransformedDistribution, | |
) | |
from pyro.distributions.torch_distribution import TorchDistributionMixin | |
from torch.distributions import constraints | |
from torch.distributions.utils import _sum_rightmost | |
from torch.distributions.transforms import Transform | |
class ConditionalAffineTransform(ConditionalTransformModule): | |
def __init__(self, context_nn, event_dim=0, **kwargs): | |
super().__init__(**kwargs) | |
self.event_dim = event_dim | |
self.context_nn = context_nn | |
def condition(self, context): | |
loc, log_scale = self.context_nn(context) | |
return torch.distributions.transforms.AffineTransform( | |
loc, log_scale.exp(), event_dim=self.event_dim | |
) | |
class MLP(nn.Module): | |
def __init__(self, num_inputs=1, width=32, num_outputs=1): | |
super().__init__() | |
activation = nn.LeakyReLU() | |
self.mlp = nn.Sequential( | |
nn.Linear(num_inputs, width, bias=False), | |
nn.BatchNorm1d(width), | |
activation, | |
nn.Linear(width, width, bias=False), | |
nn.BatchNorm1d(width), | |
activation, | |
nn.Linear(width, num_outputs), | |
) | |
def forward(self, x): | |
return self.mlp(x) | |
class CNN(nn.Module): | |
def __init__(self, in_shape=(1, 192, 192), width=16, num_outputs=1, context_dim=0): | |
super().__init__() | |
in_channels = in_shape[0] | |
res = in_shape[1] | |
s = 2 if res > 64 else 1 | |
activation = nn.LeakyReLU() | |
self.cnn = nn.Sequential( | |
nn.Conv2d(in_channels, width, 7, s, 3, bias=False), | |
nn.BatchNorm2d(width), | |
activation, | |
(nn.MaxPool2d(2, 2) if res > 32 else nn.Identity()), | |
nn.Conv2d(width, 2 * width, 3, 2, 1, bias=False), | |
nn.BatchNorm2d(2 * width), | |
activation, | |
nn.Conv2d(2 * width, 2 * width, 3, 1, 1, bias=False), | |
nn.BatchNorm2d(2 * width), | |
activation, | |
nn.Conv2d(2 * width, 4 * width, 3, 2, 1, bias=False), | |
nn.BatchNorm2d(4 * width), | |
activation, | |
nn.Conv2d(4 * width, 4 * width, 3, 1, 1, bias=False), | |
nn.BatchNorm2d(4 * width), | |
activation, | |
nn.Conv2d(4 * width, 8 * width, 3, 2, 1, bias=False), | |
nn.BatchNorm2d(8 * width), | |
activation, | |
) | |
self.fc = nn.Sequential( | |
nn.Linear(8 * width + context_dim, 8 * width, bias=False), | |
nn.BatchNorm1d(8 * width), | |
activation, | |
nn.Linear(8 * width, num_outputs), | |
) | |
def forward(self, x, y=None): | |
x = self.cnn(x) | |
x = x.mean(dim=(-2, -1)) # avg pool | |
if y is not None: | |
x = torch.cat([x, y], dim=-1) | |
return self.fc(x) | |
class ArgMaxGumbelMax(Transform): | |
r"""ArgMax as Transform, but inv conditioned on logits""" | |
def __init__(self, logits, event_dim=0, cache_size=0): | |
super(ArgMaxGumbelMax, self).__init__(cache_size=cache_size) | |
self.logits = logits | |
self._event_dim = event_dim | |
self._categorical = pyro.distributions.torch.Categorical( | |
logits=self.logits | |
).to_event(0) | |
def event_dim(self): | |
return self._event_dim | |
def __call__(self, gumbels): | |
""" | |
Computes the forward transform | |
""" | |
assert self.logits != None, "Logits not defined." | |
if self._cache_size == 0: | |
return self._call(gumbels) | |
y = self._call(gumbels) | |
return y | |
def _call(self, gumbels): | |
""" | |
Abstract method to compute forward transformation. | |
""" | |
assert self.logits != None, "Logits not defined." | |
y = gumbels + self.logits | |
# print(f'y: {y}') | |
# print(f'logits: {self.logits}') | |
return y.argmax(-1, keepdim=True) | |
def domain(self): | |
""" " | |
Domain of input(gumbel variables), Real | |
""" | |
if self.event_dim == 0: | |
return constraints.real | |
return constraints.independent(constraints.real, self.event_dim) | |
def codomain(self): | |
""" " | |
Domain of output(categorical variables), should be natural numbers, but set to Real for now | |
""" | |
if self.event_dim == 0: | |
return constraints.real | |
return constraints.independent(constraints.real, self.event_dim) | |
def inv(self, k): | |
"""Infer the gumbels noises given k and logits.""" | |
assert self.logits != None, "Logits not defined." | |
uniforms = torch.rand( | |
self.logits.shape, dtype=self.logits.dtype, device=self.logits.device | |
) | |
gumbels = -((-(uniforms.log())).log()) | |
# print(f'gumbels: {gumbels.size()}, {gumbels.dtype}') | |
# (batch_size, num_classes) mask to select kth class | |
# print(f'k : {k.size()}') | |
mask = F.one_hot( | |
k.squeeze(-1).to(torch.int64), num_classes=self.logits.shape[-1] | |
) | |
# print(f'mask: {mask.size()}, {mask.dtype}') | |
# (batch_size, 1) select topgumbel for truncation of other classes | |
topgumbel = (mask * gumbels).sum(dim=-1, keepdim=True) - ( | |
mask * self.logits | |
).sum(dim=-1, keepdim=True) | |
mask = 1 - mask # invert mask to select other != k classes | |
g = gumbels + self.logits | |
# (batch_size, num_classes) | |
epsilons = -torch.log(mask * torch.exp(-g) + torch.exp(-topgumbel)) - ( | |
mask * self.logits | |
) | |
return epsilons | |
def log_abs_det_jacobian(self, x, y): | |
"""We use the log_abs_det_jacobian to account for the categorical prob | |
x: Gumbels; y: argmax(x+logits) | |
P(y) = softmax | |
""" | |
# print(f"logits: {torch.log(F.softmax(self.logits, dim=-1)).size()}") | |
# print(f'y: {y.size()} ') | |
# print(f"log_abs_det_jacobian: {self._categorical.log_prob(y.squeeze(-1)).unsqueeze(-1).size()}") | |
return -self._categorical.log_prob(y.squeeze(-1)).unsqueeze(-1) | |
class ConditionalGumbelMax(ConditionalTransformModule): | |
r"""Given gumbels+logits, output the OneHot Categorical""" | |
def __init__(self, context_nn, event_dim=0, **kwargs): | |
# The logits_nn which predict the logits given ages: | |
super().__init__(**kwargs) | |
self.context_nn = context_nn | |
self.event_dim = event_dim | |
def condition(self, context): | |
"""Given context (age), output the Categorical results""" | |
logits = self.context_nn( | |
context | |
) # The logits for calculating argmax(Gumbel + logits) | |
return ArgMaxGumbelMax(logits) | |
def _logits(self, context): | |
"""Return logits given context""" | |
return self.context_nn(context) | |
def domain(self): | |
""" " | |
Domain of input(gumbel variables), Real | |
""" | |
if self.event_dim == 0: | |
return constraints.real | |
return constraints.independent(constraints.real, self.event_dim) | |
def codomain(self): | |
""" " | |
Domain of output(categorical variables), should be natural numbers, but set to Real for now | |
""" | |
if self.event_dim == 0: | |
return constraints.real | |
return constraints.independent(constraints.real, self.event_dim) | |
class TransformedDistributionGumbelMax(TransformedDistribution, TorchDistributionMixin): | |
r"""Define a TransformedDistribution class for Gumbel max""" | |
arg_constraints: Dict[str, constraints.Constraint] = {} | |
def log_prob(self, value): | |
""" | |
We do not use the log_prob() of the base Gumbel distribution, because the likelihood for | |
each class for Gumbel Max sampling is determined by the logits. | |
""" | |
# print("This happens") | |
if self._validate_args: | |
self._validate_sample(value) | |
event_dim = len(self.event_shape) | |
log_prob = 0.0 | |
y = value | |
for transform in reversed(self.transforms): | |
x = transform.inv(y) | |
event_dim += transform.domain.event_dim - transform.codomain.event_dim | |
log_prob = log_prob - _sum_rightmost( | |
transform.log_abs_det_jacobian(x, y), | |
event_dim - transform.domain.event_dim, | |
) | |
y = x | |
# print(f"log_prob: {log_prob.size()}") | |
return log_prob | |
class ConditionalTransformedDistributionGumbelMax(ConditionalTransformedDistribution): | |
def condition(self, context): | |
base_dist = self.base_dist.condition(context) | |
transforms = [t.condition(context) for t in self.transforms] | |
# return TransformedDistribution(base_dist, transforms) | |
return TransformedDistributionGumbelMax(base_dist, transforms) | |
def clear_cache(self): | |
pass | |