counterfactuals-demo / pgm /flow_pgm.py
fabio-deep
added links
146a6ea
from typing import Dict
import numpy as np
import torch
from torch import nn, Tensor
import torch.nn.functional as F
import pyro
import pyro.distributions as dist
import pyro.distributions.transforms as T
from pyro.nn import DenseNN
from pyro.infer.reparam.transform import TransformReparam
from pyro.distributions.conditional import ConditionalTransformedDistribution
from .layers import (
ConditionalTransformedDistributionGumbelMax,
ConditionalGumbelMax,
ConditionalAffineTransform,
MLP,
CNN,
)
class Hparams:
def update(self, dict):
for k, v in dict.items():
setattr(self, k, v)
class BasePGM(nn.Module):
def __init__(self):
super().__init__()
def scm(self, *args, **kwargs):
def config(msg):
if isinstance(msg["fn"], dist.TransformedDistribution):
return TransformReparam()
else:
return None
return pyro.poutine.reparam(self.model, config=config)(*args, **kwargs)
def sample_scm(self, n_samples: int = 1):
with pyro.plate("obs", n_samples):
samples = self.scm()
return samples
def sample(self, n_samples: int = 1):
with pyro.plate("obs", n_samples):
samples = self.model() # NOTE: not ideal as model is defined in child class
return samples
def infer_exogeneous(self, obs: Dict[str, Tensor]) -> Dict[str, Tensor]:
batch_size = list(obs.values())[0].shape[0]
# assuming that we use transformed distributions for everything:
cond_model = pyro.condition(self.sample, data=obs)
cond_trace = pyro.poutine.trace(cond_model).get_trace(batch_size)
output = {}
for name, node in cond_trace.nodes.items():
if "z" in name or "fn" not in node.keys():
continue
fn = node["fn"]
if isinstance(fn, dist.Independent):
fn = fn.base_dist
if isinstance(fn, dist.TransformedDistribution):
# compute exogenous base dist (created with TransformReparam) at all sites
output[name + "_base"] = T.ComposeTransform(fn.transforms).inv(
node["value"]
)
return output
def counterfactual(
self,
obs: Dict[str, Tensor],
intervention: Dict[str, Tensor],
num_particles: int = 1,
detach: bool = True,
) -> Dict[str, Tensor]:
# NOTE: not ideal as "variables" is defined in child class
dag_variables = self.variables.keys()
assert set(obs.keys()) == set(dag_variables)
avg_cfs = {k: torch.zeros_like(obs[k]) for k in obs.keys()}
batch_size = list(obs.values())[0].shape[0]
for _ in range(num_particles):
# Abduction
exo_noise = self.infer_exogeneous(obs)
exo_noise = {k: v.detach() if detach else v for k, v in exo_noise.items()}
# condition on root node variables (no exogeneous noise available)
for k in dag_variables:
if k not in intervention.keys():
if k not in [i.split("_base")[0] for i in exo_noise.keys()]:
exo_noise[k] = obs[k]
# Abducted SCM
abducted_scm = pyro.poutine.condition(self.sample_scm, data=exo_noise)
# Action
counterfactual_scm = pyro.poutine.do(abducted_scm, data=intervention)
# Prediction
counterfactuals = counterfactual_scm(batch_size)
if hasattr(self, "discrete_variables"): # hack for MIMIC
# Check if we should change "finding", i.e. if its parents and/or
# itself are not intervened on, then we use its observed value.
# This is used due to stochastic abduction of discrete variables
if (
"age" not in intervention.keys()
and "finding" not in intervention.keys()
):
counterfactuals["finding"] = obs["finding"]
for k, v in counterfactuals.items():
avg_cfs[k] += v / num_particles
return avg_cfs
class FlowPGM(BasePGM):
def __init__(self, args: Hparams):
super().__init__()
self.variables = {
"sex": "binary",
"mri_seq": "binary",
"age": "continuous",
"brain_volume": "continuous",
"ventricle_volume": "continuous",
}
# priors: s, m, a, b and v
self.s_logit = nn.Parameter(torch.zeros(1))
self.m_logit = nn.Parameter(torch.zeros(1))
for k in ["a", "b", "v"]:
self.register_buffer(f"{k}_base_loc", torch.zeros(1))
self.register_buffer(f"{k}_base_scale", torch.ones(1))
# constraint, assumes data is [-1,1] normalized
# normalize_transform = T.ComposeTransform([
# T.AffineTransform(loc=0, scale=2), T.SigmoidTransform(), T.AffineTransform(loc=-1, scale=2)])
# normalize_transform = T.ComposeTransform([T.TanhTransform(cache_size=1)])
# normalize_transform = T.ComposeTransform([T.AffineTransform(loc=0, scale=1)])
# age flow
self.age_module = T.ComposeTransformModule(
[T.Spline(1, count_bins=4, order="linear")]
)
self.age_flow = T.ComposeTransform([self.age_module])
# self.age_module, normalize_transform])
# brain volume (conditional) flow: (sex, age) -> brain_vol
bvol_net = DenseNN(2, args.widths, [1, 1], nonlinearity=nn.LeakyReLU(0.1))
self.bvol_flow = ConditionalAffineTransform(context_nn=bvol_net, event_dim=0)
# self.bvol_flow = [self.bvol_flow, normalize_transform]
# ventricle volume (conditional) flow: (brain_vol, age) -> ventricle_vol
vvol_net = DenseNN(2, args.widths, [1, 1], nonlinearity=nn.LeakyReLU(0.1))
self.vvol_flow = ConditionalAffineTransform(context_nn=vvol_net, event_dim=0)
# self.vvol_flow = [self.vvol_transf, normalize_transform]
# if args.setup != 'sup_pgm':
# anticausal predictors
input_shape = (args.input_channels, args.input_res, args.input_res)
# q(s | x, b) = Bernoulli(f(x,b))
self.encoder_s = CNN(input_shape, num_outputs=1, context_dim=1)
# q(m | x) = Bernoulli(f(x))
self.encoder_m = CNN(input_shape, num_outputs=1)
# q(a | b, v) = Normal(mu(b, v), sigma(b, v))
self.encoder_a = MLP(num_inputs=2, num_outputs=2)
# q(b | x, v) = Normal(mu(x, v), sigma(x, v))
self.encoder_b = CNN(input_shape, num_outputs=2, context_dim=1)
# q(v | x) = Normal(mu(x), sigma(x))
self.encoder_v = CNN(input_shape, num_outputs=2)
self.f = (
lambda x: args.std_fixed * torch.ones_like(x)
if args.std_fixed > 0
else F.softplus(x)
)
def model(self) -> Dict[str, Tensor]:
# p(s), sex dist
ps = dist.Bernoulli(logits=self.s_logit).to_event(1)
sex = pyro.sample("sex", ps)
# p(m), mri_seq dist
pm = dist.Bernoulli(logits=self.m_logit).to_event(1)
mri_seq = pyro.sample("mri_seq", pm)
# p(a), age flow
pa_base = dist.Normal(self.a_base_loc, self.a_base_scale).to_event(1)
pa = dist.TransformedDistribution(pa_base, self.age_flow)
age = pyro.sample("age", pa)
# p(b | s, a), brain volume flow
pb_sa_base = dist.Normal(self.b_base_loc, self.b_base_scale).to_event(1)
pb_sa = ConditionalTransformedDistribution(
pb_sa_base, [self.bvol_flow]
).condition(torch.cat([sex, age], dim=1))
bvol = pyro.sample("brain_volume", pb_sa)
# _ = self.bvol_transf # register with pyro
# p(v | b, a), ventricle volume flow
pv_ba_base = dist.Normal(self.v_base_loc, self.v_base_scale).to_event(1)
pv_ba = ConditionalTransformedDistribution(
pv_ba_base, [self.vvol_flow]
).condition(torch.cat([bvol, age], dim=1))
vvol = pyro.sample("ventricle_volume", pv_ba)
# _ = self.vvol_transf # register with pyro
return {
"sex": sex,
"mri_seq": mri_seq,
"age": age,
"brain_volume": bvol,
"ventricle_volume": vvol,
}
def guide(self, **obs) -> None:
# guide for (optional) semi-supervised learning
pyro.module("FlowPGM", self)
with pyro.plate("observations", obs["x"].shape[0]):
# q(m | x)
if obs["mri_seq"] is None:
m_prob = torch.sigmoid(self.encoder_m(obs["x"]))
m = pyro.sample("mri_seq", dist.Bernoulli(probs=m_prob).to_event(1))
# q(v | x)
if obs["ventricle_volume"] is None:
v_loc, v_logscale = self.encoder_v(obs["x"]).chunk(2, dim=-1)
qv_x = dist.Normal(v_loc, self.f(v_logscale)).to_event(1)
obs["ventricle_volume"] = pyro.sample("ventricle_volume", qv_x)
# q(b | x, v)
if obs["brain_volume"] is None:
b_loc, b_logscale = self.encoder_b(
obs["x"], y=obs["ventricle_volume"]
).chunk(2, dim=-1)
qb_xv = dist.Normal(b_loc, self.f(b_logscale)).to_event(1)
obs["brain_volume"] = pyro.sample("brain_volume", qb_xv)
# q(s | x, b)
if obs["sex"] is None:
s_prob = torch.sigmoid(
self.encoder_s(obs["x"], y=obs["brain_volume"])
) # .squeeze()
pyro.sample("sex", dist.Bernoulli(probs=s_prob).to_event(1))
# q(a | b, v)
if obs["age"] is None:
ctx = torch.cat([obs["brain_volume"], obs["ventricle_volume"]], dim=-1)
a_loc, a_logscale = self.encoder_a(ctx).chunk(2, dim=-1)
pyro.sample("age", dist.Normal(a_loc, self.f(a_logscale)).to_event(1))
def model_anticausal(self, **obs) -> None:
# assumes all variables are observed
pyro.module("FlowPGM", self)
with pyro.plate("observations", obs["x"].shape[0]):
# q(v | x)
v_loc, v_logscale = self.encoder_v(obs["x"]).chunk(2, dim=-1)
qv_x = dist.Normal(v_loc, self.f(v_logscale)).to_event(1)
pyro.sample("ventricle_volume_aux", qv_x, obs=obs["ventricle_volume"])
# q(b | x, v)
b_loc, b_logscale = self.encoder_b(
obs["x"], y=obs["ventricle_volume"]
).chunk(2, dim=-1)
qb_xv = dist.Normal(b_loc, self.f(b_logscale)).to_event(1)
pyro.sample("brain_volume_aux", qb_xv, obs=obs["brain_volume"])
# q(a | b, v)
ctx = torch.cat([obs["brain_volume"], obs["ventricle_volume"]], dim=-1)
a_loc, a_logscale = self.encoder_a(ctx).chunk(2, dim=-1)
pyro.sample(
"age_aux",
dist.Normal(a_loc, self.f(a_logscale)).to_event(1),
obs=obs["age"],
)
# q(s | x, b)
s_prob = torch.sigmoid(self.encoder_s(obs["x"], y=obs["brain_volume"]))
qs_xb = dist.Bernoulli(probs=s_prob).to_event(1)
pyro.sample("sex_aux", qs_xb, obs=obs["sex"])
# q(m | x)
m_prob = torch.sigmoid(self.encoder_m(obs["x"]))
qm_x = dist.Bernoulli(probs=m_prob).to_event(1)
pyro.sample("mri_seq_aux", qm_x, obs=obs["mri_seq"])
def predict(self, **obs) -> Dict[str, Tensor]:
# q(v | x)
v_loc, v_logscale = self.encoder_v(obs["x"]).chunk(2, dim=-1)
# v_loc = torch.tanh(v_loc)
# q(b | x, v)
b_loc, b_logscale = self.encoder_b(obs["x"], y=obs["ventricle_volume"]).chunk(
2, dim=-1
)
# b_loc = torch.tanh(b_loc)
# q(a | b, v)
ctx = torch.cat([obs["brain_volume"], obs["ventricle_volume"]], dim=-1)
a_loc, a_logscale = self.encoder_a(ctx).chunk(2, dim=-1)
# a_loc = torch.tanh(b_loc)
# q(s | x, b)
s_prob = torch.sigmoid(self.encoder_s(obs["x"], y=obs["brain_volume"]))
# q(m | x)
m_prob = torch.sigmoid(self.encoder_m(obs["x"]))
return {
"sex": s_prob,
"mri_seq": m_prob,
"age": a_loc,
"brain_volume": b_loc,
"ventricle_volume": v_loc,
}
def svi_model(self, **obs) -> None:
with pyro.plate("observations", obs["x"].shape[0]):
pyro.condition(self.model, data=obs)()
def guide_pass(self, **obs) -> None:
pass
class MorphoMNISTPGM(BasePGM):
def __init__(self, args):
super().__init__()
self.variables = {
"thickness": "continuous",
"intensity": "continuous",
"digit": "categorical",
}
# priors
self.digit_logits = nn.Parameter(torch.zeros(1, 10)) # uniform prior
for k in ["t", "i"]: # thickness, intensity, standard Gaussian
self.register_buffer(f"{k}_base_loc", torch.zeros(1))
self.register_buffer(f"{k}_base_scale", torch.ones(1))
# constraint, assumes data is [-1,1] normalized
normalize_transform = T.ComposeTransform(
[T.SigmoidTransform(), T.AffineTransform(loc=-1, scale=2)]
)
# thickness flow
self.thickness_module = T.ComposeTransformModule(
[T.Spline(1, count_bins=4, order="linear")]
)
self.thickness_flow = T.ComposeTransform(
[self.thickness_module, normalize_transform]
)
# intensity (conditional) flow: thickness -> intensity
intensity_net = DenseNN(1, args.widths, [1, 1], nonlinearity=nn.GELU())
self.context_nn = ConditionalAffineTransform(
context_nn=intensity_net, event_dim=0
)
self.intensity_flow = [self.context_nn, normalize_transform]
if args.setup != "sup_pgm":
# anticausal predictors
input_shape = (args.input_channels, args.input_res, args.input_res)
# q(t | x, i) = Normal(mu(x, i), sigma(x, i)), 2 outputs: loc & scale
self.encoder_t = CNN(input_shape, num_outputs=2, context_dim=1, width=8)
# q(i | x) = Normal(mu(x), sigma(x))
self.encoder_i = CNN(input_shape, num_outputs=2, width=8)
# q(y | x) = Categorical(pi(x))
self.encoder_y = CNN(input_shape, num_outputs=10, width=8)
self.f = (
lambda x: args.std_fixed * torch.ones_like(x)
if args.std_fixed > 0
else F.softplus(x)
)
def model(self) -> Dict[str, Tensor]:
pyro.module("MorphoMNISTPGM", self)
# p(y), digit label prior dist
py = dist.OneHotCategorical(
probs=F.softmax(self.digit_logits, dim=-1)
) # .to_event(1)
# with pyro.poutine.scale(scale=0.05):
digit = pyro.sample("digit", py)
# p(t), thickness flow
pt_base = dist.Normal(self.t_base_loc, self.t_base_scale).to_event(1)
pt = dist.TransformedDistribution(pt_base, self.thickness_flow)
thickness = pyro.sample("thickness", pt)
# p(i | t), intensity conditional flow
pi_t_base = dist.Normal(self.i_base_loc, self.i_base_scale).to_event(1)
pi_t = ConditionalTransformedDistribution(
pi_t_base, self.intensity_flow
).condition(thickness)
intensity = pyro.sample("intensity", pi_t)
_ = self.context_nn
return {"thickness": thickness, "intensity": intensity, "digit": digit}
def guide(self, **obs) -> None:
# guide for (optional) semi-supervised learning
with pyro.plate("observations", obs["x"].shape[0]):
# q(i | x)
if obs["intensity"] is None:
i_loc, i_logscale = self.encoder_i(obs["x"]).chunk(2, dim=-1)
qi_t = dist.Normal(torch.tanh(i_loc), self.f(i_logscale)).to_event(1)
obs["intensity"] = pyro.sample("intensity", qi_t)
# q(t | x, i)
if obs["thickness"] is None:
t_loc, t_logscale = self.encoder_t(obs["x"], y=obs["intensity"]).chunk(
2, dim=-1
)
qt_x = dist.Normal(torch.tanh(t_loc), self.f(t_logscale)).to_event(1)
obs["thickness"] = pyro.sample("thickness", qt_x)
# q(y | x)
if obs["digit"] is None:
y_prob = F.softmax(self.encoder_y(obs["x"]), dim=-1)
qy_x = dist.OneHotCategorical(probs=y_prob) # .to_event(1)
pyro.sample("digit", qy_x)
def model_anticausal(self, **obs) -> None:
# assumes all variables are observed & continuous ones are in [-1,1]
pyro.module("MorphoMNISTPGM", self)
with pyro.plate("observations", obs["x"].shape[0]):
# q(t | x, i)
t_loc, t_logscale = self.encoder_t(obs["x"], y=obs["intensity"]).chunk(
2, dim=-1
)
qt_x = dist.Normal(torch.tanh(t_loc), self.f(t_logscale)).to_event(1)
pyro.sample("thickness_aux", qt_x, obs=obs["thickness"])
# q(i | x)
i_loc, i_logscale = self.encoder_i(obs["x"]).chunk(2, dim=-1)
qi_t = dist.Normal(torch.tanh(i_loc), self.f(i_logscale)).to_event(1)
pyro.sample("intensity_aux", qi_t, obs=obs["intensity"])
# q(y | x)
y_prob = F.softmax(self.encoder_y(obs["x"]), dim=-1)
qy_x = dist.OneHotCategorical(probs=y_prob) # .to_event(1)
pyro.sample("digit_aux", qy_x, obs=obs["digit"])
def predict(self, **obs) -> Dict[str, Tensor]:
# q(t | x, i)
t_loc, t_logscale = self.encoder_t(obs["x"], y=obs["intensity"]).chunk(
2, dim=-1
)
t_loc = torch.tanh(t_loc)
# q(i | x)
i_loc, i_logscale = self.encoder_i(obs["x"]).chunk(2, dim=-1)
i_loc = torch.tanh(i_loc)
# q(y | x)
y_prob = F.softmax(self.encoder_y(obs["x"]), dim=-1)
return {"thickness": t_loc, "intensity": i_loc, "digit": y_prob}
def svi_model(self, **obs) -> None:
with pyro.plate("observations", obs["x"].shape[0]):
pyro.condition(self.model, data=obs)()
def guide_pass(self, **obs) -> None:
pass
class ChestPGM(BasePGM):
def __init__(self, args: Hparams):
super().__init__()
self.variables = {
"race": "categorical",
"sex": "binary",
"finding": "binary",
"age": "continuous",
}
# Discrete variables that are not root nodes
self.discrete_variables = {"finding": "binary"}
# define base distributions
for k in ["a"]: # , "f"]:
self.register_buffer(f"{k}_base_loc", torch.zeros(1))
self.register_buffer(f"{k}_base_scale", torch.ones(1))
# age spline flow
self.age_flow_components = T.ComposeTransformModule([T.Spline(1)])
# self.age_constraints = T.ComposeTransform([
# T.AffineTransform(loc=4.09541458484, scale=0.32548387126),
# T.ExpTransform()])
self.age_flow = T.ComposeTransform(
[
self.age_flow_components,
# self.age_constraints,
]
)
# Finding (conditional) via MLP, a -> f
finding_net = DenseNN(1, [8, 16], param_dims=[2], nonlinearity=nn.Sigmoid())
self.finding_transform_GumbelMax = ConditionalGumbelMax(
context_nn=finding_net, event_dim=0
)
# log space for sex and race
self.sex_logit = nn.Parameter(np.log(1 / 2) * torch.ones(1))
self.race_logits = nn.Parameter(np.log(1 / 3) * torch.ones(1, 3))
input_shape = (args.input_channels, args.input_res, args.input_res)
if args.enc_net == "cnn":
# q(s | x) ~ Bernoulli(f(x))
self.encoder_s = CNN(input_shape, num_outputs=1)
# q(r | x) ~ OneHotCategorical(logits=f(x))
self.encoder_r = CNN(input_shape, num_outputs=3)
# q(f | x) ~ Bernoulli(f(x))
self.encoder_f = CNN(input_shape, num_outputs=1)
# q(a | x, f) ~ Normal(mu(x), sigma(x))
self.encoder_a = CNN(input_shape, num_outputs=1, context_dim=1)
def model(self) -> Dict[str, Tensor]:
pyro.module("ChestPGM", self)
# p(s), sex dist
ps = dist.Bernoulli(logits=self.sex_logit).to_event(1)
sex = pyro.sample("sex", ps)
# p(a), age flow
pa_base = dist.Normal(self.a_base_loc, self.a_base_scale).to_event(1)
pa = dist.TransformedDistribution(pa_base, self.age_flow)
age = pyro.sample("age", pa)
# age_ = self.age_constraints.inv(age)
_ = self.age_flow_components # register with pyro
# p(r), race dist
pr = dist.OneHotCategorical(logits=self.race_logits) # .to_event(1)
race = pyro.sample("race", pr)
# p(f | a), finding as OneHotCategorical conditioned on age
# finding_dist_base = dist.Gumbel(self.f_base_loc, self.f_base_scale).to_event(1)
finding_dist_base = dist.Gumbel(torch.zeros(1), torch.ones(1)).to_event(1)
finding_dist = ConditionalTransformedDistributionGumbelMax(
finding_dist_base, [self.finding_transform_GumbelMax]
).condition(age)
finding = pyro.sample("finding", finding_dist)
return {
"sex": sex,
"race": race,
"age": age,
"finding": finding,
}
def guide(self, **obs) -> None:
with pyro.plate("observations", obs["x"].shape[0]):
# q(s | x)
if obs["sex"] is None:
s_prob = torch.sigmoid(self.encoder_s(obs["x"]))
pyro.sample("sex", dist.Bernoulli(probs=s_prob).to_event(1))
# q(r | x)
if obs["race"] is None:
r_probs = F.softmax(self.encoder_r(obs["x"]), dim=-1)
qr_x = dist.OneHotCategorical(probs=r_probs) # .to_event(1)
pyro.sample("race", qr_x)
# q(f | x)
if obs["finding"] is None:
f_prob = torch.sigmoid(self.encoder_f(obs["x"]))
qf_x = dist.Bernoulli(probs=f_prob).to_event(1)
obs["finding"] = pyro.sample("finding", qf_x)
# q(a | x, f)
if obs["age"] is None:
a_loc, a_logscale = self.encoder_a(obs["x"], y=obs["finding"]).chunk(
2, dim=-1
)
qa_xf = dist.Normal(a_loc, self.f(a_logscale)).to_event(1)
pyro.sample("age_aux", qa_xf)
def model_anticausal(self, **obs) -> None:
# assumes all variables are observed, train classfiers
pyro.module("ChestPGM", self)
with pyro.plate("observations", obs["x"].shape[0]):
# q(s | x)
s_prob = torch.sigmoid(self.encoder_s(obs["x"]))
qs_x = dist.Bernoulli(probs=s_prob).to_event(1)
# with pyro.poutine.scale(scale=0.8):
pyro.sample("sex_aux", qs_x, obs=obs["sex"])
# q(r | x)
r_probs = F.softmax(self.encoder_r(obs["x"]), dim=-1)
qr_x = dist.OneHotCategorical(probs=r_probs) # .to_event(1)
# with pyro.poutine.scale(scale=0.5):
pyro.sample("race_aux", qr_x, obs=obs["race"])
# q(f | x)
f_prob = torch.sigmoid(self.encoder_f(obs["x"]))
qf_x = dist.Bernoulli(probs=f_prob).to_event(1)
pyro.sample("finding_aux", qf_x, obs=obs["finding"])
# q(a | x, f)
a_loc, a_logscale = self.encoder_a(obs["x"], y=obs["finding"]).chunk(
2, dim=-1
)
qa_xf = dist.Normal(a_loc, self.f(a_logscale)).to_event(1)
# with pyro.poutine.scale(scale=2):
pyro.sample("age_aux", qa_xf, obs=obs["age"])
def predict(self, **obs) -> Dict[str, Tensor]:
# q(s | x)
s_prob = torch.sigmoid(self.encoder_s(obs["x"]))
# q(r | x)
r_probs = F.softmax(self.encoder_r(obs["x"]), dim=-1)
# q(f | x)
f_prob = torch.sigmoid(self.encoder_f(obs["x"]))
# q(a | x, f)
a_loc, _ = self.encoder_a(obs["x"], y=obs["finding"]).chunk(2, dim=-1)
return {
"sex": s_prob,
"race": r_probs,
"finding": f_prob,
"age": a_loc,
}
def svi_model(self, **obs) -> None:
with pyro.plate("observations", obs["x"].shape[0]):
pyro.condition(self.model, data=obs)()
def guide_pass(self, **obs) -> None:
pass