Val-2's picture
First commit
66347a3
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from torchvision.models import VGG19_Weights
class VGGPerceptualLoss(nn.Module):
"""
Perceptual loss using VGG19 pretrained on ImageNet.
We extract features at:
- relu1_2 (index: 3)
- relu2_2 (index: 8)
- relu3_2 (index: 17)
- relu4_2 (index: 26)
Then compute L1 distance between those feature maps.
Input images are in [-1,1]. We convert to [0,1], then normalize with ImageNet stats.
"""
def __init__(self, device):
super(VGGPerceptualLoss, self).__init__()
vgg19_features = models.vgg19(weights=VGG19_Weights.DEFAULT).features.to(device).eval()
# We only need layers up to 26 (relu4_2)
self.slices = nn.ModuleDict({
"relu1_2": nn.Sequential(*list(vgg19_features.children())[:4]), # conv1_1, relu1_1, conv1_2, relu1_2
"relu2_2": nn.Sequential(*list(vgg19_features.children())[4:9]), # pool1, conv2_1, relu2_1, conv2_2, relu2_2
"relu3_2": nn.Sequential(*list(vgg19_features.children())[9:18]), # pool2, conv3_1, relu3_1, conv3_2, relu3_2, ...
"relu4_2": nn.Sequential(*list(vgg19_features.children())[18:27]) # pool3, conv4_1, relu4_1, conv4_2, relu4_2
})
for param in self.parameters():
param.requires_grad = False
self.l1 = nn.L1Loss()
self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1))
self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1))
def forward(self, img_gen, img_ref):
"""
img_gen, img_ref: [B,3,H,W] in range [-1,1].
Return: sum of L1 distances between VGG feature maps at chosen layers.
"""
# Convert to [0,1]
gen = (img_gen + 1.0) / 2.0
ref = (img_ref + 1.0) / 2.0
# Normalize
gen_norm = (gen - self.mean) / self.std
ref_norm = (ref - self.mean) / self.std
loss = 0.0
x_gen = gen_norm
x_ref = ref_norm
for slice_mod in self.slices.values():
x_gen = slice_mod(x_gen)
x_ref = slice_mod(x_ref)
loss += self.l1(x_gen, x_ref)
return loss
class SobelLoss(nn.Module):
"""
Computes the Sobel loss between two images, which encourages edge similarity.
This loss operates on the grayscale versions of the input images.
"""
def __init__(self):
super(SobelLoss, self).__init__()
# Sobel kernels for edge detection
self.kernel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3)
self.kernel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3)
self.l1 = nn.L1Loss()
# Grayscale conversion weights (ITU-R BT.601)
self.rgb_to_gray_weights = torch.tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1)
def _get_edges(self, img):
"""
Converts an RGB image to grayscale and applies Sobel filters.
Args:
img: [B, 3, H, W] image tensor in range [-1, 1].
Returns:
Gradient magnitude map [B, 1, H, W].
"""
# Convert from [-1, 1] to [0, 1]
img = (img + 1.0) / 2.0
# Convert to grayscale
grayscale_img = F.conv2d(img, self.rgb_to_gray_weights.to(img.device))
# Apply Sobel filters
grad_x = F.conv2d(grayscale_img, self.kernel_x.to(img.device), padding=1)
grad_y = F.conv2d(grayscale_img, self.kernel_y.to(img.device), padding=1)
# Compute gradient magnitude
edges = torch.sqrt(grad_x**2 + grad_y**2 + 1e-6) # add epsilon for stability
return edges
def forward(self, img_gen, img_ref):
"""
img_gen, img_ref: [B, 3, H, W] in range [-1, 1].
Returns: L1 loss between the edge maps of the two images.
"""
edges_gen = self._get_edges(img_gen)
edges_ref = self._get_edges(img_ref)
return self.l1(edges_gen, edges_ref)