saurabhati commited on
Commit
c093c1b
·
verified ·
1 Parent(s): 30dee73

Update modeling_dass.py

Browse files
Files changed (1) hide show
  1. modeling_dass.py +35 -3
modeling_dass.py CHANGED
@@ -10,7 +10,6 @@ import warnings
10
  import torch.nn as nn
11
  import torch.nn.functional as F
12
  import torch.utils.checkpoint as checkpoint
13
- from timm.models.layers import DropPath, trunc_normal_
14
  from functools import partial
15
  from typing import Optional, Callable, Any, Union
16
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
@@ -718,6 +717,39 @@ def selective_scan_fn(
718
  ############## HuggingFace modeling file #################
719
  ##########################################################
720
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
721
  class DASSLinear2d(nn.Linear):
722
  def __init__(self, *args, groups=1, **kwargs):
723
  nn.Linear.__init__(self, *args, **kwargs)
@@ -1094,7 +1126,7 @@ class DASSPreTrainedModel(PreTrainedModel):
1094
  def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
1095
  """Initialize the weights"""
1096
  if isinstance(module, nn.Linear):
1097
- trunc_normal_(module.weight, std=0.02)
1098
  if isinstance(module, nn.Linear) and module.bias is not None:
1099
  nn.init.constant_(module.bias, 0)
1100
  elif isinstance(module, nn.LayerNorm):
@@ -1193,4 +1225,4 @@ __all__ = [
1193
  "DASSModel",
1194
  "DASSPreTrainedModel",
1195
  "DASSForAudioClassification",
1196
- ]
 
10
  import torch.nn as nn
11
  import torch.nn.functional as F
12
  import torch.utils.checkpoint as checkpoint
 
13
  from functools import partial
14
  from typing import Optional, Callable, Any, Union
15
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
 
717
  ############## HuggingFace modeling file #################
718
  ##########################################################
719
 
720
+ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
721
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
722
+
723
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
724
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
725
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
726
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
727
+ 'survival rate' as the argument.
728
+
729
+ """
730
+ if drop_prob == 0. or not training:
731
+ return x
732
+ keep_prob = 1 - drop_prob
733
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
734
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
735
+ if keep_prob > 0.0 and scale_by_keep:
736
+ random_tensor.div_(keep_prob)
737
+ return x * random_tensor
738
+
739
+ class DropPath(nn.Module):
740
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
741
+ """
742
+ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
743
+ super(DropPath, self).__init__()
744
+ self.drop_prob = drop_prob
745
+ self.scale_by_keep = scale_by_keep
746
+
747
+ def forward(self, x):
748
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
749
+
750
+ def extra_repr(self):
751
+ return f'drop_prob={round(self.drop_prob,3):0.3f}'
752
+
753
  class DASSLinear2d(nn.Linear):
754
  def __init__(self, *args, groups=1, **kwargs):
755
  nn.Linear.__init__(self, *args, **kwargs)
 
1126
  def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
1127
  """Initialize the weights"""
1128
  if isinstance(module, nn.Linear):
1129
+ nn.init.trunc_normal_(module.weight, std=0.02)
1130
  if isinstance(module, nn.Linear) and module.bias is not None:
1131
  nn.init.constant_(module.bias, 0)
1132
  elif isinstance(module, nn.LayerNorm):
 
1225
  "DASSModel",
1226
  "DASSPreTrainedModel",
1227
  "DASSForAudioClassification",
1228
+ ]