import torch import torch.nn as nn from model_blocks.text_encoder import TextEncoder class Discriminator256(nn.Module): def __init__(self, text_dim=256, img_channels=3): super(Discriminator256, self).__init__() self.text_encoder = TextEncoder() # Separate text encoder for discriminators self.img_path = nn.Sequential( # 256x256 -> 128x128 nn.Conv2d(img_channels, 16, 4, 2, 1, bias=False), nn.LeakyReLU(0.2, inplace=True), # 128x128 -> 64x64 nn.Conv2d(16, 32, 4, 2, 1, bias=False), nn.BatchNorm2d(32), nn.LeakyReLU(0.2, inplace=True), # 64x64 -> 32x32 nn.Conv2d(32, 64, 4, 2, 1, bias=False), nn.BatchNorm2d(64), nn.LeakyReLU(0.2, inplace=True), # 32x32 -> 16x16 nn.Conv2d(64, 128, 4, 2, 1, bias=False), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, inplace=True), # 16x16 -> 8x8 nn.Conv2d(128, 256, 4, 2, 1, bias=False), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, inplace=True), # 8x8 -> 4x4 nn.Conv2d(256, 512, 4, 2, 1, bias=False), nn.BatchNorm2d(512), nn.LeakyReLU(0.2, inplace=True), ) self.text_path = nn.Sequential( nn.Linear(text_dim, 1024), nn.LeakyReLU(0.2, inplace=True), nn.Linear(1024, 512) ) # Unconditional classifier (real/fake without text conditioning) self.unconditional_classifier = nn.Sequential( nn.Linear(512 * 4 * 4, 1024), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(0.5), nn.Linear(1024, 1), ) # Conditional classifier (text-conditioned real/fake) self.conditional_classifier = nn.Sequential( nn.Linear(512 * 4 * 4 + 512, 1024), # size: sum of flattened image and text embedding nn.LeakyReLU(0.2, inplace=True), nn.Dropout(0.5), nn.Linear(1024, 1), ) def forward(self, images, text_features=None, text_mask=None, return_both=True): # Encode image img_features = self.img_path(images) img_features_flat = img_features.view(img_features.size(0), -1) # Flatten unconditional_output = self.unconditional_classifier(img_features_flat) if not return_both: return unconditional_output if text_features is None or text_mask is None: raise AttributeError("text_features and text_mask necessary for text conditioning") # Encode text (mean pooling) global_full_text = self.text_encoder(text_features, text_mask) global_text = global_full_text.mean(dim=1) text_features_encoded = self.text_path(global_text) # Combine features combined = torch.cat([img_features_flat, text_features_encoded], dim=1) conditional_output = self.conditional_classifier(combined) return unconditional_output, conditional_output class Discriminator64(nn.Module): def __init__(self, text_dim=256, img_channels=3): super(Discriminator64, self).__init__() self.text_encoder = TextEncoder() self.img_path = nn.Sequential( # 64x64 -> 32x32 nn.Conv2d(img_channels, 16, 4, 2, 1, bias=False), nn.LeakyReLU(0.2, inplace=True), # 32x32 -> 16x16 nn.Conv2d(16, 32, 4, 2, 1, bias=False), nn.BatchNorm2d(32), nn.LeakyReLU(0.2, inplace=True), # 16x16 -> 8x8 nn.Conv2d(32, 64, 4, 2, 1, bias=False), nn.BatchNorm2d(64), nn.LeakyReLU(0.2, inplace=True), # 8x8 -> 4x4 nn.Conv2d(64, 128, 4, 2, 1, bias=False), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, inplace=True), ) # Text encoder for discriminator self.text_path = nn.Sequential( nn.Linear(text_dim, 1024), nn.LeakyReLU(0.2, inplace=True), nn.Linear(1024, 512) ) # Unconditional classifier (real/fake without text conditioning) self.unconditional_classifier = nn.Sequential( nn.Linear(128 * 4 * 4, 1024), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(0.5), nn.Linear(1024, 1), ) # Conditional classifier (text-conditioned real/fake) self.conditional_classifier = nn.Sequential( nn.Linear(128 * 4 * 4 + 512, 1024), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(0.5), nn.Linear(1024, 1), ) def forward(self, images, text_features=None, text_mask=None, return_both=True): img_features = self.img_path(images) img_features_flat = img_features.view(img_features.size(0), -1) # Flatten unconditional_output = self.unconditional_classifier(img_features_flat) if not return_both: return unconditional_output if text_features is None or text_mask is None: raise AttributeError("text_features and text_mask necessary for text conditioning") # Encode text (mean pooling) global_full_text = self.text_encoder(text_features, text_mask) global_text = global_full_text.mean(dim=1) text_features_encoded = self.text_path(global_text) combined = torch.cat([img_features_flat, text_features_encoded], dim=1) conditional_output = self.conditional_classifier(combined) return unconditional_output, conditional_output