from typing import Dict import numpy as np import torch from torch import nn, Tensor import torch.nn.functional as F import pyro import pyro.distributions as dist import pyro.distributions.transforms as T from pyro.nn import DenseNN from pyro.infer.reparam.transform import TransformReparam from pyro.distributions.conditional import ConditionalTransformedDistribution from .layers import ( ConditionalTransformedDistributionGumbelMax, ConditionalGumbelMax, ConditionalAffineTransform, MLP, CNN, ) class Hparams: def update(self, dict): for k, v in dict.items(): setattr(self, k, v) class BasePGM(nn.Module): def __init__(self): super().__init__() def scm(self, *args, **kwargs): def config(msg): if isinstance(msg["fn"], dist.TransformedDistribution): return TransformReparam() else: return None return pyro.poutine.reparam(self.model, config=config)(*args, **kwargs) def sample_scm(self, n_samples: int = 1): with pyro.plate("obs", n_samples): samples = self.scm() return samples def sample(self, n_samples: int = 1): with pyro.plate("obs", n_samples): samples = self.model() # NOTE: not ideal as model is defined in child class return samples def infer_exogeneous(self, obs: Dict[str, Tensor]) -> Dict[str, Tensor]: batch_size = list(obs.values())[0].shape[0] # assuming that we use transformed distributions for everything: cond_model = pyro.condition(self.sample, data=obs) cond_trace = pyro.poutine.trace(cond_model).get_trace(batch_size) output = {} for name, node in cond_trace.nodes.items(): if "z" in name or "fn" not in node.keys(): continue fn = node["fn"] if isinstance(fn, dist.Independent): fn = fn.base_dist if isinstance(fn, dist.TransformedDistribution): # compute exogenous base dist (created with TransformReparam) at all sites output[name + "_base"] = T.ComposeTransform(fn.transforms).inv( node["value"] ) return output def counterfactual( self, obs: Dict[str, Tensor], intervention: Dict[str, Tensor], num_particles: int = 1, detach: bool = True, ) -> Dict[str, Tensor]: # NOTE: not ideal as "variables" is defined in child class dag_variables = self.variables.keys() assert set(obs.keys()) == set(dag_variables) avg_cfs = {k: torch.zeros_like(obs[k]) for k in obs.keys()} batch_size = list(obs.values())[0].shape[0] for _ in range(num_particles): # Abduction exo_noise = self.infer_exogeneous(obs) exo_noise = {k: v.detach() if detach else v for k, v in exo_noise.items()} # condition on root node variables (no exogeneous noise available) for k in dag_variables: if k not in intervention.keys(): if k not in [i.split("_base")[0] for i in exo_noise.keys()]: exo_noise[k] = obs[k] # Abducted SCM abducted_scm = pyro.poutine.condition(self.sample_scm, data=exo_noise) # Action counterfactual_scm = pyro.poutine.do(abducted_scm, data=intervention) # Prediction counterfactuals = counterfactual_scm(batch_size) if hasattr(self, "discrete_variables"): # hack for MIMIC # Check if we should change "finding", i.e. if its parents and/or # itself are not intervened on, then we use its observed value. # This is used due to stochastic abduction of discrete variables if ( "age" not in intervention.keys() and "finding" not in intervention.keys() ): counterfactuals["finding"] = obs["finding"] for k, v in counterfactuals.items(): avg_cfs[k] += v / num_particles return avg_cfs class FlowPGM(BasePGM): def __init__(self, args: Hparams): super().__init__() self.variables = { "sex": "binary", "mri_seq": "binary", "age": "continuous", "brain_volume": "continuous", "ventricle_volume": "continuous", } # priors: s, m, a, b and v self.s_logit = nn.Parameter(torch.zeros(1)) self.m_logit = nn.Parameter(torch.zeros(1)) for k in ["a", "b", "v"]: self.register_buffer(f"{k}_base_loc", torch.zeros(1)) self.register_buffer(f"{k}_base_scale", torch.ones(1)) # constraint, assumes data is [-1,1] normalized # normalize_transform = T.ComposeTransform([ # T.AffineTransform(loc=0, scale=2), T.SigmoidTransform(), T.AffineTransform(loc=-1, scale=2)]) # normalize_transform = T.ComposeTransform([T.TanhTransform(cache_size=1)]) # normalize_transform = T.ComposeTransform([T.AffineTransform(loc=0, scale=1)]) # age flow self.age_module = T.ComposeTransformModule( [T.Spline(1, count_bins=4, order="linear")] ) self.age_flow = T.ComposeTransform([self.age_module]) # self.age_module, normalize_transform]) # brain volume (conditional) flow: (sex, age) -> brain_vol bvol_net = DenseNN(2, args.widths, [1, 1], nonlinearity=nn.LeakyReLU(0.1)) self.bvol_flow = ConditionalAffineTransform(context_nn=bvol_net, event_dim=0) # self.bvol_flow = [self.bvol_flow, normalize_transform] # ventricle volume (conditional) flow: (brain_vol, age) -> ventricle_vol vvol_net = DenseNN(2, args.widths, [1, 1], nonlinearity=nn.LeakyReLU(0.1)) self.vvol_flow = ConditionalAffineTransform(context_nn=vvol_net, event_dim=0) # self.vvol_flow = [self.vvol_transf, normalize_transform] # if args.setup != 'sup_pgm': # anticausal predictors input_shape = (args.input_channels, args.input_res, args.input_res) # q(s | x, b) = Bernoulli(f(x,b)) self.encoder_s = CNN(input_shape, num_outputs=1, context_dim=1) # q(m | x) = Bernoulli(f(x)) self.encoder_m = CNN(input_shape, num_outputs=1) # q(a | b, v) = Normal(mu(b, v), sigma(b, v)) self.encoder_a = MLP(num_inputs=2, num_outputs=2) # q(b | x, v) = Normal(mu(x, v), sigma(x, v)) self.encoder_b = CNN(input_shape, num_outputs=2, context_dim=1) # q(v | x) = Normal(mu(x), sigma(x)) self.encoder_v = CNN(input_shape, num_outputs=2) self.f = ( lambda x: args.std_fixed * torch.ones_like(x) if args.std_fixed > 0 else F.softplus(x) ) def model(self) -> Dict[str, Tensor]: # p(s), sex dist ps = dist.Bernoulli(logits=self.s_logit).to_event(1) sex = pyro.sample("sex", ps) # p(m), mri_seq dist pm = dist.Bernoulli(logits=self.m_logit).to_event(1) mri_seq = pyro.sample("mri_seq", pm) # p(a), age flow pa_base = dist.Normal(self.a_base_loc, self.a_base_scale).to_event(1) pa = dist.TransformedDistribution(pa_base, self.age_flow) age = pyro.sample("age", pa) # p(b | s, a), brain volume flow pb_sa_base = dist.Normal(self.b_base_loc, self.b_base_scale).to_event(1) pb_sa = ConditionalTransformedDistribution( pb_sa_base, [self.bvol_flow] ).condition(torch.cat([sex, age], dim=1)) bvol = pyro.sample("brain_volume", pb_sa) # _ = self.bvol_transf # register with pyro # p(v | b, a), ventricle volume flow pv_ba_base = dist.Normal(self.v_base_loc, self.v_base_scale).to_event(1) pv_ba = ConditionalTransformedDistribution( pv_ba_base, [self.vvol_flow] ).condition(torch.cat([bvol, age], dim=1)) vvol = pyro.sample("ventricle_volume", pv_ba) # _ = self.vvol_transf # register with pyro return { "sex": sex, "mri_seq": mri_seq, "age": age, "brain_volume": bvol, "ventricle_volume": vvol, } def guide(self, **obs) -> None: # guide for (optional) semi-supervised learning pyro.module("FlowPGM", self) with pyro.plate("observations", obs["x"].shape[0]): # q(m | x) if obs["mri_seq"] is None: m_prob = torch.sigmoid(self.encoder_m(obs["x"])) m = pyro.sample("mri_seq", dist.Bernoulli(probs=m_prob).to_event(1)) # q(v | x) if obs["ventricle_volume"] is None: v_loc, v_logscale = self.encoder_v(obs["x"]).chunk(2, dim=-1) qv_x = dist.Normal(v_loc, self.f(v_logscale)).to_event(1) obs["ventricle_volume"] = pyro.sample("ventricle_volume", qv_x) # q(b | x, v) if obs["brain_volume"] is None: b_loc, b_logscale = self.encoder_b( obs["x"], y=obs["ventricle_volume"] ).chunk(2, dim=-1) qb_xv = dist.Normal(b_loc, self.f(b_logscale)).to_event(1) obs["brain_volume"] = pyro.sample("brain_volume", qb_xv) # q(s | x, b) if obs["sex"] is None: s_prob = torch.sigmoid( self.encoder_s(obs["x"], y=obs["brain_volume"]) ) # .squeeze() pyro.sample("sex", dist.Bernoulli(probs=s_prob).to_event(1)) # q(a | b, v) if obs["age"] is None: ctx = torch.cat([obs["brain_volume"], obs["ventricle_volume"]], dim=-1) a_loc, a_logscale = self.encoder_a(ctx).chunk(2, dim=-1) pyro.sample("age", dist.Normal(a_loc, self.f(a_logscale)).to_event(1)) def model_anticausal(self, **obs) -> None: # assumes all variables are observed pyro.module("FlowPGM", self) with pyro.plate("observations", obs["x"].shape[0]): # q(v | x) v_loc, v_logscale = self.encoder_v(obs["x"]).chunk(2, dim=-1) qv_x = dist.Normal(v_loc, self.f(v_logscale)).to_event(1) pyro.sample("ventricle_volume_aux", qv_x, obs=obs["ventricle_volume"]) # q(b | x, v) b_loc, b_logscale = self.encoder_b( obs["x"], y=obs["ventricle_volume"] ).chunk(2, dim=-1) qb_xv = dist.Normal(b_loc, self.f(b_logscale)).to_event(1) pyro.sample("brain_volume_aux", qb_xv, obs=obs["brain_volume"]) # q(a | b, v) ctx = torch.cat([obs["brain_volume"], obs["ventricle_volume"]], dim=-1) a_loc, a_logscale = self.encoder_a(ctx).chunk(2, dim=-1) pyro.sample( "age_aux", dist.Normal(a_loc, self.f(a_logscale)).to_event(1), obs=obs["age"], ) # q(s | x, b) s_prob = torch.sigmoid(self.encoder_s(obs["x"], y=obs["brain_volume"])) qs_xb = dist.Bernoulli(probs=s_prob).to_event(1) pyro.sample("sex_aux", qs_xb, obs=obs["sex"]) # q(m | x) m_prob = torch.sigmoid(self.encoder_m(obs["x"])) qm_x = dist.Bernoulli(probs=m_prob).to_event(1) pyro.sample("mri_seq_aux", qm_x, obs=obs["mri_seq"]) def predict(self, **obs) -> Dict[str, Tensor]: # q(v | x) v_loc, v_logscale = self.encoder_v(obs["x"]).chunk(2, dim=-1) # v_loc = torch.tanh(v_loc) # q(b | x, v) b_loc, b_logscale = self.encoder_b(obs["x"], y=obs["ventricle_volume"]).chunk( 2, dim=-1 ) # b_loc = torch.tanh(b_loc) # q(a | b, v) ctx = torch.cat([obs["brain_volume"], obs["ventricle_volume"]], dim=-1) a_loc, a_logscale = self.encoder_a(ctx).chunk(2, dim=-1) # a_loc = torch.tanh(b_loc) # q(s | x, b) s_prob = torch.sigmoid(self.encoder_s(obs["x"], y=obs["brain_volume"])) # q(m | x) m_prob = torch.sigmoid(self.encoder_m(obs["x"])) return { "sex": s_prob, "mri_seq": m_prob, "age": a_loc, "brain_volume": b_loc, "ventricle_volume": v_loc, } def svi_model(self, **obs) -> None: with pyro.plate("observations", obs["x"].shape[0]): pyro.condition(self.model, data=obs)() def guide_pass(self, **obs) -> None: pass class MorphoMNISTPGM(BasePGM): def __init__(self, args): super().__init__() self.variables = { "thickness": "continuous", "intensity": "continuous", "digit": "categorical", } # priors self.digit_logits = nn.Parameter(torch.zeros(1, 10)) # uniform prior for k in ["t", "i"]: # thickness, intensity, standard Gaussian self.register_buffer(f"{k}_base_loc", torch.zeros(1)) self.register_buffer(f"{k}_base_scale", torch.ones(1)) # constraint, assumes data is [-1,1] normalized normalize_transform = T.ComposeTransform( [T.SigmoidTransform(), T.AffineTransform(loc=-1, scale=2)] ) # thickness flow self.thickness_module = T.ComposeTransformModule( [T.Spline(1, count_bins=4, order="linear")] ) self.thickness_flow = T.ComposeTransform( [self.thickness_module, normalize_transform] ) # intensity (conditional) flow: thickness -> intensity intensity_net = DenseNN(1, args.widths, [1, 1], nonlinearity=nn.GELU()) self.context_nn = ConditionalAffineTransform( context_nn=intensity_net, event_dim=0 ) self.intensity_flow = [self.context_nn, normalize_transform] if args.setup != "sup_pgm": # anticausal predictors input_shape = (args.input_channels, args.input_res, args.input_res) # q(t | x, i) = Normal(mu(x, i), sigma(x, i)), 2 outputs: loc & scale self.encoder_t = CNN(input_shape, num_outputs=2, context_dim=1, width=8) # q(i | x) = Normal(mu(x), sigma(x)) self.encoder_i = CNN(input_shape, num_outputs=2, width=8) # q(y | x) = Categorical(pi(x)) self.encoder_y = CNN(input_shape, num_outputs=10, width=8) self.f = ( lambda x: args.std_fixed * torch.ones_like(x) if args.std_fixed > 0 else F.softplus(x) ) def model(self) -> Dict[str, Tensor]: pyro.module("MorphoMNISTPGM", self) # p(y), digit label prior dist py = dist.OneHotCategorical( probs=F.softmax(self.digit_logits, dim=-1) ) # .to_event(1) # with pyro.poutine.scale(scale=0.05): digit = pyro.sample("digit", py) # p(t), thickness flow pt_base = dist.Normal(self.t_base_loc, self.t_base_scale).to_event(1) pt = dist.TransformedDistribution(pt_base, self.thickness_flow) thickness = pyro.sample("thickness", pt) # p(i | t), intensity conditional flow pi_t_base = dist.Normal(self.i_base_loc, self.i_base_scale).to_event(1) pi_t = ConditionalTransformedDistribution( pi_t_base, self.intensity_flow ).condition(thickness) intensity = pyro.sample("intensity", pi_t) _ = self.context_nn return {"thickness": thickness, "intensity": intensity, "digit": digit} def guide(self, **obs) -> None: # guide for (optional) semi-supervised learning with pyro.plate("observations", obs["x"].shape[0]): # q(i | x) if obs["intensity"] is None: i_loc, i_logscale = self.encoder_i(obs["x"]).chunk(2, dim=-1) qi_t = dist.Normal(torch.tanh(i_loc), self.f(i_logscale)).to_event(1) obs["intensity"] = pyro.sample("intensity", qi_t) # q(t | x, i) if obs["thickness"] is None: t_loc, t_logscale = self.encoder_t(obs["x"], y=obs["intensity"]).chunk( 2, dim=-1 ) qt_x = dist.Normal(torch.tanh(t_loc), self.f(t_logscale)).to_event(1) obs["thickness"] = pyro.sample("thickness", qt_x) # q(y | x) if obs["digit"] is None: y_prob = F.softmax(self.encoder_y(obs["x"]), dim=-1) qy_x = dist.OneHotCategorical(probs=y_prob) # .to_event(1) pyro.sample("digit", qy_x) def model_anticausal(self, **obs) -> None: # assumes all variables are observed & continuous ones are in [-1,1] pyro.module("MorphoMNISTPGM", self) with pyro.plate("observations", obs["x"].shape[0]): # q(t | x, i) t_loc, t_logscale = self.encoder_t(obs["x"], y=obs["intensity"]).chunk( 2, dim=-1 ) qt_x = dist.Normal(torch.tanh(t_loc), self.f(t_logscale)).to_event(1) pyro.sample("thickness_aux", qt_x, obs=obs["thickness"]) # q(i | x) i_loc, i_logscale = self.encoder_i(obs["x"]).chunk(2, dim=-1) qi_t = dist.Normal(torch.tanh(i_loc), self.f(i_logscale)).to_event(1) pyro.sample("intensity_aux", qi_t, obs=obs["intensity"]) # q(y | x) y_prob = F.softmax(self.encoder_y(obs["x"]), dim=-1) qy_x = dist.OneHotCategorical(probs=y_prob) # .to_event(1) pyro.sample("digit_aux", qy_x, obs=obs["digit"]) def predict(self, **obs) -> Dict[str, Tensor]: # q(t | x, i) t_loc, t_logscale = self.encoder_t(obs["x"], y=obs["intensity"]).chunk( 2, dim=-1 ) t_loc = torch.tanh(t_loc) # q(i | x) i_loc, i_logscale = self.encoder_i(obs["x"]).chunk(2, dim=-1) i_loc = torch.tanh(i_loc) # q(y | x) y_prob = F.softmax(self.encoder_y(obs["x"]), dim=-1) return {"thickness": t_loc, "intensity": i_loc, "digit": y_prob} def svi_model(self, **obs) -> None: with pyro.plate("observations", obs["x"].shape[0]): pyro.condition(self.model, data=obs)() def guide_pass(self, **obs) -> None: pass class ChestPGM(BasePGM): def __init__(self, args: Hparams): super().__init__() self.variables = { "race": "categorical", "sex": "binary", "finding": "binary", "age": "continuous", } # Discrete variables that are not root nodes self.discrete_variables = {"finding": "binary"} # define base distributions for k in ["a"]: # , "f"]: self.register_buffer(f"{k}_base_loc", torch.zeros(1)) self.register_buffer(f"{k}_base_scale", torch.ones(1)) # age spline flow self.age_flow_components = T.ComposeTransformModule([T.Spline(1)]) # self.age_constraints = T.ComposeTransform([ # T.AffineTransform(loc=4.09541458484, scale=0.32548387126), # T.ExpTransform()]) self.age_flow = T.ComposeTransform( [ self.age_flow_components, # self.age_constraints, ] ) # Finding (conditional) via MLP, a -> f finding_net = DenseNN(1, [8, 16], param_dims=[2], nonlinearity=nn.Sigmoid()) self.finding_transform_GumbelMax = ConditionalGumbelMax( context_nn=finding_net, event_dim=0 ) # log space for sex and race self.sex_logit = nn.Parameter(np.log(1 / 2) * torch.ones(1)) self.race_logits = nn.Parameter(np.log(1 / 3) * torch.ones(1, 3)) input_shape = (args.input_channels, args.input_res, args.input_res) if args.enc_net == "cnn": # q(s | x) ~ Bernoulli(f(x)) self.encoder_s = CNN(input_shape, num_outputs=1) # q(r | x) ~ OneHotCategorical(logits=f(x)) self.encoder_r = CNN(input_shape, num_outputs=3) # q(f | x) ~ Bernoulli(f(x)) self.encoder_f = CNN(input_shape, num_outputs=1) # q(a | x, f) ~ Normal(mu(x), sigma(x)) self.encoder_a = CNN(input_shape, num_outputs=1, context_dim=1) def model(self) -> Dict[str, Tensor]: pyro.module("ChestPGM", self) # p(s), sex dist ps = dist.Bernoulli(logits=self.sex_logit).to_event(1) sex = pyro.sample("sex", ps) # p(a), age flow pa_base = dist.Normal(self.a_base_loc, self.a_base_scale).to_event(1) pa = dist.TransformedDistribution(pa_base, self.age_flow) age = pyro.sample("age", pa) # age_ = self.age_constraints.inv(age) _ = self.age_flow_components # register with pyro # p(r), race dist pr = dist.OneHotCategorical(logits=self.race_logits) # .to_event(1) race = pyro.sample("race", pr) # p(f | a), finding as OneHotCategorical conditioned on age # finding_dist_base = dist.Gumbel(self.f_base_loc, self.f_base_scale).to_event(1) finding_dist_base = dist.Gumbel(torch.zeros(1), torch.ones(1)).to_event(1) finding_dist = ConditionalTransformedDistributionGumbelMax( finding_dist_base, [self.finding_transform_GumbelMax] ).condition(age) finding = pyro.sample("finding", finding_dist) return { "sex": sex, "race": race, "age": age, "finding": finding, } def guide(self, **obs) -> None: with pyro.plate("observations", obs["x"].shape[0]): # q(s | x) if obs["sex"] is None: s_prob = torch.sigmoid(self.encoder_s(obs["x"])) pyro.sample("sex", dist.Bernoulli(probs=s_prob).to_event(1)) # q(r | x) if obs["race"] is None: r_probs = F.softmax(self.encoder_r(obs["x"]), dim=-1) qr_x = dist.OneHotCategorical(probs=r_probs) # .to_event(1) pyro.sample("race", qr_x) # q(f | x) if obs["finding"] is None: f_prob = torch.sigmoid(self.encoder_f(obs["x"])) qf_x = dist.Bernoulli(probs=f_prob).to_event(1) obs["finding"] = pyro.sample("finding", qf_x) # q(a | x, f) if obs["age"] is None: a_loc, a_logscale = self.encoder_a(obs["x"], y=obs["finding"]).chunk( 2, dim=-1 ) qa_xf = dist.Normal(a_loc, self.f(a_logscale)).to_event(1) pyro.sample("age_aux", qa_xf) def model_anticausal(self, **obs) -> None: # assumes all variables are observed, train classfiers pyro.module("ChestPGM", self) with pyro.plate("observations", obs["x"].shape[0]): # q(s | x) s_prob = torch.sigmoid(self.encoder_s(obs["x"])) qs_x = dist.Bernoulli(probs=s_prob).to_event(1) # with pyro.poutine.scale(scale=0.8): pyro.sample("sex_aux", qs_x, obs=obs["sex"]) # q(r | x) r_probs = F.softmax(self.encoder_r(obs["x"]), dim=-1) qr_x = dist.OneHotCategorical(probs=r_probs) # .to_event(1) # with pyro.poutine.scale(scale=0.5): pyro.sample("race_aux", qr_x, obs=obs["race"]) # q(f | x) f_prob = torch.sigmoid(self.encoder_f(obs["x"])) qf_x = dist.Bernoulli(probs=f_prob).to_event(1) pyro.sample("finding_aux", qf_x, obs=obs["finding"]) # q(a | x, f) a_loc, a_logscale = self.encoder_a(obs["x"], y=obs["finding"]).chunk( 2, dim=-1 ) qa_xf = dist.Normal(a_loc, self.f(a_logscale)).to_event(1) # with pyro.poutine.scale(scale=2): pyro.sample("age_aux", qa_xf, obs=obs["age"]) def predict(self, **obs) -> Dict[str, Tensor]: # q(s | x) s_prob = torch.sigmoid(self.encoder_s(obs["x"])) # q(r | x) r_probs = F.softmax(self.encoder_r(obs["x"]), dim=-1) # q(f | x) f_prob = torch.sigmoid(self.encoder_f(obs["x"])) # q(a | x, f) a_loc, _ = self.encoder_a(obs["x"], y=obs["finding"]).chunk(2, dim=-1) return { "sex": s_prob, "race": r_probs, "finding": f_prob, "age": a_loc, } def svi_model(self, **obs) -> None: with pyro.plate("observations", obs["x"].shape[0]): pyro.condition(self.model, data=obs)() def guide_pass(self, **obs) -> None: pass