import gradio as gr import gradio.themes import torch import numpy as np from PIL import Image from transformers import AutoTokenizer from model import Generator as PikaPikaGen from utils import denormalize_image from plots import plot_attention_visualization import os DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') CHECKPOINT_PATH = "pikapikagen/model_checkpoint/checkpoint_epoch_150.pth" TOKENIZER_NAME = "prajjwal1/bert-mini" class PokemonGenerator: """Main class for the Pokemon generation demo""" def __init__(self): self.device = DEVICE self.tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME) self._load_model() def _load_model(self): """Load the trained PikaPikaGen model""" try: # Initialize model self.generator = PikaPikaGen().to(self.device) # Load checkpoint checkpoint = torch.load(CHECKPOINT_PATH, map_location=self.device, weights_only=True) # Load saved weights into model self.generator.load_state_dict(checkpoint['generator_state_dict']) print(f"✅ Generator loaded from checkpoint (epoch {checkpoint.get('epoch', 'unknown')})") # No training self.generator.eval() except Exception as e: print(f"❌ Error loading model: {e}") raise def _tensor_to_pil(self, tensor): """Convert tensor to PIL Image""" # tensor shape: (3, H, W) img_np = tensor.permute(1, 2, 0).clamp(0, 1).numpy() img_np = (img_np * 255).astype(np.uint8) return Image.fromarray(img_np) def generate_pokemon(self, description, num_samples=4, show_attention=False, resolution="both"): """ Generate Pokemon sprites from text description Args: description (str): Text description of the desired Pokemon num_samples (int): Number of samples to generate (1-8) show_attention (bool): Whether to show attention visualization resolution (str): Output resolution - "256x256", "64x64", or "both" Returns: tuple: (generated_images, attention_plot) """ if not description.strip(): return [], "❌ Please enter a description." # No reason to compute gradients with torch.no_grad(): tokens = self.tokenizer( description, max_length=128, padding='max_length', truncation=True, return_tensors='pt' ) text_ids = tokens['input_ids'].repeat(num_samples, 1).to(self.device) attention_mask = tokens['attention_mask'].repeat(num_samples, 1).to(self.device) generated_256, generated_64, attention_maps, initial_weights = self.generator( text_ids, attention_mask, return_attentions=True ) # Convert tensors to PIL images output_images = [] images_to_process = [] if resolution in ["256x256", "both"]: images_to_process.append(generated_256) if resolution in ["64x64", "both"]: images_to_process.append(generated_64) for img_batch in images_to_process: img_batch_denorm = denormalize_image(img_batch.cpu()) for i in range(num_samples): img_pil = self._tensor_to_pil(img_batch_denorm[i]) output_images.append(img_pil) attention_plot = None if show_attention: # Create directory if it doesn't exist output_dir = "attention_visualizations" os.makedirs(output_dir, exist_ok=True) # Create a more descriptive ID for the file # To avoid overwriting the same file with the same name plot_id = description.strip().replace(" ", "_")[:30] # Use the first sample for the attention visualization attention_plot = plot_attention_visualization( epoch=0, set_name="demo", output_dir=output_dir, generated_images=generated_256, # Full batch data from the model decoder_attention_maps=attention_maps, initial_context_weights=initial_weights, token_ids=text_ids, attention_mask=attention_mask, tokenizer=self.tokenizer, # Metadata for the specific sample description=description, pokemon_id=plot_id, sample_idx=0, show_inline=False ) return output_images, attention_plot print("Initializing PikaPikaGen Demo...") pokemon_gen = PokemonGenerator() def generate_pokemon_interface(description, num_samples, show_attention, resolution): images, attention_plot = pokemon_gen.generate_pokemon( description=description, num_samples=num_samples, show_attention=show_attention, resolution=resolution ) if images is None: return [], attention_plot # attention_plot contains error message if error status_msg = f"Generated {len(images)} Pokemon sprites" if resolution == "both": status_msg += f" ({num_samples} at 256x256 + {num_samples} at 64x64)" else: status_msg += f" at {resolution}" return images, attention_plot def create_interface(): with gr.Blocks( title="PikaPikaGen: AI Pokemon Generator", theme=gradio.themes.Soft(), css=""" .main-header { text-align: center; background: linear-gradient(45deg, #ff6b6b, #4ecdc4); -webkit-background-clip: text; -webkit-text-fill-color: transparent; font-size: 2.5em; font-weight: bold; margin-bottom: 0.5em; } .description { text-align: center; font-size: 1.1em; color: #666; margin-bottom: 1em; } """ ) as demo: gr.HTML("""