Val-2's picture
First commit
66347a3
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import io
from PIL import Image
from utils import denormalize_image
import torch.nn.functional as F
from transformers import AutoTokenizer
def save_attention_visualization(
epoch, model, tokenizer, batch, device, set_name, output_dir, show_inline=False
):
print(f"Epoch {epoch}: Generating attention visualization for {set_name} set...")
attention_data = generate_attention_data(model, tokenizer, batch, device)
if attention_data:
plot_attention_visualization(
epoch=epoch,
set_name=set_name,
output_dir=output_dir,
show_inline=show_inline,
**attention_data,
)
print(f"Epoch {epoch}: Attention visualization saved for Pokémon #{attention_data['pokemon_id']}.")
else:
print(f"Epoch {epoch}: Skipped attention visualization due to missing data.")
def generate_attention_data(model, tokenizer, batch, device):
"""
Runs the model to generate the image and attention maps, filtering the padding tokens.
"""
model.eval()
with torch.no_grad():
token_ids = batch["text"].to(device)
attention_mask = batch["attention_mask"].to(device)
# Ensure batch size is 1 for visualization
if token_ids.dim() > 1:
token_ids = token_ids[0].unsqueeze(0)
attention_mask = attention_mask[0].unsqueeze(0)
# Get the first sample from the batch
pokemon_id = batch["idx"][0]
description = batch["description"][0]
generated_image, attention_maps, initial_context_weights = model(
token_ids, attention_mask, return_attentions=True
)
decoder_attention_maps = [m for m in attention_maps if m is not None]
if not decoder_attention_maps or initial_context_weights is None:
print("Attention maps not available. Skipping data generation.")
return None
# Extract valid tokens to display
tokens_all = tokenizer.convert_ids_to_tokens(token_ids.squeeze(0))
display_tokens = []
for i, token in enumerate(tokens_all):
if (
token not in [tokenizer.sep_token, tokenizer.pad_token]
and attention_mask[0, i] == 1
):
display_tokens.append({"token": token, "index": i})
if not display_tokens:
print(f"No valid tokens to display for '{description}'. Skipping.")
return None
return {
"generated_image": generated_image.cpu(),
"decoder_attention_maps": [m.cpu() for m in decoder_attention_maps],
"initial_context_weights": initial_context_weights.cpu(),
"display_tokens": display_tokens,
"description": description,
"pokemon_id": pokemon_id,
}
def plot_attention_visualization(
# Plot identification arguments
epoch: int,
set_name: str,
output_dir: str | None,
# Data generated by the model (can be full batches)
generated_images: torch.Tensor,
decoder_attention_maps: list[torch.Tensor],
initial_context_weights: torch.Tensor,
# Original text input (can be a full batch)
token_ids: torch.Tensor,
attention_mask: torch.Tensor,
tokenizer: AutoTokenizer,
# Batch metadata (for the specific sample)
description: str,
pokemon_id: int | str,
# Control options
sample_idx: int = 0,
show_inline: bool = False,
):
"""
Generates and saves an attention visualization for a single sample from a batch.
This function is self-contained: it accepts full batch tensors and internally
handles sample selection and token preparation.
Args:
epoch (int): Epoch number (for title/filename).
set_name (str): Set name (e.g., 'train', for title/filename).
output_dir (str, optional): Folder to save the image. If None, the plot is not saved.
generated_images (torch.Tensor): Tensor of generated images.
Shape: (B, C, H, W).
decoder_attention_maps (list[torch.Tensor]): List of attention tensors.
Each tensor shape: (B, num_patches, seq_length).
initial_context_weights (torch.Tensor): Initial attention weights.
Shape: (B, 1, seq_length).
token_ids (torch.Tensor): Input token.
Shape: (B, seq_length).
attention_mask (torch.Tensor): Attention mask for tokens.
Shape: (B, seq_length).
tokenizer: The tokenizer object for id -> token conversion.
description (str): The text prompt for the selected sample.
pokemon_id (int or str): The ID of the selected sample.
sample_idx (int, optional): Index of the sample in the batch to visualize.
Defaults to 0.
show_inline (bool, optional): If True, shows the plot. Defaults to False.
"""
# Select the specific sample using sample_idx and move to CPU
img_tensor = generated_images[sample_idx].cpu()
layer_maps = [m[sample_idx].cpu() for m in decoder_attention_maps if m is not None]
initial_weights = initial_context_weights[sample_idx].cpu()
token_ids_sample = token_ids[sample_idx].cpu()
attention_mask_sample = attention_mask[sample_idx].cpu()
# Token filtering logic
tokens_all = tokenizer.convert_ids_to_tokens(token_ids_sample)
display_tokens = []
for i, token in enumerate(tokens_all):
if (
token not in [tokenizer.sep_token, tokenizer.pad_token]
and attention_mask_sample[i] == 1
):
display_tokens.append({"token": token, "index": i})
img_tensor_cpu = denormalize_image(img_tensor).permute(1, 2, 0)
num_decoder_layers = len(layer_maps)
num_tokens = len(display_tokens)
token_indices_to_display = [t["index"] for t in display_tokens]
cols = min(num_tokens, 8)
rows_per_layer = (num_tokens + cols - 1) // cols
height_ratios = [3, 2] + [2 * rows_per_layer] * num_decoder_layers
fig_height = sum(height_ratios)
fig_width = max(20, 2.5 * cols)
fig = plt.figure(figsize=(fig_width, fig_height))
gs_main = fig.add_gridspec(len(height_ratios), 1, height_ratios=height_ratios, hspace=1.2)
fig.suptitle(f"Epoch {epoch}: Attention for Pokémon #{pokemon_id} ({set_name.capitalize()})", fontsize=24)
ax_main_img = fig.add_subplot(gs_main[0])
ax_main_img.imshow(img_tensor_cpu)
ax_main_img.set_title("Generated Image", fontsize=18)
ax_main_img.text(0.5, -0.1, f"Prompt: {description}", ha="center", va="top",
transform=ax_main_img.transAxes, fontsize=14, wrap=True)
ax_main_img.axis("off")
ax_initial_attn = fig.add_subplot(gs_main[1])
initial_weights_squeezed = initial_weights.squeeze().numpy()
token_strings = [t["token"] for t in display_tokens]
relevant_weights = initial_weights_squeezed[[t["index"] for t in display_tokens]]
ax_initial_attn.bar(np.arange(len(token_strings)), relevant_weights, color="skyblue")
ax_initial_attn.set_xticks(np.arange(len(token_strings)))
ax_initial_attn.set_xticklabels(token_strings, rotation=45, ha="right", fontsize=10)
ax_initial_attn.set_title("Initial Context Attention (Global)", fontsize=16)
ax_initial_attn.set_ylabel("Weight", fontsize=12)
ax_initial_attn.grid(axis="y", linestyle="--", alpha=0.7)
# Iterate through each decoder layer's attention maps
for i, layer_attn_map in enumerate(layer_maps):
# layer_attn_map shape is now (num_patches, seq_len)
map_size_flat = layer_attn_map.shape[0]
map_side = int(np.sqrt(map_size_flat))
layer_title = f"Decoder Cross-Attention Layer {i+1} (Size: {map_side}x{map_side})"
# Extract attention weights only for tokens we want to display
relevant_attn_maps = layer_attn_map[:, token_indices_to_display]
vmin, vmax = relevant_attn_maps.min(), relevant_attn_maps.max()
# Create subplot grid for this layer
gs_layer = gs_main[2 + i].subgridspec(rows_per_layer, cols + 1, wspace=0.2, hspace=0.4, width_ratios=[*([1] * cols), 0.1])
axes_in_layer = [fig.add_subplot(gs_layer[r, c]) for r in range(rows_per_layer) for c in range(cols)]
# Add layer title above the token attention maps
if axes_in_layer:
y_pos = axes_in_layer[0].get_position().y1
fig.text(0.5, y_pos + 0.01, layer_title, ha="center", va="bottom", fontsize=16, weight="bold")
# Plot attention heatmap for each token
im = None
for j, token_info in enumerate(display_tokens):
if j >= len(axes_in_layer):
break
ax = axes_in_layer[j]
attn_for_token = layer_attn_map[:, token_info["index"]]
# Reshape flat attention to spatial grid
heatmap = attn_for_token.reshape(map_side, map_side)
im = ax.imshow(heatmap, cmap="jet", interpolation="nearest", vmin=vmin, vmax=vmax)
ax.set_title(f"'{token_info['token']}'", fontsize=12)
ax.axis("off")
# Add colorbar for the layer
if im:
cax = fig.add_subplot(gs_layer[:, -1])
cbar = fig.colorbar(im, cax=cax)
cbar.ax.tick_params(labelsize=10)
cbar.set_label("Attention Weight", rotation=270, labelpad=15, fontsize=12)
# Hide unused subplots
for j in range(num_tokens, len(axes_in_layer)):
axes_in_layer[j].axis("off")
plt.tight_layout(rect=(0, 0.03, 1, 0.96))
if output_dir is not None:
save_path = os.path.join(output_dir, f"{epoch:03d}_{set_name}_attention_visualization_{pokemon_id}.png")
plt.savefig(save_path, bbox_inches="tight")
# Save figure to bytes for potential further use (e.g., logging)
buf = io.BytesIO()
fig.savefig(buf, format='png', bbox_inches='tight', dpi=150)
buf.seek(0)
# Convert to PIL image
attention_plot = Image.open(buf)
if show_inline:
plt.show()
plt.close(fig)
return attention_plot
def save_plot_losses(losses_g, losses_d, output_dir="training_output", show_inline=True):
"""
Generates and saves a plot of the generator and discriminator losses.
"""
os.makedirs(output_dir, exist_ok=True)
fig, ax = plt.subplots(figsize=(12, 6))
ax.plot(losses_g, label="Generator Loss", color="blue")
ax.plot(losses_d, label="Discriminator Loss", color="red")
ax.set_title("Training Losses")
ax.set_xlabel("Epochs")
ax.set_ylabel("Loss")
ax.legend()
ax.grid(True)
save_path = os.path.join(output_dir, "training_losses.png")
plt.savefig(save_path)
print(f"Loss plot saved to: {save_path}")
if show_inline:
plt.show()
else:
plt.close(fig)
def save_plot_non_gan_losses(train_losses_history, val_losses_history, output_dir="training_output", show_inline=True, filter_losses=None):
"""
Generates and saves plots of losses for non-GAN models with multiple loss components.
Args:
train_losses_history (list[dict]): List of dicts containing training losses per epoch.
e.g., [{'l1': 0.5, 'sobel': 0.3}, ...]
val_losses_history (list[dict]): List of dicts containing validation losses per epoch.
output_dir (str): Directory to save the plot.
show_inline (bool): Whether to display the plot inline.
filter_losses (list[str], optional): List of loss names to plot.
If None, plots all found losses.
"""
os.makedirs(output_dir, exist_ok=True)
# Extract all unique loss keys from both training and validation
all_keys = set()
for losses_dict in train_losses_history + val_losses_history:
all_keys.update(losses_dict.keys())
# Filter out non-numeric keys if any
loss_keys = [key for key in all_keys if key not in ['epoch']]
# Apply filter if specified
if filter_losses is not None:
loss_keys = [key for key in loss_keys if key in filter_losses]
loss_keys = sorted(loss_keys) # Sort for consistent ordering
# Create subplots
n_losses = len(loss_keys)
cols = min(3, n_losses) # Max 3 columns
rows = (n_losses + cols - 1) // cols # Ceiling division
fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 4 * rows))
if n_losses == 1:
axes = [axes]
elif rows > 1:
axes = axes.flatten()
fig.suptitle("Training and Validation Losses", fontsize=16, y=0.98)
for i, loss_key in enumerate(loss_keys):
ax = axes[i]
# Extract train and validation losses for this key
train_values = [losses.get(loss_key, 0) for losses in train_losses_history]
val_values = [losses.get(loss_key, 0) for losses in val_losses_history]
epochs_train = range(1, len(train_values) + 1)
epochs_val = range(1, len(val_values) + 1)
# Plot training and validation curves
if train_values:
ax.plot(epochs_train, train_values, label=f"Train {loss_key}", color="blue", linewidth=1.5)
if val_values:
ax.plot(epochs_val, val_values, label=f"Val {loss_key}", color="red", linewidth=1.5, linestyle='--')
ax.set_title(f"{loss_key.capitalize()} Loss", fontsize=12)
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_ylim(bottom=0)
# Hide unused subplots
for i in range(n_losses, len(axes)):
axes[i].set_visible(False)
plt.tight_layout()
# Save the plot
save_path = os.path.join(output_dir, "non_gan_training_losses.png")
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"Non-GAN training losses plot saved to: {save_path}")
if show_inline:
plt.show()
else:
plt.close(fig)
def save_comparison_grid(epoch, model, batch, set_name, device, output_dir="training_output", show_inline=True):
"""
Generates and saves/shows a horizontal comparison grid (real vs. generated).
Automatically handles 256x256 or 64x64 output based on set_name.
"""
os.makedirs(output_dir, exist_ok=True)
model.eval()
token_ids = batch["text"].to(device)
attention_mask = batch["attention_mask"].to(device)
real_images = batch["image"]
pokemon_ids = batch["idx"]
descriptions = batch["description"]
num_images = real_images.size(0)
with torch.no_grad():
generated_images = model(token_ids, attention_mask)
# Handle tuple output from generator (e.g., 256px and 64px images)
if isinstance(generated_images, tuple):
# Check if we want 64x64 or 256x256 based on set_name
if "64" in set_name:
generated_images = generated_images[1] # Use 64x64 output
# Resize real images to 64x64 for comparison
real_images = F.interpolate(real_images, size=(64, 64), mode='bilinear', align_corners=False)
else:
generated_images = generated_images[0] # Use 256x256 output
fig, axs = plt.subplots(2, num_images, figsize=(4 * num_images, 8.5))
resolution = "64x64" if "64" in set_name else "256x256"
fig.suptitle(
f"Epoch {epoch} - {set_name.capitalize()} Comparison ({resolution})", fontsize=16, y=0.98
)
for i in range(num_images):
ax_real = axs[0, i]
ax_real.imshow(denormalize_image(real_images[i].cpu()).permute(1, 2, 0))
ax_real.set_title(f"#{pokemon_ids[i]}: {descriptions[i][:35]}...", fontsize=10)
ax_real.axis("off")
ax_gen = axs[1, i]
ax_gen.imshow(denormalize_image(generated_images[i].cpu()).permute(1, 2, 0))
ax_gen.axis("off")
axs[0, 0].text(
-0.1,
0.5,
"Real",
ha="center",
va="center",
rotation="vertical",
fontsize=14,
transform=axs[0, 0].transAxes,
)
axs[1, 0].text(
-0.1,
0.5,
"Generated",
ha="center",
va="center",
rotation="vertical",
fontsize=14,
transform=axs[1, 0].transAxes,
)
plt.tight_layout(rect=(0, 0, 1, 0.95))
# Save the figure and optionally show it
save_path = os.path.join(output_dir, f"{epoch:03d}_{set_name}_comparison.png")
plt.savefig(save_path)
if show_inline:
plt.show()
else:
plt.close(fig)