fabio-deep
added links
146a6ea
import pyro
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict
from pyro.distributions.conditional import (
ConditionalTransformModule,
ConditionalTransformedDistribution,
TransformedDistribution,
)
from pyro.distributions.torch_distribution import TorchDistributionMixin
from torch.distributions import constraints
from torch.distributions.utils import _sum_rightmost
from torch.distributions.transforms import Transform
class ConditionalAffineTransform(ConditionalTransformModule):
def __init__(self, context_nn, event_dim=0, **kwargs):
super().__init__(**kwargs)
self.event_dim = event_dim
self.context_nn = context_nn
def condition(self, context):
loc, log_scale = self.context_nn(context)
return torch.distributions.transforms.AffineTransform(
loc, log_scale.exp(), event_dim=self.event_dim
)
class MLP(nn.Module):
def __init__(self, num_inputs=1, width=32, num_outputs=1):
super().__init__()
activation = nn.LeakyReLU()
self.mlp = nn.Sequential(
nn.Linear(num_inputs, width, bias=False),
nn.BatchNorm1d(width),
activation,
nn.Linear(width, width, bias=False),
nn.BatchNorm1d(width),
activation,
nn.Linear(width, num_outputs),
)
def forward(self, x):
return self.mlp(x)
class CNN(nn.Module):
def __init__(self, in_shape=(1, 192, 192), width=16, num_outputs=1, context_dim=0):
super().__init__()
in_channels = in_shape[0]
res = in_shape[1]
s = 2 if res > 64 else 1
activation = nn.LeakyReLU()
self.cnn = nn.Sequential(
nn.Conv2d(in_channels, width, 7, s, 3, bias=False),
nn.BatchNorm2d(width),
activation,
(nn.MaxPool2d(2, 2) if res > 32 else nn.Identity()),
nn.Conv2d(width, 2 * width, 3, 2, 1, bias=False),
nn.BatchNorm2d(2 * width),
activation,
nn.Conv2d(2 * width, 2 * width, 3, 1, 1, bias=False),
nn.BatchNorm2d(2 * width),
activation,
nn.Conv2d(2 * width, 4 * width, 3, 2, 1, bias=False),
nn.BatchNorm2d(4 * width),
activation,
nn.Conv2d(4 * width, 4 * width, 3, 1, 1, bias=False),
nn.BatchNorm2d(4 * width),
activation,
nn.Conv2d(4 * width, 8 * width, 3, 2, 1, bias=False),
nn.BatchNorm2d(8 * width),
activation,
)
self.fc = nn.Sequential(
nn.Linear(8 * width + context_dim, 8 * width, bias=False),
nn.BatchNorm1d(8 * width),
activation,
nn.Linear(8 * width, num_outputs),
)
def forward(self, x, y=None):
x = self.cnn(x)
x = x.mean(dim=(-2, -1)) # avg pool
if y is not None:
x = torch.cat([x, y], dim=-1)
return self.fc(x)
class ArgMaxGumbelMax(Transform):
r"""ArgMax as Transform, but inv conditioned on logits"""
def __init__(self, logits, event_dim=0, cache_size=0):
super(ArgMaxGumbelMax, self).__init__(cache_size=cache_size)
self.logits = logits
self._event_dim = event_dim
self._categorical = pyro.distributions.torch.Categorical(
logits=self.logits
).to_event(0)
@property
def event_dim(self):
return self._event_dim
def __call__(self, gumbels):
"""
Computes the forward transform
"""
assert self.logits != None, "Logits not defined."
if self._cache_size == 0:
return self._call(gumbels)
y = self._call(gumbels)
return y
def _call(self, gumbels):
"""
Abstract method to compute forward transformation.
"""
assert self.logits != None, "Logits not defined."
y = gumbels + self.logits
# print(f'y: {y}')
# print(f'logits: {self.logits}')
return y.argmax(-1, keepdim=True)
@property
def domain(self):
""" "
Domain of input(gumbel variables), Real
"""
if self.event_dim == 0:
return constraints.real
return constraints.independent(constraints.real, self.event_dim)
@property
def codomain(self):
""" "
Domain of output(categorical variables), should be natural numbers, but set to Real for now
"""
if self.event_dim == 0:
return constraints.real
return constraints.independent(constraints.real, self.event_dim)
def inv(self, k):
"""Infer the gumbels noises given k and logits."""
assert self.logits != None, "Logits not defined."
uniforms = torch.rand(
self.logits.shape, dtype=self.logits.dtype, device=self.logits.device
)
gumbels = -((-(uniforms.log())).log())
# print(f'gumbels: {gumbels.size()}, {gumbels.dtype}')
# (batch_size, num_classes) mask to select kth class
# print(f'k : {k.size()}')
mask = F.one_hot(
k.squeeze(-1).to(torch.int64), num_classes=self.logits.shape[-1]
)
# print(f'mask: {mask.size()}, {mask.dtype}')
# (batch_size, 1) select topgumbel for truncation of other classes
topgumbel = (mask * gumbels).sum(dim=-1, keepdim=True) - (
mask * self.logits
).sum(dim=-1, keepdim=True)
mask = 1 - mask # invert mask to select other != k classes
g = gumbels + self.logits
# (batch_size, num_classes)
epsilons = -torch.log(mask * torch.exp(-g) + torch.exp(-topgumbel)) - (
mask * self.logits
)
return epsilons
def log_abs_det_jacobian(self, x, y):
"""We use the log_abs_det_jacobian to account for the categorical prob
x: Gumbels; y: argmax(x+logits)
P(y) = softmax
"""
# print(f"logits: {torch.log(F.softmax(self.logits, dim=-1)).size()}")
# print(f'y: {y.size()} ')
# print(f"log_abs_det_jacobian: {self._categorical.log_prob(y.squeeze(-1)).unsqueeze(-1).size()}")
return -self._categorical.log_prob(y.squeeze(-1)).unsqueeze(-1)
class ConditionalGumbelMax(ConditionalTransformModule):
r"""Given gumbels+logits, output the OneHot Categorical"""
def __init__(self, context_nn, event_dim=0, **kwargs):
# The logits_nn which predict the logits given ages:
super().__init__(**kwargs)
self.context_nn = context_nn
self.event_dim = event_dim
def condition(self, context):
"""Given context (age), output the Categorical results"""
logits = self.context_nn(
context
) # The logits for calculating argmax(Gumbel + logits)
return ArgMaxGumbelMax(logits)
def _logits(self, context):
"""Return logits given context"""
return self.context_nn(context)
@property
def domain(self):
""" "
Domain of input(gumbel variables), Real
"""
if self.event_dim == 0:
return constraints.real
return constraints.independent(constraints.real, self.event_dim)
@property
def codomain(self):
""" "
Domain of output(categorical variables), should be natural numbers, but set to Real for now
"""
if self.event_dim == 0:
return constraints.real
return constraints.independent(constraints.real, self.event_dim)
class TransformedDistributionGumbelMax(TransformedDistribution, TorchDistributionMixin):
r"""Define a TransformedDistribution class for Gumbel max"""
arg_constraints: Dict[str, constraints.Constraint] = {}
def log_prob(self, value):
"""
We do not use the log_prob() of the base Gumbel distribution, because the likelihood for
each class for Gumbel Max sampling is determined by the logits.
"""
# print("This happens")
if self._validate_args:
self._validate_sample(value)
event_dim = len(self.event_shape)
log_prob = 0.0
y = value
for transform in reversed(self.transforms):
x = transform.inv(y)
event_dim += transform.domain.event_dim - transform.codomain.event_dim
log_prob = log_prob - _sum_rightmost(
transform.log_abs_det_jacobian(x, y),
event_dim - transform.domain.event_dim,
)
y = x
# print(f"log_prob: {log_prob.size()}")
return log_prob
class ConditionalTransformedDistributionGumbelMax(ConditionalTransformedDistribution):
def condition(self, context):
base_dist = self.base_dist.condition(context)
transforms = [t.condition(context) for t in self.transforms]
# return TransformedDistribution(base_dist, transforms)
return TransformedDistributionGumbelMax(base_dist, transforms)
def clear_cache(self):
pass