import os import time import logging from pathlib import Path import torch from torchvision.io import read_image import torchvision.transforms.v2 as transforms from torchvision.utils import make_grid import gradio as gr from diffusers import AutoencoderKL, EulerDiscreteScheduler from transformers import SiglipImageProcessor, SiglipVisionModel from huggingface_hub import hf_hub_download import spaces from esrgan_model import UpscalerESRGAN from model import create_model device = "cuda" # Custom timer logger only timer_logger = logging.getLogger("TIMER") timer_logger.setLevel(logging.INFO) handler = logging.StreamHandler() # Attach a stream handler with formatter handler.setFormatter(logging.Formatter('%(asctime)s - %(message)s', datefmt="%Y-%m-%d %H:%M:%S")) timer_logger.addHandler(handler) timer_logger.propagate = False # Avoid duplicate logs # Custom transform to pad images to square class PadToSquare: def __call__(self, img): _, h, w = img.shape max_side = max(h, w) pad_h = (max_side - h) // 2 pad_w = (max_side - w) // 2 padding = (pad_w, pad_h, max_side - w - pad_w, max_side - h - pad_h) return transforms.functional.pad(img, padding, padding_mode="edge") # Timer decorator def timer_func(func): def wrapper(*args, **kwargs): t0 = time.time() result = func(*args, **kwargs) timer_logger.info(f"{func.__name__} took {time.time() - t0:.2f} seconds") return result return wrapper @timer_func def load_model(model_class_name, model_filename, repo_id: str = "rizavelioglu/tryoffdiff"): path_model = hf_hub_download(repo_id=repo_id, filename=model_filename, force_download=False) state_dict = torch.load(path_model, weights_only=True, map_location=device) state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()} model = create_model(model_class_name).to(device) # model = torch.compile(model) model.load_state_dict(state_dict, strict=True) return model.eval() def validate_garment_selection(garment_types): """Validate garment type selection and return selected types and label indices.""" label_map = {"Upper-Body": 0, "Lower-Body": 1, "Dress": 2} valid_single = ["Upper-Body", "Lower-Body", "Dress"] valid_tuple = ["Upper-Body", "Lower-Body"] if not garment_types: raise gr.Error("Please select at least one garment type.") if len(garment_types) == 1 and garment_types[0] in valid_single: selected, label_indices = garment_types, [label_map[garment_types[0]]] elif sorted(garment_types) == sorted(valid_tuple): selected, label_indices = valid_tuple, [label_map[t] for t in valid_tuple] else: raise gr.Error("Invalid selection. Choose one garment type or Upper-Body and Lower-Body together.") return selected, label_indices def generate_multi_image_wrapper(input_image, garment_types, seed=42, guidance_scale=2.0, num_inference_steps=50, is_upscale=False): """Wrapper function that validates input before calling the GPU function.""" # Validate selection before entering GPU context selected, label_indices = validate_garment_selection(garment_types) return generate_multi_image(input_image, selected, label_indices, seed, guidance_scale, num_inference_steps, is_upscale) @spaces.GPU(duration=10) @torch.no_grad() @timer_func def generate_multi_image(input_image, selected, label_indices, seed=42, guidance_scale=2.0, num_inference_steps=50, is_upscale=False): batch_size = len(selected) scheduler.set_timesteps(num_inference_steps) generator = torch.Generator(device=device).manual_seed(seed) x = torch.randn(batch_size, 4, 64, 64, generator=generator, device=device) # Process inputs cond_image = img_enc_transform(read_image(input_image)) inputs = {k: v.to(device) for k, v in img_processor(images=cond_image, return_tensors="pt").items()} cond_emb = img_enc(**inputs).last_hidden_state.to(device) cond_emb = cond_emb.expand(batch_size, *cond_emb.shape[1:]) uncond_emb = torch.zeros_like(cond_emb) if guidance_scale > 1 else None label = torch.tensor(label_indices, device=device, dtype=torch.int64) model = models["multi"] with torch.autocast(device): for t in scheduler.timesteps: t = t.to(device) # Ensure t is on the correct device if guidance_scale > 1: noise_pred = model(torch.cat([x] * 2), t, torch.cat([uncond_emb, cond_emb]), torch.cat([label, label])).chunk(2) noise_pred = noise_pred[0] + guidance_scale * (noise_pred[1] - noise_pred[0]) # Classifier-free guidance else: noise_pred = model(x, t, cond_emb, label) # Standard prediction # Scheduler step scheduler_output = scheduler.step(noise_pred, t, x) x = scheduler_output.prev_sample # Decode predictions from latent space decoded = vae.decode(1 / vae.config.scaling_factor * scheduler_output.pred_original_sample).sample images = (decoded / 2 + 0.5).cpu() grid = make_grid(images, nrow=len(images), normalize=True, scale_each=True) output_image = transforms.ToPILImage()(grid) return upscaler(output_image) if is_upscale else output_image # Optionally upscale the output image @spaces.GPU(duration=10) @torch.no_grad() @timer_func def generate_upper_image(input_image, seed=42, guidance_scale=2.0, num_inference_steps=50, is_upscale=False): model = models["upper"] scheduler.set_timesteps(num_inference_steps) scheduler.timesteps = scheduler.timesteps.to(device) generator = torch.Generator(device=device).manual_seed(seed) x = torch.randn(1, 4, 64, 64, generator=generator, device=device) # Process input image cond_image = img_enc_transform(read_image(input_image)) inputs = {k: v.to(device) for k, v in img_processor(images=cond_image, return_tensors="pt").items()} cond_emb = img_enc(**inputs).last_hidden_state.to(device) uncond_emb = torch.zeros_like(cond_emb) if guidance_scale > 1 else None with torch.autocast(device): for t in scheduler.timesteps: t = t.to(device) # Ensure t is on the correct device if guidance_scale > 1: # Classifier-free guidance noise_pred = model(torch.cat([x] * 2), t, torch.cat([uncond_emb, cond_emb])).chunk(2) noise_pred = noise_pred[0] + guidance_scale * (noise_pred[1] - noise_pred[0]) else: # Standard prediction noise_pred = model(x, t, cond_emb) # Scheduler step scheduler_output = scheduler.step(noise_pred, t, x) x = scheduler_output.prev_sample # Decode predictions from latent space decoded = vae.decode(1 / vae.config.scaling_factor * scheduler_output.pred_original_sample).sample images = (decoded / 2 + 0.5).cpu() grid = make_grid(images, nrow=len(images), normalize=True, scale_each=True) output_image = transforms.ToPILImage()(grid) return upscaler(output_image) if is_upscale else output_image # Optionally upscale the output image @spaces.GPU(duration=10) @torch.no_grad() @timer_func def generate_lower_image(input_image, seed=42, guidance_scale=2.0, num_inference_steps=50, is_upscale=False): model = models["lower"] scheduler.set_timesteps(num_inference_steps) scheduler.timesteps = scheduler.timesteps.to(device) generator = torch.Generator(device=device).manual_seed(seed) x = torch.randn(1, 4, 64, 64, generator=generator, device=device) # Process input image cond_image = img_enc_transform(read_image(input_image)) inputs = {k: v.to(device) for k, v in img_processor(images=cond_image, return_tensors="pt").items()} cond_emb = img_enc(**inputs).last_hidden_state.to(device) uncond_emb = torch.zeros_like(cond_emb) if guidance_scale > 1 else None with torch.autocast(device): for t in scheduler.timesteps: t = t.to(device) # Ensure t is on the correct device if guidance_scale > 1: # Classifier-free guidance noise_pred = model(torch.cat([x] * 2), t, torch.cat([uncond_emb, cond_emb])).chunk(2) noise_pred = noise_pred[0] + guidance_scale * (noise_pred[1] - noise_pred[0]) else: # Standard prediction noise_pred = model(x, t, cond_emb) # Scheduler step scheduler_output = scheduler.step(noise_pred, t, x) x = scheduler_output.prev_sample # Decode predictions from latent space decoded = vae.decode(1 / vae.config.scaling_factor * scheduler_output.pred_original_sample).sample images = (decoded / 2 + 0.5).cpu() grid = make_grid(images, nrow=len(images), normalize=True, scale_each=True) output_image = transforms.ToPILImage()(grid) return upscaler(output_image) if is_upscale else output_image # Optionally upscale the output image @spaces.GPU(duration=10) @torch.no_grad() @timer_func def generate_dress_image(input_image, seed=42, guidance_scale=2.0, num_inference_steps=50, is_upscale=False): model = models["dress"] scheduler.set_timesteps(num_inference_steps) scheduler.timesteps = scheduler.timesteps.to(device) generator = torch.Generator(device=device).manual_seed(seed) x = torch.randn(1, 4, 64, 64, generator=generator, device=device) # Process input image cond_image = img_enc_transform(read_image(input_image)) inputs = {k: v.to(device) for k, v in img_processor(images=cond_image, return_tensors="pt").items()} cond_emb = img_enc(**inputs).last_hidden_state.to(device) uncond_emb = torch.zeros_like(cond_emb) if guidance_scale > 1 else None with torch.autocast(device): for t in scheduler.timesteps: t = t.to(device) # Ensure t is on the correct device if guidance_scale > 1: # Classifier-free guidance noise_pred = model(torch.cat([x] * 2), t, torch.cat([uncond_emb, cond_emb])).chunk(2) noise_pred = noise_pred[0] + guidance_scale * (noise_pred[1] - noise_pred[0]) else: # Standard prediction noise_pred = model(x, t, cond_emb) # Scheduler step scheduler_output = scheduler.step(noise_pred, t, x) x = scheduler_output.prev_sample # Decode predictions from latent space decoded = vae.decode(1 / vae.config.scaling_factor * scheduler_output.pred_original_sample).sample images = (decoded / 2 + 0.5).cpu() grid = make_grid(images, nrow=len(images), normalize=True, scale_each=True) output_image = transforms.ToPILImage()(grid) return upscaler(output_image) if is_upscale else output_image # Optionally upscale the output image def create_multi_tab(): description = r"""
In total, 4 models are available for generating garments (one in each tab): - Multi-Garment: Generate multiple garments (e.g., upper-body and lower-body) sequentially. - Upper-Body: Generate upper-body garments (e.g., tops, jackets, etc.). - Lower-Body: Generate lower-body garments (e.g., pants, skirts, etc.). - Dress: Generate dresses. |
How to use: 1. Upload a reference image, 2. Adjust the parameters as needed, 3. Click "Generate" to create the garment(s). 💡 Individual models perform slightly better than the multi-garment model, but the latter is more versatile. |