import os import numpy as np import torch as th from imageio import imread from skimage.transform import resize as imresize from ema_pytorch import EMA from decomp_diffusion.model_and_diffusion_util import * from decomp_diffusion.diffusion.respace import SpacedDiffusion from decomp_diffusion.gen_image import * from download import download_model import gradio as gr # fix randomness th.manual_seed(0) np.random.seed(0) def get_pil_im(im, resolution=64): im = imresize(im, (resolution, resolution))[:, :, :3] im = th.Tensor(im).permute(2, 0, 1)[None, :, :, :].contiguous() return im # generate image components and reconstruction def gen_image_and_components(model, gd, im, num_components=4, sample_method='ddim', batch_size=1, image_size=64, device='cuda', num_images=1): """Generate row of orig image, individual components, and reconstructed image""" orig_img = get_pil_im(im, resolution=image_size).to(device) latent = model.encode_latent(orig_img) model_kwargs = {'latent': latent} assert sample_method in ('ddpm', 'ddim') sample_loop_func = gd.p_sample_loop if sample_method == 'ddpm' else gd.ddim_sample_loop if sample_method == 'ddim': model = gd._wrap_model(model) # generate imgs for i in range(num_images): all_samples = [orig_img] # individual components for j in range(num_components): model_kwargs['latent_index'] = j sample = sample_loop_func( model, (batch_size, 3, image_size, image_size), device=device, clip_denoised=True, progress=True, model_kwargs=model_kwargs, cond_fn=None, )[:batch_size] # save indiv comp all_samples.append(sample) # reconstruction model_kwargs['latent_index'] = None sample = sample_loop_func( model, (batch_size, 3, image_size, image_size), device=device, clip_denoised=True, progress=True, model_kwargs=model_kwargs, cond_fn=None, )[:batch_size] # save indiv reconstruction all_samples.append(sample) samples = th.cat(all_samples, dim=0).cpu() grid = make_grid(samples, nrow=samples.shape[0], padding=0) return grid def decompose_image(im): sample_method = 'ddim' result = gen_image_and_components(clevr_model, GD[sample_method], im, sample_method=sample_method, num_images=1, device=device) return result.permute(1, 2, 0).numpy() # load diffusion GD = {} # diffusion objects for ddim and ddpm diffusion_kwargs = diffusion_defaults() gd = create_gaussian_diffusion(**diffusion_kwargs) GD['ddpm'] = gd # set up ddim sampling desired_timesteps = 50 num_timesteps = diffusion_kwargs['steps'] spacing = num_timesteps // desired_timesteps spaced_ts = list(range(0, num_timesteps + 1, spacing)) betas = get_named_beta_schedule(diffusion_kwargs['noise_schedule'], num_timesteps) diffusion_kwargs['betas'] = betas del diffusion_kwargs['steps'], diffusion_kwargs['noise_schedule'] gd = SpacedDiffusion(spaced_ts, rescale_timesteps=True, original_num_steps=num_timesteps, **diffusion_kwargs) GD['ddim'] = gd # ckpt_path = download_model('clevr') # 'clevr_model.pt' # model_kwargs = unet_model_defaults() # # model parameters # model_kwargs.update(dict( # emb_dim=64, # enc_channels=128 # )) # clevr_model = create_diffusion_model(**model_kwargs) # clevr_model.eval() # device = 'cuda' if th.cuda.is_available() else 'cpu' # clevr_model.to(device) # print(f'loading from {ckpt_path}') # checkpoint = th.load(ckpt_path, map_location='cpu') # clevr_model.load_state_dict(checkpoint) # img_input = gr.inputs.Image(type="numpy", label="Input") # img_output = gr.outputs.Image(type="numpy", label="Output") # gr.Interface( # decompose_image, # inputs=img_input, # outputs=img_output, # examples=[ # "sample_images/clevr_im_10.png", # "sample_images/clevr_im_25.png", # ], # ).launch() def combine_components_slice(model, gd, im1, im2, indices=None, sample_method='ddim', device='cuda', num_images=4, model_kwargs={}, desc='', save_dir='', dataset='clevr', image_size=64): """Combine by adding components together """ assert sample_method in ('ddpm', 'ddim') im1 = get_pil_im(im1, resolution=image_size).to(device) im2 = get_pil_im(im2, resolution=image_size).to(device) latent1 = model.encode_latent(im1) latent2 = model.encode_latent(im2) num_comps = model.num_components # get latent slices if indices == None: half = num_comps // 2 indices = [1] * half + [0] * half # first half 1, second half 0 indices = th.Tensor(indices) == 1 indices = indices.reshape(num_comps, 1) elif type(indices) == str: indices = indices.split(',') indices = [int(ind) for ind in indices] indices = th.Tensor(indices).reshape(-1, 1) == 1 assert len(indices) == num_comps indices = indices.to(device) latent1 = latent1.reshape(num_comps, -1).to(device) latent2 = latent2.reshape(num_comps, -1).to(device) combined_latent = th.where(indices, latent1, latent2) combined_latent = combined_latent.reshape(1, -1) model_kwargs['latent'] = combined_latent sample_loop_func = gd.p_sample_loop if sample_method == 'ddpm' else gd.ddim_sample_loop if sample_method == 'ddim': model = gd._wrap_model(model) # sampling loop sample = sample_loop_func( model, (1, 3, image_size, image_size), device=device, clip_denoised=True, progress=True, model_kwargs=model_kwargs, cond_fn=None, )[:1] return sample[0].cpu() def combine_images(im1, im2): sample_method = 'ddim' result = combine_components_slice(clevr_model, GD[sample_method], im1, im2, indices='1,0,1,0', sample_method=sample_method, num_images=1) return result.permute(1, 2, 0).numpy() ckpt_path = download_model('celebahq') # 'celeb_model.pt' model_kwargs = unet_model_defaults() # model parameters model_kwargs.update(dict( enc_channels=128 )) celeb_model = create_diffusion_model(**model_kwargs) celeb_model.eval() device = 'cuda' if th.cuda.is_available() else 'cpu' celeb_model.to(device) print(f'loading from {ckpt_path}') checkpoint = th.load(ckpt_path, map_location='cpu') celeb_model.load_state_dict(checkpoint) # Recombination img_input = gr.inputs.Image(type="numpy", label="Input") img_input2 = gr.inputs.Image(type="numpy", label="Input") img_output = gr.outputs.Image(type="numpy", label="Output") gr.Interface( combine_images, inputs=[img_input, img_input2], outputs=img_output, examples=[ ["sample_images/celebahq_im_15.jpg", "sample_images/celebahq_im_21.jpg"] ] ).launch()