import torch import numpy as np import matplotlib.pyplot as plt import os import io from PIL import Image from utils import denormalize_image import torch.nn.functional as F from transformers import AutoTokenizer def save_attention_visualization( epoch, model, tokenizer, batch, device, set_name, output_dir, show_inline=False ): print(f"Epoch {epoch}: Generating attention visualization for {set_name} set...") attention_data = generate_attention_data(model, tokenizer, batch, device) if attention_data: plot_attention_visualization( epoch=epoch, set_name=set_name, output_dir=output_dir, show_inline=show_inline, **attention_data, ) print(f"Epoch {epoch}: Attention visualization saved for Pokémon #{attention_data['pokemon_id']}.") else: print(f"Epoch {epoch}: Skipped attention visualization due to missing data.") def generate_attention_data(model, tokenizer, batch, device): """ Runs the model to generate the image and attention maps, filtering the padding tokens. """ model.eval() with torch.no_grad(): token_ids = batch["text"].to(device) attention_mask = batch["attention_mask"].to(device) # Ensure batch size is 1 for visualization if token_ids.dim() > 1: token_ids = token_ids[0].unsqueeze(0) attention_mask = attention_mask[0].unsqueeze(0) # Get the first sample from the batch pokemon_id = batch["idx"][0] description = batch["description"][0] generated_image, attention_maps, initial_context_weights = model( token_ids, attention_mask, return_attentions=True ) decoder_attention_maps = [m for m in attention_maps if m is not None] if not decoder_attention_maps or initial_context_weights is None: print("Attention maps not available. Skipping data generation.") return None # Extract valid tokens to display tokens_all = tokenizer.convert_ids_to_tokens(token_ids.squeeze(0)) display_tokens = [] for i, token in enumerate(tokens_all): if ( token not in [tokenizer.sep_token, tokenizer.pad_token] and attention_mask[0, i] == 1 ): display_tokens.append({"token": token, "index": i}) if not display_tokens: print(f"No valid tokens to display for '{description}'. Skipping.") return None return { "generated_image": generated_image.cpu(), "decoder_attention_maps": [m.cpu() for m in decoder_attention_maps], "initial_context_weights": initial_context_weights.cpu(), "display_tokens": display_tokens, "description": description, "pokemon_id": pokemon_id, } def plot_attention_visualization( # Plot identification arguments epoch: int, set_name: str, output_dir: str | None, # Data generated by the model (can be full batches) generated_images: torch.Tensor, decoder_attention_maps: list[torch.Tensor], initial_context_weights: torch.Tensor, # Original text input (can be a full batch) token_ids: torch.Tensor, attention_mask: torch.Tensor, tokenizer: AutoTokenizer, # Batch metadata (for the specific sample) description: str, pokemon_id: int | str, # Control options sample_idx: int = 0, show_inline: bool = False, ): """ Generates and saves an attention visualization for a single sample from a batch. This function is self-contained: it accepts full batch tensors and internally handles sample selection and token preparation. Args: epoch (int): Epoch number (for title/filename). set_name (str): Set name (e.g., 'train', for title/filename). output_dir (str, optional): Folder to save the image. If None, the plot is not saved. generated_images (torch.Tensor): Tensor of generated images. Shape: (B, C, H, W). decoder_attention_maps (list[torch.Tensor]): List of attention tensors. Each tensor shape: (B, num_patches, seq_length). initial_context_weights (torch.Tensor): Initial attention weights. Shape: (B, 1, seq_length). token_ids (torch.Tensor): Input token. Shape: (B, seq_length). attention_mask (torch.Tensor): Attention mask for tokens. Shape: (B, seq_length). tokenizer: The tokenizer object for id -> token conversion. description (str): The text prompt for the selected sample. pokemon_id (int or str): The ID of the selected sample. sample_idx (int, optional): Index of the sample in the batch to visualize. Defaults to 0. show_inline (bool, optional): If True, shows the plot. Defaults to False. """ # Select the specific sample using sample_idx and move to CPU img_tensor = generated_images[sample_idx].cpu() layer_maps = [m[sample_idx].cpu() for m in decoder_attention_maps if m is not None] initial_weights = initial_context_weights[sample_idx].cpu() token_ids_sample = token_ids[sample_idx].cpu() attention_mask_sample = attention_mask[sample_idx].cpu() # Token filtering logic tokens_all = tokenizer.convert_ids_to_tokens(token_ids_sample) display_tokens = [] for i, token in enumerate(tokens_all): if ( token not in [tokenizer.sep_token, tokenizer.pad_token] and attention_mask_sample[i] == 1 ): display_tokens.append({"token": token, "index": i}) img_tensor_cpu = denormalize_image(img_tensor).permute(1, 2, 0) num_decoder_layers = len(layer_maps) num_tokens = len(display_tokens) token_indices_to_display = [t["index"] for t in display_tokens] cols = min(num_tokens, 8) rows_per_layer = (num_tokens + cols - 1) // cols height_ratios = [3, 2] + [2 * rows_per_layer] * num_decoder_layers fig_height = sum(height_ratios) fig_width = max(20, 2.5 * cols) fig = plt.figure(figsize=(fig_width, fig_height)) gs_main = fig.add_gridspec(len(height_ratios), 1, height_ratios=height_ratios, hspace=1.2) fig.suptitle(f"Epoch {epoch}: Attention for Pokémon #{pokemon_id} ({set_name.capitalize()})", fontsize=24) ax_main_img = fig.add_subplot(gs_main[0]) ax_main_img.imshow(img_tensor_cpu) ax_main_img.set_title("Generated Image", fontsize=18) ax_main_img.text(0.5, -0.1, f"Prompt: {description}", ha="center", va="top", transform=ax_main_img.transAxes, fontsize=14, wrap=True) ax_main_img.axis("off") ax_initial_attn = fig.add_subplot(gs_main[1]) initial_weights_squeezed = initial_weights.squeeze().numpy() token_strings = [t["token"] for t in display_tokens] relevant_weights = initial_weights_squeezed[[t["index"] for t in display_tokens]] ax_initial_attn.bar(np.arange(len(token_strings)), relevant_weights, color="skyblue") ax_initial_attn.set_xticks(np.arange(len(token_strings))) ax_initial_attn.set_xticklabels(token_strings, rotation=45, ha="right", fontsize=10) ax_initial_attn.set_title("Initial Context Attention (Global)", fontsize=16) ax_initial_attn.set_ylabel("Weight", fontsize=12) ax_initial_attn.grid(axis="y", linestyle="--", alpha=0.7) # Iterate through each decoder layer's attention maps for i, layer_attn_map in enumerate(layer_maps): # layer_attn_map shape is now (num_patches, seq_len) map_size_flat = layer_attn_map.shape[0] map_side = int(np.sqrt(map_size_flat)) layer_title = f"Decoder Cross-Attention Layer {i+1} (Size: {map_side}x{map_side})" # Extract attention weights only for tokens we want to display relevant_attn_maps = layer_attn_map[:, token_indices_to_display] vmin, vmax = relevant_attn_maps.min(), relevant_attn_maps.max() # Create subplot grid for this layer gs_layer = gs_main[2 + i].subgridspec(rows_per_layer, cols + 1, wspace=0.2, hspace=0.4, width_ratios=[*([1] * cols), 0.1]) axes_in_layer = [fig.add_subplot(gs_layer[r, c]) for r in range(rows_per_layer) for c in range(cols)] # Add layer title above the token attention maps if axes_in_layer: y_pos = axes_in_layer[0].get_position().y1 fig.text(0.5, y_pos + 0.01, layer_title, ha="center", va="bottom", fontsize=16, weight="bold") # Plot attention heatmap for each token im = None for j, token_info in enumerate(display_tokens): if j >= len(axes_in_layer): break ax = axes_in_layer[j] attn_for_token = layer_attn_map[:, token_info["index"]] # Reshape flat attention to spatial grid heatmap = attn_for_token.reshape(map_side, map_side) im = ax.imshow(heatmap, cmap="jet", interpolation="nearest", vmin=vmin, vmax=vmax) ax.set_title(f"'{token_info['token']}'", fontsize=12) ax.axis("off") # Add colorbar for the layer if im: cax = fig.add_subplot(gs_layer[:, -1]) cbar = fig.colorbar(im, cax=cax) cbar.ax.tick_params(labelsize=10) cbar.set_label("Attention Weight", rotation=270, labelpad=15, fontsize=12) # Hide unused subplots for j in range(num_tokens, len(axes_in_layer)): axes_in_layer[j].axis("off") plt.tight_layout(rect=(0, 0.03, 1, 0.96)) if output_dir is not None: save_path = os.path.join(output_dir, f"{epoch:03d}_{set_name}_attention_visualization_{pokemon_id}.png") plt.savefig(save_path, bbox_inches="tight") # Save figure to bytes for potential further use (e.g., logging) buf = io.BytesIO() fig.savefig(buf, format='png', bbox_inches='tight', dpi=150) buf.seek(0) # Convert to PIL image attention_plot = Image.open(buf) if show_inline: plt.show() plt.close(fig) return attention_plot def save_plot_losses(losses_g, losses_d, output_dir="training_output", show_inline=True): """ Generates and saves a plot of the generator and discriminator losses. """ os.makedirs(output_dir, exist_ok=True) fig, ax = plt.subplots(figsize=(12, 6)) ax.plot(losses_g, label="Generator Loss", color="blue") ax.plot(losses_d, label="Discriminator Loss", color="red") ax.set_title("Training Losses") ax.set_xlabel("Epochs") ax.set_ylabel("Loss") ax.legend() ax.grid(True) save_path = os.path.join(output_dir, "training_losses.png") plt.savefig(save_path) print(f"Loss plot saved to: {save_path}") if show_inline: plt.show() else: plt.close(fig) def save_plot_non_gan_losses(train_losses_history, val_losses_history, output_dir="training_output", show_inline=True, filter_losses=None): """ Generates and saves plots of losses for non-GAN models with multiple loss components. Args: train_losses_history (list[dict]): List of dicts containing training losses per epoch. e.g., [{'l1': 0.5, 'sobel': 0.3}, ...] val_losses_history (list[dict]): List of dicts containing validation losses per epoch. output_dir (str): Directory to save the plot. show_inline (bool): Whether to display the plot inline. filter_losses (list[str], optional): List of loss names to plot. If None, plots all found losses. """ os.makedirs(output_dir, exist_ok=True) # Extract all unique loss keys from both training and validation all_keys = set() for losses_dict in train_losses_history + val_losses_history: all_keys.update(losses_dict.keys()) # Filter out non-numeric keys if any loss_keys = [key for key in all_keys if key not in ['epoch']] # Apply filter if specified if filter_losses is not None: loss_keys = [key for key in loss_keys if key in filter_losses] loss_keys = sorted(loss_keys) # Sort for consistent ordering # Create subplots n_losses = len(loss_keys) cols = min(3, n_losses) # Max 3 columns rows = (n_losses + cols - 1) // cols # Ceiling division fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 4 * rows)) if n_losses == 1: axes = [axes] elif rows > 1: axes = axes.flatten() fig.suptitle("Training and Validation Losses", fontsize=16, y=0.98) for i, loss_key in enumerate(loss_keys): ax = axes[i] # Extract train and validation losses for this key train_values = [losses.get(loss_key, 0) for losses in train_losses_history] val_values = [losses.get(loss_key, 0) for losses in val_losses_history] epochs_train = range(1, len(train_values) + 1) epochs_val = range(1, len(val_values) + 1) # Plot training and validation curves if train_values: ax.plot(epochs_train, train_values, label=f"Train {loss_key}", color="blue", linewidth=1.5) if val_values: ax.plot(epochs_val, val_values, label=f"Val {loss_key}", color="red", linewidth=1.5, linestyle='--') ax.set_title(f"{loss_key.capitalize()} Loss", fontsize=12) ax.set_xlabel("Epoch") ax.set_ylabel("Loss") ax.legend() ax.grid(True, alpha=0.3) ax.set_ylim(bottom=0) # Hide unused subplots for i in range(n_losses, len(axes)): axes[i].set_visible(False) plt.tight_layout() # Save the plot save_path = os.path.join(output_dir, "non_gan_training_losses.png") plt.savefig(save_path, dpi=150, bbox_inches='tight') print(f"Non-GAN training losses plot saved to: {save_path}") if show_inline: plt.show() else: plt.close(fig) def save_comparison_grid(epoch, model, batch, set_name, device, output_dir="training_output", show_inline=True): """ Generates and saves/shows a horizontal comparison grid (real vs. generated). Automatically handles 256x256 or 64x64 output based on set_name. """ os.makedirs(output_dir, exist_ok=True) model.eval() token_ids = batch["text"].to(device) attention_mask = batch["attention_mask"].to(device) real_images = batch["image"] pokemon_ids = batch["idx"] descriptions = batch["description"] num_images = real_images.size(0) with torch.no_grad(): generated_images = model(token_ids, attention_mask) # Handle tuple output from generator (e.g., 256px and 64px images) if isinstance(generated_images, tuple): # Check if we want 64x64 or 256x256 based on set_name if "64" in set_name: generated_images = generated_images[1] # Use 64x64 output # Resize real images to 64x64 for comparison real_images = F.interpolate(real_images, size=(64, 64), mode='bilinear', align_corners=False) else: generated_images = generated_images[0] # Use 256x256 output fig, axs = plt.subplots(2, num_images, figsize=(4 * num_images, 8.5)) resolution = "64x64" if "64" in set_name else "256x256" fig.suptitle( f"Epoch {epoch} - {set_name.capitalize()} Comparison ({resolution})", fontsize=16, y=0.98 ) for i in range(num_images): ax_real = axs[0, i] ax_real.imshow(denormalize_image(real_images[i].cpu()).permute(1, 2, 0)) ax_real.set_title(f"#{pokemon_ids[i]}: {descriptions[i][:35]}...", fontsize=10) ax_real.axis("off") ax_gen = axs[1, i] ax_gen.imshow(denormalize_image(generated_images[i].cpu()).permute(1, 2, 0)) ax_gen.axis("off") axs[0, 0].text( -0.1, 0.5, "Real", ha="center", va="center", rotation="vertical", fontsize=14, transform=axs[0, 0].transAxes, ) axs[1, 0].text( -0.1, 0.5, "Generated", ha="center", va="center", rotation="vertical", fontsize=14, transform=axs[1, 0].transAxes, ) plt.tight_layout(rect=(0, 0, 1, 0.95)) # Save the figure and optionally show it save_path = os.path.join(output_dir, f"{epoch:03d}_{set_name}_comparison.png") plt.savefig(save_path) if show_inline: plt.show() else: plt.close(fig)