# Copyright (C) 2022-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # -------------------------------------------------------- # Heads for downstream tasks # -------------------------------------------------------- """ A head is a module where the __init__ defines only the head hyperparameters. A method setup(croconet) takes a CroCoNet and set all layers according to the head and croconet attributes. The forward takes the features as well as a dictionary img_info containing the keys 'width' and 'height' """ import torch import torch.nn as nn from .dpt_block import DPTOutputAdapter class PixelwiseTaskWithDPT(nn.Module): """DPT module for CroCo. by default, hooks_idx will be equal to: * for encoder-only: 4 equally spread layers * for encoder+decoder: last encoder + 3 equally spread layers of the decoder """ def __init__( self, *, hooks_idx=None, layer_dims=[96, 192, 384, 768], output_width_ratio=1, num_channels=1, postprocess=None, **kwargs, ): super(PixelwiseTaskWithDPT, self).__init__() self.return_all_blocks = True # backbone needs to return all layers self.postprocess = postprocess self.output_width_ratio = output_width_ratio self.num_channels = num_channels self.hooks_idx = hooks_idx self.layer_dims = layer_dims def setup(self, croconet): dpt_args = { "output_width_ratio": self.output_width_ratio, "num_channels": self.num_channels, } if self.hooks_idx is None: if hasattr(croconet, "dec_blocks"): # encoder + decoder step = {8: 3, 12: 4, 24: 8}[croconet.dec_depth] hooks_idx = [ croconet.dec_depth + croconet.enc_depth - 1 - i * step for i in range(3, -1, -1) ] else: # encoder only step = croconet.enc_depth // 4 hooks_idx = [ croconet.enc_depth - 1 - i * step for i in range(3, -1, -1) ] self.hooks_idx = hooks_idx print( f" PixelwiseTaskWithDPT: automatically setting hook_idxs={self.hooks_idx}" ) dpt_args["hooks"] = self.hooks_idx dpt_args["layer_dims"] = self.layer_dims self.dpt = DPTOutputAdapter(**dpt_args) dim_tokens = [ croconet.enc_embed_dim if hook < croconet.enc_depth else croconet.dec_embed_dim for hook in self.hooks_idx ] dpt_init_args = {"dim_tokens_enc": dim_tokens} self.dpt.init(**dpt_init_args) def forward(self, x, img_info): out = self.dpt(x, image_size=(img_info["height"], img_info["width"])) if self.postprocess: out = self.postprocess(out) return out