HongyuanTao commited on
Commit
3b8153b
·
verified ·
1 Parent(s): 6790e0a

Update modeling_mmMamba_embedding.py

Browse files
Files changed (1) hide show
  1. modeling_mmMamba_embedding.py +104 -29
modeling_mmMamba_embedding.py CHANGED
@@ -14,52 +14,30 @@
14
  # See the License for the specific language governing permissions and
15
  # limitations under the License.
16
  import math
17
- import queue
18
- import threading
19
- import warnings
20
- from typing import List, Optional, Tuple, Union
21
- from functools import partial
22
 
23
  import torch
24
  import torch.nn.functional as F
25
  import torch.utils.checkpoint
26
  from einops import rearrange
27
  from torch import nn
28
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
  from transformers.activations import ACT2FN
30
- from transformers.modeling_outputs import (
31
- BaseModelOutputWithPast,
32
- CausalLMOutputWithPast,
33
- SequenceClassifierOutputWithPast,
34
- )
35
  from transformers.modeling_utils import PreTrainedModel
36
- from transformers.cache_utils import Cache
37
- from transformers.utils import (
38
- add_start_docstrings,
39
- add_start_docstrings_to_model_forward,
40
- logging,
41
- replace_return_docstrings,
42
- )
43
- from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
44
- import copy
45
  from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
46
  from mamba_ssm.ops.triton.selective_state_update import selective_state_update
47
  from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
48
- from transformers.cache_utils import Cache
49
- import time
50
  from timm.models.layers import DropPath
51
 
52
  compute_ARank = False # [ARank] Set this to True to compute attention rank
53
 
54
- try:
55
- from transformers.generation.streamers import BaseStreamer
56
- except: # noqa # pylint: disable=bare-except
57
- BaseStreamer = None
58
-
59
  from .configuration_mmMamba_embedding import mmMambaEmbeddingConfig
60
 
61
- import time
62
-
63
  from .configuration_mmMamba import mmMambaConfig
64
 
65
  try:
@@ -128,6 +106,103 @@ class mmMambaRMSNorm(nn.Module):
128
  hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
129
  return self.weight * hidden_states.to(input_dtype)
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  class mmMambaMLP(nn.Module):
132
  def __init__(self, config):
133
  super().__init__()
 
14
  # See the License for the specific language governing permissions and
15
  # limitations under the License.
16
  import math
17
+ from typing import Optional, Tuple
 
 
 
 
18
 
19
  import torch
20
  import torch.nn.functional as F
21
  import torch.utils.checkpoint
22
  from einops import rearrange
23
  from torch import nn
 
24
  from transformers.activations import ACT2FN
25
+
 
 
 
 
26
  from transformers.modeling_utils import PreTrainedModel
27
+ from transformers.utils import logging
28
+
29
+ from fused_norm_gate import FusedRMSNormSwishGate
30
+
 
 
 
 
 
31
  from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
32
  from mamba_ssm.ops.triton.selective_state_update import selective_state_update
33
  from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
34
+
 
35
  from timm.models.layers import DropPath
36
 
37
  compute_ARank = False # [ARank] Set this to True to compute attention rank
38
 
 
 
 
 
 
39
  from .configuration_mmMamba_embedding import mmMambaEmbeddingConfig
40
 
 
 
41
  from .configuration_mmMamba import mmMambaConfig
42
 
43
  try:
 
106
  hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
107
  return self.weight * hidden_states.to(input_dtype)
108
 
109
+
110
+ # Copied from transformers.model.llama.modeling_llama.LlamaRotaryEmbedding with Llama->mmMamba
111
+ class mmMambaRotaryEmbedding(nn.Module):
112
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
113
+ super().__init__()
114
+
115
+ self.dim = dim
116
+ self.max_position_embeddings = max_position_embeddings
117
+ self.base = base
118
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
119
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
120
+
121
+ # Build here to make `torch.jit.trace` work.
122
+ self._set_cos_sin_cache(
123
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
124
+ )
125
+
126
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
127
+ self.max_seq_len_cached = seq_len
128
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
129
+
130
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
131
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
132
+ emb = torch.cat((freqs, freqs), dim=-1)
133
+ self.cos_cached = emb.cos().to(dtype)
134
+ self.sin_cached = emb.sin().to(dtype)
135
+ #self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
136
+ #self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
137
+
138
+ def forward(self, x, seq_len=None):
139
+ # x: [bs, num_attention_heads, seq_len, head_size]
140
+ if seq_len > self.max_seq_len_cached:
141
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=torch.float32)
142
+
143
+ return (
144
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
145
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
146
+ )
147
+
148
+
149
+ # Copied from transformers.model.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->mmMamba
150
+ class mmMambaLinearScalingRotaryEmbedding(mmMambaRotaryEmbedding):
151
+ """mmMambaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
152
+
153
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
154
+ self.scaling_factor = scaling_factor
155
+ super().__init__(dim, max_position_embeddings, base, device)
156
+
157
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
158
+ self.max_seq_len_cached = seq_len
159
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
160
+ t = t / self.scaling_factor
161
+
162
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
163
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
164
+ emb = torch.cat((freqs, freqs), dim=-1)
165
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
166
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
167
+
168
+
169
+ # Copied from transformers.model.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->mmMamba
170
+ class mmMambaDynamicNTKScalingRotaryEmbedding(mmMambaRotaryEmbedding):
171
+ """mmMambaRotaryEmbedding extended with Dynamic NTK scaling.
172
+ Credits to the Reddit users /u/bloc97 and /u/emozilla.
173
+ """
174
+
175
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
176
+ self.scaling_factor = scaling_factor
177
+ super().__init__(dim, max_position_embeddings, base, device)
178
+
179
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
180
+ self.max_seq_len_cached = seq_len
181
+
182
+ if seq_len > self.max_position_embeddings:
183
+ base = self.base * (
184
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
185
+ ) ** (self.dim / (self.dim - 2))
186
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
187
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
188
+
189
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
190
+
191
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
192
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
193
+ emb = torch.cat((freqs, freqs), dim=-1)
194
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
195
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
196
+
197
+
198
+ # Copied from transformers.model.llama.modeling_llama.rotate_half
199
+ def rotate_half(x):
200
+ """Rotates half the hidden dims of the input."""
201
+ x1 = x[..., : x.shape[-1] // 2]
202
+ x2 = x[..., x.shape[-1] // 2 :]
203
+ return torch.cat((-x2, x1), dim=-1)
204
+
205
+
206
  class mmMambaMLP(nn.Module):
207
  def __init__(self, config):
208
  super().__init__()