Spaces:
Runtime error
Runtime error
File size: 7,483 Bytes
2e82449 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import torch
import tqdm
import k_diffusion.sampling
import numpy as np
from modules import shared
from modules.models.diffusion.uni_pc import uni_pc
from modules.torch_utils import float64
@torch.no_grad()
def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones((x.shape[0]))
s_x = x.new_ones((x.shape[0], 1, 1, 1))
for i in tqdm.trange(len(timesteps) - 1, disable=disable):
index = len(timesteps) - 1 - i
e_t = model(x, timesteps[index].item() * s_in, **extra_args)
a_t = alphas[index].item() * s_x
a_prev = alphas_prev[index].item() * s_x
sigma_t = sigmas[index].item() * s_x
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
noise = sigma_t * k_diffusion.sampling.torch.randn_like(x)
x = a_prev.sqrt() * pred_x0 + dir_xt + noise
if callback is not None:
callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
return x
@torch.no_grad()
def ddim_cfgpp(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
""" Implements CFG++: Manifold-constrained Classifier Free Guidance For Diffusion Models (2024).
Uses the unconditional noise prediction instead of the conditional noise to guide the denoising direction.
The CFG scale is divided by 12.5 to map CFG from [0.0, 12.5] to [0, 1.0].
"""
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
model.cond_scale_miltiplier = 1 / 12.5
model.need_last_noise_uncond = True
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones((x.shape[0]))
s_x = x.new_ones((x.shape[0], 1, 1, 1))
for i in tqdm.trange(len(timesteps) - 1, disable=disable):
index = len(timesteps) - 1 - i
e_t = model(x, timesteps[index].item() * s_in, **extra_args)
last_noise_uncond = model.last_noise_uncond
a_t = alphas[index].item() * s_x
a_prev = alphas_prev[index].item() * s_x
sigma_t = sigmas[index].item() * s_x
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * last_noise_uncond
noise = sigma_t * k_diffusion.sampling.torch.randn_like(x)
x = a_prev.sqrt() * pred_x0 + dir_xt + noise
if callback is not None:
callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
return x
@torch.no_grad()
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
s_x = x.new_ones((x.shape[0], 1, 1, 1))
old_eps = []
def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep
a_t = alphas[index].item() * s_x
a_prev = alphas_prev[index].item() * s_x
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
# direction pointing to x_t
dir_xt = (1. - a_prev).sqrt() * e_t
x_prev = a_prev.sqrt() * pred_x0 + dir_xt
return x_prev, pred_x0
for i in tqdm.trange(len(timesteps) - 1, disable=disable):
index = len(timesteps) - 1 - i
ts = timesteps[index].item() * s_in
t_next = timesteps[max(index - 1, 0)].item() * s_in
e_t = model(x, ts, **extra_args)
if len(old_eps) == 0:
# Pseudo Improved Euler (2nd order)
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
e_t_next = model(x_prev, t_next, **extra_args)
e_t_prime = (e_t + e_t_next) / 2
elif len(old_eps) == 1:
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (3 * e_t - old_eps[-1]) / 2
elif len(old_eps) == 2:
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
else:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
x = x_prev
if callback is not None:
callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
return x
class UniPCCFG(uni_pc.UniPC):
def __init__(self, cfg_model, extra_args, callback, *args, **kwargs):
super().__init__(None, *args, **kwargs)
def after_update(x, model_x):
callback({'x': x, 'i': self.index, 'sigma': 0, 'sigma_hat': 0, 'denoised': model_x})
self.index += 1
self.cfg_model = cfg_model
self.extra_args = extra_args
self.callback = callback
self.index = 0
self.after_update = after_update
def get_model_input_time(self, t_continuous):
return (t_continuous - 1. / self.noise_schedule.total_N) * 1000.
def model(self, x, t):
t_input = self.get_model_input_time(t)
res = self.cfg_model(x, t_input, **self.extra_args)
return res
def unipc(model, x, timesteps, extra_args=None, callback=None, disable=None, is_img2img=False):
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
ns = uni_pc.NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
t_start = timesteps[-1] / 1000 + 1 / 1000 if is_img2img else None # this is likely off by a bit - if someone wants to fix it please by all means
unipc_sampler = UniPCCFG(model, extra_args, callback, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant)
x = unipc_sampler.sample(x, steps=len(timesteps), t_start=t_start, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final)
return x
|