counterfactuals / app_utils.py
fabio-deep
added links
146a6ea
import torch
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from PIL import Image
from matplotlib import rc, patches, colors
rc("font", **{"family": "serif", "serif": ["Roman"]})
rc("text", usetex=True)
rc("image", interpolation="none")
rc("text.latex", preamble=r"\usepackage{amsmath} \usepackage{amssymb}")
from datasets import get_attr_max_min
HAMMER = np.array(Image.open("./hammer.png").resize((35, 35))) / 255
class MidpointNormalize(colors.Normalize):
def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False):
self.midpoint = midpoint
colors.Normalize.__init__(self, vmin, vmax, clip)
def __call__(self, value, clip=None):
v_ext = np.max([np.abs(self.vmin), np.abs(self.vmax)])
x, y = [-v_ext, self.midpoint, v_ext], [0, 0.5, 1]
return np.ma.masked_array(np.interp(value, x, y))
def postprocess(x):
return ((x + 1.0) * 127.5).squeeze().detach().cpu().numpy()
def mnist_graph(*args):
x, t, i, y = r"$\mathbf{x}$", r"$t$", r"$i$", r"$y$"
ut, ui, uy = r"$\mathbf{U}_t$", r"$\mathbf{U}_i$", r"$\mathbf{U}_y$"
zx, ex = r"$\mathbf{z}_{1:L}$", r"$\boldsymbol{\epsilon}$"
G = nx.DiGraph()
G.add_edge(t, x)
G.add_edge(i, x)
G.add_edge(y, x)
G.add_edge(t, i)
G.add_edge(ut, t)
G.add_edge(ui, i)
G.add_edge(uy, y)
G.add_edge(zx, x)
G.add_edge(ex, x)
pos = {
y: (0, 0),
uy: (-1, 0),
t: (0, 0.5),
ut: (0, 1),
x: (1, 0),
zx: (2, 0.375),
ex: (2, 0),
i: (1, 0.5),
ui: (1, 1),
}
node_c = {}
for node in G:
node_c[node] = "lightgrey" if node in [x, t, i, y] else "white"
node_line_c = {k: "black" for k, _ in node_c.items()}
edge_c = {e: "black" for e in G.edges}
if args[0]: # do_t
edge_c[(ut, t)] = "lightgrey"
# G.remove_edge(ut, t)
node_line_c[t] = "red"
if args[1]: # do_i
edge_c[(ui, i)] = "lightgrey"
edge_c[(t, i)] = "lightgrey"
# G.remove_edges_from([(ui, i), (t, i)])
node_line_c[i] = "red"
if args[2]: # do_y
edge_c[(uy, y)] = "lightgrey"
# G.remove_edge(uy, y)
node_line_c[y] = "red"
fs = 30
options = {
"font_size": fs,
"node_size": 3000,
"node_color": list(node_c.values()),
"edgecolors": list(node_line_c.values()),
"edge_color": list(edge_c.values()),
"linewidths": 2,
"width": 2,
}
plt.close("all")
fig, ax = plt.subplots(1, 1, figsize=(6, 4.1)) # , constrained_layout=True)
# fig.patch.set_visible(False)
ax.margins(x=0.06, y=0.15, tight=False)
ax.axis("off")
nx.draw_networkx(G, pos, **options, arrowsize=25, arrowstyle="-|>", ax=ax)
# need to reuse x, y limits so that the graphs plot the same way before and after removing edges
x_lim = (-1.348, 2.348)
y_lim = (-0.215, 1.215)
ax.set_xlim(x_lim)
ax.set_ylim(y_lim)
rect = patches.FancyBboxPatch(
(1.75, -0.16),
0.5,
0.7,
boxstyle="round, pad=0.05, rounding_size=0",
linewidth=2,
edgecolor="black",
facecolor="none",
linestyle="-",
)
ax.add_patch(rect)
ax.text(1.85, 0.65, r"$\mathbf{U}_{\mathbf{x}}$", fontsize=fs)
if args[0]: # do_t
fig.figimage(HAMMER, 0.26 * fig.bbox.xmax, 0.525 * fig.bbox.ymax, zorder=10)
if args[1]: # do_i
fig.figimage(HAMMER, 0.5175 * fig.bbox.xmax, 0.525 * fig.bbox.ymax, zorder=11)
if args[2]: # do_y
fig.figimage(HAMMER, 0.26 * fig.bbox.xmax, 0.2 * fig.bbox.ymax, zorder=12)
fig.tight_layout()
fig.canvas.draw()
return np.array(fig.canvas.renderer.buffer_rgba())
def brain_graph(*args):
x, m, s, a, b, v = r"$\mathbf{x}$", r"$m$", r"$s$", r"$a$", r"$b$", r"$v$"
um, us, ua, ub, uv = (
r"$\mathbf{U}_m$",
r"$\mathbf{U}_s$",
r"$\mathbf{U}_a$",
r"$\mathbf{U}_b$",
r"$\mathbf{U}_v$",
)
zx, ex = r"$\mathbf{z}_{1:L}$", r"$\boldsymbol{\epsilon}$"
G = nx.DiGraph()
G.add_edge(m, x)
G.add_edge(s, x)
G.add_edge(b, x)
G.add_edge(v, x)
G.add_edge(zx, x)
G.add_edge(ex, x)
G.add_edge(a, b)
G.add_edge(a, v)
G.add_edge(s, b)
G.add_edge(um, m)
G.add_edge(us, s)
G.add_edge(ua, a)
G.add_edge(ub, b)
G.add_edge(uv, v)
pos = {
x: (0, 0),
zx: (-0.25, -1),
ex: (0.25, -1),
a: (0, 1),
ua: (0, 2),
s: (1, 0),
us: (1, -1),
b: (1, 1),
ub: (1, 2),
m: (-1, 0),
um: (-1, -1),
v: (-1, 1),
uv: (-1, 2),
}
node_c = {}
for node in G:
node_c[node] = "lightgrey" if node in [x, m, s, a, b, v] else "white"
node_line_c = {k: "black" for k, _ in node_c.items()}
edge_c = {e: "black" for e in G.edges}
if args[0]: # do_m
# G.remove_edge(um, m)
edge_c[(um, m)] = "lightgrey"
node_line_c[m] = "red"
if args[1]: # do_s
# G.remove_edge(us, s)
edge_c[(us, s)] = "lightgrey"
node_line_c[s] = "red"
if args[2]: # do_a
# G.remove_edge(ua, a)
edge_c[(ua, a)] = "lightgrey"
node_line_c[a] = "red"
if args[3]: # do_b
# G.remove_edges_from([(ub, b), (s, b), (a, b)])
edge_c[(ub, b)] = "lightgrey"
edge_c[(s, b)] = "lightgrey"
edge_c[(a, b)] = "lightgrey"
node_line_c[b] = "red"
if args[4]: # do_v
# G.remove_edges_from([(uv, v), (a, v), (b, v)])
edge_c[(uv, v)] = "lightgrey"
edge_c[(a, v)] = "lightgrey"
edge_c[(b, v)] = "lightgrey"
node_line_c[v] = "red"
fs = 30
options = {
"font_size": fs,
"node_size": 3000,
"node_color": list(node_c.values()),
"edgecolors": list(node_line_c.values()),
"edge_color": list(edge_c.values()),
"linewidths": 2,
"width": 2,
}
plt.close("all")
fig, ax = plt.subplots(1, 1, figsize=(5, 5)) # , constrained_layout=True)
# fig.patch.set_visible(False)
ax.margins(x=0.1, y=0.08, tight=False)
ax.axis("off")
nx.draw_networkx(G, pos, **options, arrowsize=25, arrowstyle="-|>", ax=ax)
# need to reuse x, y limits so that the graphs plot the same way before and after removing edges
x_lim = (-1.32, 1.32)
y_lim = (-1.414, 2.414)
ax.set_xlim(x_lim)
ax.set_ylim(y_lim)
rect = patches.FancyBboxPatch(
(-0.5, -1.325),
1,
0.65,
boxstyle="round, pad=0.05, rounding_size=0",
linewidth=2,
edgecolor="black",
facecolor="none",
linestyle="-",
)
ax.add_patch(rect)
# ax.text(1.85, 0.65, r"$\mathbf{U}_{\mathbf{x}}$", fontsize=fs)
if args[0]: # do_m
fig.figimage(HAMMER, 0.0075 * fig.bbox.xmax, 0.395 * fig.bbox.ymax, zorder=10)
if args[1]: # do_s
fig.figimage(HAMMER, 0.72 * fig.bbox.xmax, 0.395 * fig.bbox.ymax, zorder=11)
if args[2]: # do_a
fig.figimage(HAMMER, 0.363 * fig.bbox.xmax, 0.64 * fig.bbox.ymax, zorder=12)
if args[3]: # do_b
fig.figimage(HAMMER, 0.72 * fig.bbox.xmax, 0.64 * fig.bbox.ymax, zorder=13)
if args[4]: # do_v
fig.figimage(HAMMER, 0.0075 * fig.bbox.xmax, 0.64 * fig.bbox.ymax, zorder=14)
else: # b -> v
a3 = patches.FancyArrowPatch(
(0.86, 1.21),
(-0.86, 1.21),
connectionstyle="arc3,rad=.3",
linewidth=2,
arrowstyle="simple, head_width=10, head_length=10",
color="k",
)
ax.add_patch(a3)
# print(ax.get_xlim())
# print(ax.get_ylim())
fig.tight_layout()
fig.canvas.draw()
return np.array(fig.canvas.renderer.buffer_rgba())
def chest_graph(*args):
x, a, d, r, s = r"$\mathbf{x}$", r"$a$", r"$d$", r"$r$", r"$s$"
ua, ud, ur, us = (
r"$\mathbf{U}_a$",
r"$\mathbf{U}_d$",
r"$\mathbf{U}_r$",
r"$\mathbf{U}_s$",
)
zx, ex = r"$\mathbf{z}_{1:L}$", r"$\boldsymbol{\epsilon}$"
G = nx.DiGraph()
G.add_edge(ua, a)
G.add_edge(ud, d)
G.add_edge(ur, r)
G.add_edge(us, s)
G.add_edge(a, d)
G.add_edge(d, x)
G.add_edge(r, x)
G.add_edge(s, x)
G.add_edge(ex, x)
G.add_edge(zx, x)
G.add_edge(a, x)
pos = {
x: (0, 0),
a: (-1, 1),
d: (0, 1),
r: (1, 1),
s: (1, 0),
ua: (-1, 2),
ud: (0, 2),
ur: (1, 2),
us: (1, -1),
zx: (-0.25, -1),
ex: (0.25, -1),
}
node_c = {}
for node in G:
node_c[node] = "lightgrey" if node in [x, a, d, r, s] else "white"
edge_c = {e: "black" for e in G.edges}
node_line_c = {k: "black" for k, _ in node_c.items()}
if args[0]: # do_r
# G.remove_edge(ur, r)
edge_c[(ur, r)] = "lightgrey"
node_line_c[r] = "red"
if args[1]: # do_s
# G.remove_edges_from([(us, s)])
edge_c[(us, s)] = "lightgrey"
node_line_c[s] = "red"
if args[2]: # do_f (do_d)
# G.remove_edges_from([(ud, d), (a, d)])
edge_c[(ud, d)] = "lightgrey"
edge_c[(a, d)] = "lightgrey"
node_line_c[d] = "red"
if args[3]: # do_a
# G.remove_edge(ua, a)
edge_c[(ua, a)] = "lightgrey"
node_line_c[a] = "red"
fs = 30
options = {
"font_size": fs,
"node_size": 3000,
"node_color": list(node_c.values()),
"edgecolors": list(node_line_c.values()),
"edge_color": list(edge_c.values()),
"linewidths": 2,
"width": 2,
}
plt.close("all")
fig, ax = plt.subplots(1, 1, figsize=(5, 5)) # , constrained_layout=True)
# fig.patch.set_visible(False)
ax.margins(x=0.1, y=0.08, tight=False)
ax.axis("off")
nx.draw_networkx(G, pos, **options, arrowsize=25, arrowstyle="-|>", ax=ax)
# need to reuse x, y limits so that the graphs plot the same way before and after removing edges
x_lim = (-1.32, 1.32)
y_lim = (-1.414, 2.414)
ax.set_xlim(x_lim)
ax.set_ylim(y_lim)
rect = patches.FancyBboxPatch(
(-0.5, -1.325),
1,
0.65,
boxstyle="round, pad=0.05, rounding_size=0",
linewidth=2,
edgecolor="black",
facecolor="none",
linestyle="-",
)
ax.add_patch(rect)
ax.text(-0.9, -1.075, r"$\mathbf{U}_{\mathbf{x}}$", fontsize=fs)
if args[0]: # do_r
fig.figimage(HAMMER, 0.72 * fig.bbox.xmax, 0.64 * fig.bbox.ymax, zorder=10)
if args[1]: # do_s
fig.figimage(HAMMER, 0.72 * fig.bbox.xmax, 0.395 * fig.bbox.ymax, zorder=11)
if args[2]: # do_f
fig.figimage(HAMMER, 0.363 * fig.bbox.xmax, 0.64 * fig.bbox.ymax, zorder=12)
if args[3]: # do_a
fig.figimage(HAMMER, 0.0075 * fig.bbox.xmax, 0.64 * fig.bbox.ymax, zorder=13)
fig.tight_layout()
fig.canvas.draw()
return np.array(fig.canvas.renderer.buffer_rgba())
def vae_preprocess(args, pa):
if "ukbb" in args.hps:
# preprocessing ukbb parents for the vae which was originally trained using
# log standardized parents. The pgm was trained using [-1,1] normalization
# first undo [-1,1] parent preprocessing back to original range
for k, v in pa.items():
if k != "mri_seq" and k != "sex":
pa[k] = (v + 1) / 2 # [-1,1] -> [0,1]
_max, _min = get_attr_max_min(k)
pa[k] = pa[k] * (_max - _min) + _min
# log_standardize parents for vae input
for k, v in pa.items():
logpa_k = torch.log(v.clamp(min=1e-12))
if k == "age":
pa[k] = (logpa_k - 4.112339973449707) / 0.11769197136163712
elif k == "brain_volume":
pa[k] = (logpa_k - 13.965583801269531) / 0.09537758678197861
elif k == "ventricle_volume":
pa[k] = (logpa_k - 10.345998764038086) / 0.43127763271331787
# concatenate parents expand to input res for conditioning the vae
pa = torch.cat(
[pa[k] if len(pa[k].shape) > 1 else pa[k][..., None] for k in args.parents_x],
dim=1,
)
pa = (
pa[..., None, None].repeat(1, 1, *(args.input_res,) * 2).to(args.device).float()
)
return pa
def preprocess_brain(args, obs):
obs["x"] = (obs["x"][None, ...].float().to(args.device) - 127.5) / 127.5 # [-1,1]
# for all other variables except x
for k in [k for k in obs.keys() if k != "x"]:
obs[k] = obs[k].float().to(args.device).view(1, 1)
if k in ["age", "brain_volume", "ventricle_volume"]:
k_max, k_min = get_attr_max_min(k)
obs[k] = (obs[k] - k_min) / (k_max - k_min) # [0,1]
obs[k] = 2 * obs[k] - 1 # [-1,1]
return obs
def get_fig_arr(x, width=4, height=4, dpi=144, cmap="Greys_r", norm=None):
fig = plt.figure(figsize=(width, height), dpi=dpi)
ax = plt.axes([0, 0, 1, 1], frameon=False)
if cmap == "Greys_r":
ax.imshow(x, cmap=cmap, vmin=0, vmax=255)
else:
ax.imshow(x, cmap=cmap, norm=norm)
ax.axis("off")
fig.canvas.draw()
return np.array(fig.canvas.renderer.buffer_rgba())
def normalize(x, x_min=None, x_max=None, zero_one=False):
if x_min is None:
x_min = x.min()
if x_max is None:
x_max = x.max()
x = (x - x_min) / (x_max - x_min) # [0,1]
return x if zero_one else 2 * x - 1 # else [-1,1]