Update modeling_dass.py
Browse files- modeling_dass.py +35 -3
modeling_dass.py
CHANGED
@@ -10,7 +10,6 @@ import warnings
|
|
10 |
import torch.nn as nn
|
11 |
import torch.nn.functional as F
|
12 |
import torch.utils.checkpoint as checkpoint
|
13 |
-
from timm.models.layers import DropPath, trunc_normal_
|
14 |
from functools import partial
|
15 |
from typing import Optional, Callable, Any, Union
|
16 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
|
@@ -718,6 +717,39 @@ def selective_scan_fn(
|
|
718 |
############## HuggingFace modeling file #################
|
719 |
##########################################################
|
720 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
721 |
class DASSLinear2d(nn.Linear):
|
722 |
def __init__(self, *args, groups=1, **kwargs):
|
723 |
nn.Linear.__init__(self, *args, **kwargs)
|
@@ -1094,7 +1126,7 @@ class DASSPreTrainedModel(PreTrainedModel):
|
|
1094 |
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
1095 |
"""Initialize the weights"""
|
1096 |
if isinstance(module, nn.Linear):
|
1097 |
-
trunc_normal_(module.weight, std=0.02)
|
1098 |
if isinstance(module, nn.Linear) and module.bias is not None:
|
1099 |
nn.init.constant_(module.bias, 0)
|
1100 |
elif isinstance(module, nn.LayerNorm):
|
@@ -1193,4 +1225,4 @@ __all__ = [
|
|
1193 |
"DASSModel",
|
1194 |
"DASSPreTrainedModel",
|
1195 |
"DASSForAudioClassification",
|
1196 |
-
]
|
|
|
10 |
import torch.nn as nn
|
11 |
import torch.nn.functional as F
|
12 |
import torch.utils.checkpoint as checkpoint
|
|
|
13 |
from functools import partial
|
14 |
from typing import Optional, Callable, Any, Union
|
15 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
|
|
|
717 |
############## HuggingFace modeling file #################
|
718 |
##########################################################
|
719 |
|
720 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
|
721 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
722 |
+
|
723 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
724 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
725 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
726 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
727 |
+
'survival rate' as the argument.
|
728 |
+
|
729 |
+
"""
|
730 |
+
if drop_prob == 0. or not training:
|
731 |
+
return x
|
732 |
+
keep_prob = 1 - drop_prob
|
733 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
734 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
735 |
+
if keep_prob > 0.0 and scale_by_keep:
|
736 |
+
random_tensor.div_(keep_prob)
|
737 |
+
return x * random_tensor
|
738 |
+
|
739 |
+
class DropPath(nn.Module):
|
740 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
741 |
+
"""
|
742 |
+
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
|
743 |
+
super(DropPath, self).__init__()
|
744 |
+
self.drop_prob = drop_prob
|
745 |
+
self.scale_by_keep = scale_by_keep
|
746 |
+
|
747 |
+
def forward(self, x):
|
748 |
+
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
749 |
+
|
750 |
+
def extra_repr(self):
|
751 |
+
return f'drop_prob={round(self.drop_prob,3):0.3f}'
|
752 |
+
|
753 |
class DASSLinear2d(nn.Linear):
|
754 |
def __init__(self, *args, groups=1, **kwargs):
|
755 |
nn.Linear.__init__(self, *args, **kwargs)
|
|
|
1126 |
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
1127 |
"""Initialize the weights"""
|
1128 |
if isinstance(module, nn.Linear):
|
1129 |
+
nn.init.trunc_normal_(module.weight, std=0.02)
|
1130 |
if isinstance(module, nn.Linear) and module.bias is not None:
|
1131 |
nn.init.constant_(module.bias, 0)
|
1132 |
elif isinstance(module, nn.LayerNorm):
|
|
|
1225 |
"DASSModel",
|
1226 |
"DASSPreTrainedModel",
|
1227 |
"DASSForAudioClassification",
|
1228 |
+
]
|