Spaces:
Runtime error
Runtime error
import os | |
import numpy as np | |
import torch as th | |
from imageio import imread | |
from skimage.transform import resize as imresize | |
from ema_pytorch import EMA | |
from decomp_diffusion.model_and_diffusion_util import * | |
from decomp_diffusion.diffusion.respace import SpacedDiffusion | |
from decomp_diffusion.gen_image import * | |
from download import download_model | |
import gradio as gr | |
# fix randomness | |
th.manual_seed(0) | |
np.random.seed(0) | |
def get_pil_im(im, resolution=64): | |
im = imresize(im, (resolution, resolution))[:, :, :3] | |
im = th.Tensor(im).permute(2, 0, 1)[None, :, :, :].contiguous() | |
return im | |
# generate image components and reconstruction | |
def gen_image_and_components(model, gd, im, num_components=4, sample_method='ddim', batch_size=1, image_size=64, device='cuda', num_images=1): | |
"""Generate row of orig image, individual components, and reconstructed image""" | |
orig_img = get_pil_im(im, resolution=image_size).to(device) | |
latent = model.encode_latent(orig_img) | |
model_kwargs = {'latent': latent} | |
assert sample_method in ('ddpm', 'ddim') | |
sample_loop_func = gd.p_sample_loop if sample_method == 'ddpm' else gd.ddim_sample_loop | |
if sample_method == 'ddim': | |
model = gd._wrap_model(model) | |
# generate imgs | |
for i in range(num_images): | |
all_samples = [orig_img] | |
# individual components | |
for j in range(num_components): | |
model_kwargs['latent_index'] = j | |
sample = sample_loop_func( | |
model, | |
(batch_size, 3, image_size, image_size), | |
device=device, | |
clip_denoised=True, | |
progress=True, | |
model_kwargs=model_kwargs, | |
cond_fn=None, | |
)[:batch_size] | |
# save indiv comp | |
all_samples.append(sample) | |
# reconstruction | |
model_kwargs['latent_index'] = None | |
sample = sample_loop_func( | |
model, | |
(batch_size, 3, image_size, image_size), | |
device=device, | |
clip_denoised=True, | |
progress=True, | |
model_kwargs=model_kwargs, | |
cond_fn=None, | |
)[:batch_size] | |
# save indiv reconstruction | |
all_samples.append(sample) | |
samples = th.cat(all_samples, dim=0).cpu() | |
grid = make_grid(samples, nrow=samples.shape[0], padding=0) | |
return grid | |
def decompose_image(im): | |
sample_method = 'ddim' | |
result = gen_image_and_components(clevr_model, GD[sample_method], im, sample_method=sample_method, num_images=1, device=device) | |
return result.permute(1, 2, 0).numpy() | |
# load diffusion | |
GD = {} # diffusion objects for ddim and ddpm | |
diffusion_kwargs = diffusion_defaults() | |
gd = create_gaussian_diffusion(**diffusion_kwargs) | |
GD['ddpm'] = gd | |
# set up ddim sampling | |
desired_timesteps = 50 | |
num_timesteps = diffusion_kwargs['steps'] | |
spacing = num_timesteps // desired_timesteps | |
spaced_ts = list(range(0, num_timesteps + 1, spacing)) | |
betas = get_named_beta_schedule(diffusion_kwargs['noise_schedule'], num_timesteps) | |
diffusion_kwargs['betas'] = betas | |
del diffusion_kwargs['steps'], diffusion_kwargs['noise_schedule'] | |
gd = SpacedDiffusion(spaced_ts, rescale_timesteps=True, original_num_steps=num_timesteps, **diffusion_kwargs) | |
GD['ddim'] = gd | |
# ckpt_path = download_model('clevr') # 'clevr_model.pt' | |
# model_kwargs = unet_model_defaults() | |
# # model parameters | |
# model_kwargs.update(dict( | |
# emb_dim=64, | |
# enc_channels=128 | |
# )) | |
# clevr_model = create_diffusion_model(**model_kwargs) | |
# clevr_model.eval() | |
# device = 'cuda' if th.cuda.is_available() else 'cpu' | |
# clevr_model.to(device) | |
# print(f'loading from {ckpt_path}') | |
# checkpoint = th.load(ckpt_path, map_location='cpu') | |
# clevr_model.load_state_dict(checkpoint) | |
# img_input = gr.inputs.Image(type="numpy", label="Input") | |
# img_output = gr.outputs.Image(type="numpy", label="Output") | |
# gr.Interface( | |
# decompose_image, | |
# inputs=img_input, | |
# outputs=img_output, | |
# examples=[ | |
# "sample_images/clevr_im_10.png", | |
# "sample_images/clevr_im_25.png", | |
# ], | |
# ).launch() | |
def combine_components_slice(model, gd, im1, im2, indices=None, sample_method='ddim', device='cuda', num_images=4, model_kwargs={}, desc='', save_dir='', dataset='clevr', image_size=64): | |
"""Combine by adding components together | |
""" | |
assert sample_method in ('ddpm', 'ddim') | |
im1 = get_pil_im(im1, resolution=image_size).to(device) | |
im2 = get_pil_im(im2, resolution=image_size).to(device) | |
latent1 = model.encode_latent(im1) | |
latent2 = model.encode_latent(im2) | |
num_comps = model.num_components | |
# get latent slices | |
if indices == None: | |
half = num_comps // 2 | |
indices = [1] * half + [0] * half # first half 1, second half 0 | |
indices = th.Tensor(indices) == 1 | |
indices = indices.reshape(num_comps, 1) | |
elif type(indices) == str: | |
indices = indices.split(',') | |
indices = [int(ind) for ind in indices] | |
indices = th.Tensor(indices).reshape(-1, 1) == 1 | |
assert len(indices) == num_comps | |
indices = indices.to(device) | |
latent1 = latent1.reshape(num_comps, -1).to(device) | |
latent2 = latent2.reshape(num_comps, -1).to(device) | |
combined_latent = th.where(indices, latent1, latent2) | |
combined_latent = combined_latent.reshape(1, -1) | |
model_kwargs['latent'] = combined_latent | |
sample_loop_func = gd.p_sample_loop if sample_method == 'ddpm' else gd.ddim_sample_loop | |
if sample_method == 'ddim': | |
model = gd._wrap_model(model) | |
# sampling loop | |
sample = sample_loop_func( | |
model, | |
(1, 3, image_size, image_size), | |
device=device, | |
clip_denoised=True, | |
progress=True, | |
model_kwargs=model_kwargs, | |
cond_fn=None, | |
)[:1] | |
return sample[0].cpu() | |
def combine_images(im1, im2): | |
sample_method = 'ddim' | |
result = combine_components_slice(clevr_model, GD[sample_method], im1, im2, indices='1,0,1,0', sample_method=sample_method, num_images=1) | |
return result.permute(1, 2, 0).numpy() | |
ckpt_path = download_model('celebahq') # 'celeb_model.pt' | |
model_kwargs = unet_model_defaults() | |
# model parameters | |
model_kwargs.update(dict( | |
enc_channels=128 | |
)) | |
celeb_model = create_diffusion_model(**model_kwargs) | |
celeb_model.eval() | |
device = 'cuda' if th.cuda.is_available() else 'cpu' | |
celeb_model.to(device) | |
print(f'loading from {ckpt_path}') | |
checkpoint = th.load(ckpt_path, map_location='cpu') | |
celeb_model.load_state_dict(checkpoint) | |
# Recombination | |
img_input = gr.inputs.Image(type="numpy", label="Input") | |
img_input2 = gr.inputs.Image(type="numpy", label="Input") | |
img_output = gr.outputs.Image(type="numpy", label="Output") | |
gr.Interface( | |
combine_images, | |
inputs=[img_input, img_input2], | |
outputs=img_output, | |
examples=[ | |
["sample_images/celebahq_im_15.jpg", | |
"sample_images/celebahq_im_21.jpg"] | |
] | |
).launch() | |