Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision import models | |
from torchvision.models import VGG19_Weights | |
class VGGPerceptualLoss(nn.Module): | |
""" | |
Perceptual loss using VGG19 pretrained on ImageNet. | |
We extract features at: | |
- relu1_2 (index: 3) | |
- relu2_2 (index: 8) | |
- relu3_2 (index: 17) | |
- relu4_2 (index: 26) | |
Then compute L1 distance between those feature maps. | |
Input images are in [-1,1]. We convert to [0,1], then normalize with ImageNet stats. | |
""" | |
def __init__(self, device): | |
super(VGGPerceptualLoss, self).__init__() | |
vgg19_features = models.vgg19(weights=VGG19_Weights.DEFAULT).features.to(device).eval() | |
# We only need layers up to 26 (relu4_2) | |
self.slices = nn.ModuleDict({ | |
"relu1_2": nn.Sequential(*list(vgg19_features.children())[:4]), # conv1_1, relu1_1, conv1_2, relu1_2 | |
"relu2_2": nn.Sequential(*list(vgg19_features.children())[4:9]), # pool1, conv2_1, relu2_1, conv2_2, relu2_2 | |
"relu3_2": nn.Sequential(*list(vgg19_features.children())[9:18]), # pool2, conv3_1, relu3_1, conv3_2, relu3_2, ... | |
"relu4_2": nn.Sequential(*list(vgg19_features.children())[18:27]) # pool3, conv4_1, relu4_1, conv4_2, relu4_2 | |
}) | |
for param in self.parameters(): | |
param.requires_grad = False | |
self.l1 = nn.L1Loss() | |
self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)) | |
self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)) | |
def forward(self, img_gen, img_ref): | |
""" | |
img_gen, img_ref: [B,3,H,W] in range [-1,1]. | |
Return: sum of L1 distances between VGG feature maps at chosen layers. | |
""" | |
# Convert to [0,1] | |
gen = (img_gen + 1.0) / 2.0 | |
ref = (img_ref + 1.0) / 2.0 | |
# Normalize | |
gen_norm = (gen - self.mean) / self.std | |
ref_norm = (ref - self.mean) / self.std | |
loss = 0.0 | |
x_gen = gen_norm | |
x_ref = ref_norm | |
for slice_mod in self.slices.values(): | |
x_gen = slice_mod(x_gen) | |
x_ref = slice_mod(x_ref) | |
loss += self.l1(x_gen, x_ref) | |
return loss | |
class SobelLoss(nn.Module): | |
""" | |
Computes the Sobel loss between two images, which encourages edge similarity. | |
This loss operates on the grayscale versions of the input images. | |
""" | |
def __init__(self): | |
super(SobelLoss, self).__init__() | |
# Sobel kernels for edge detection | |
self.kernel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3) | |
self.kernel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3) | |
self.l1 = nn.L1Loss() | |
# Grayscale conversion weights (ITU-R BT.601) | |
self.rgb_to_gray_weights = torch.tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1) | |
def _get_edges(self, img): | |
""" | |
Converts an RGB image to grayscale and applies Sobel filters. | |
Args: | |
img: [B, 3, H, W] image tensor in range [-1, 1]. | |
Returns: | |
Gradient magnitude map [B, 1, H, W]. | |
""" | |
# Convert from [-1, 1] to [0, 1] | |
img = (img + 1.0) / 2.0 | |
# Convert to grayscale | |
grayscale_img = F.conv2d(img, self.rgb_to_gray_weights.to(img.device)) | |
# Apply Sobel filters | |
grad_x = F.conv2d(grayscale_img, self.kernel_x.to(img.device), padding=1) | |
grad_y = F.conv2d(grayscale_img, self.kernel_y.to(img.device), padding=1) | |
# Compute gradient magnitude | |
edges = torch.sqrt(grad_x**2 + grad_y**2 + 1e-6) # add epsilon for stability | |
return edges | |
def forward(self, img_gen, img_ref): | |
""" | |
img_gen, img_ref: [B, 3, H, W] in range [-1, 1]. | |
Returns: L1 loss between the edge maps of the two images. | |
""" | |
edges_gen = self._get_edges(img_gen) | |
edges_ref = self._get_edges(img_ref) | |
return self.l1(edges_gen, edges_ref) | |