jsu27's picture
init
4c1d330
raw
history blame
4.19 kB
import os
import numpy as np
import torch as th
from imageio import imread
from skimage.transform import resize as imresize
from ema_pytorch import EMA
from decomp_diffusion.model_and_diffusion_util import *
from decomp_diffusion.diffusion.respace import SpacedDiffusion
from decomp_diffusion.gen_image import *
from download import download_model
# fix randomness
th.manual_seed(0)
np.random.seed(0)
import gradio as gr
def get_pil_im(im, resolution=64):
im = imresize(im, (resolution, resolution))[:, :, :3]
im = th.Tensor(im).permute(2, 0, 1)[None, :, :, :].contiguous().cuda()
return im
# generate image components and reconstruction
def gen_image_and_components(model, gd, im, num_components=4, sample_method='ddim', batch_size=1, image_size=64, device='cuda', num_images=1):
"""Generate row of orig image, individual components, and reconstructed image"""
orig_img = get_pil_im(im, resolution=image_size)
latent = model.encode_latent(orig_img)
model_kwargs = {'latent': latent}
assert sample_method in ('ddpm', 'ddim')
sample_loop_func = gd.p_sample_loop if sample_method == 'ddpm' else gd.ddim_sample_loop
if sample_method == 'ddim':
model = gd._wrap_model(model)
# generate imgs
for i in range(num_images):
all_samples = [orig_img]
# individual components
for j in range(num_components):
model_kwargs['latent_index'] = j
sample = sample_loop_func(
model,
(batch_size, 3, image_size, image_size),
device=device,
clip_denoised=True,
progress=True,
model_kwargs=model_kwargs,
cond_fn=None,
)[:batch_size]
# save indiv comp
all_samples.append(sample)
# reconstruction
model_kwargs['latent_index'] = None
sample = sample_loop_func(
model,
(batch_size, 3, image_size, image_size),
device=device,
clip_denoised=True,
progress=True,
model_kwargs=model_kwargs,
cond_fn=None,
)[:batch_size]
# save indiv reconstruction
all_samples.append(sample)
samples = th.cat(all_samples, dim=0).cpu()
grid = make_grid(samples, nrow=samples.shape[0], padding=0)
return grid
def decompose_image(im):
sample_method = 'ddim'
result = gen_image_and_components(clevr_model, GD[sample_method], im, sample_method=sample_method, num_images=1)
return result.permute(1, 2, 0).numpy()
# load diffusion
GD = {} # diffusion objects for ddim and ddpm
diffusion_kwargs = diffusion_defaults()
gd = create_gaussian_diffusion(**diffusion_kwargs)
GD['ddpm'] = gd
# set up ddim sampling
desired_timesteps = 50
num_timesteps = diffusion_kwargs['steps']
spacing = num_timesteps // desired_timesteps
spaced_ts = list(range(0, num_timesteps + 1, spacing))
betas = get_named_beta_schedule(diffusion_kwargs['noise_schedule'], num_timesteps)
diffusion_kwargs['betas'] = betas
del diffusion_kwargs['steps'], diffusion_kwargs['noise_schedule']
gd = SpacedDiffusion(spaced_ts, rescale_timesteps=True, original_num_steps=num_timesteps, **diffusion_kwargs)
GD['ddim'] = gd
# !wget https://www.dropbox.com/s/bqpc3ymstz9m05z/clevr_model.pt
# load model
ckpt_path = download_model('clevr') # 'clevr_model.pt'
model_kwargs = unet_model_defaults()
# model parameters
model_kwargs.update(dict(
emb_dim=64,
enc_channels=128
))
clevr_model = create_diffusion_model(**model_kwargs)
clevr_model.eval()
device = 'cuda'
clevr_model.to(device)
print(f'loading from {ckpt_path}')
checkpoint = th.load(ckpt_path, map_location='cpu')
clevr_model.load_state_dict(checkpoint)
img_input = gr.inputs.Image(type="numpy", label="Input")
img_output = gr.outputs.Image(type="numpy", label="Output")
gr.Interface(
decompose_image,
inputs=img_input,
outputs=img_output,
examples=[
os.path.join(os.path.dirname(__file__), "sample_images/clevr_im_10.png"),
os.path.join(os.path.dirname(__file__), "sample_images/clevr_im_25.png"),
],
).launch()