zhiyuan8 commited on
Commit
cdb5700
·
verified ·
1 Parent(s): b3d89c3

Upload wkv.py

Browse files
Files changed (1) hide show
  1. wkv.py +157 -119
wkv.py CHANGED
@@ -1,16 +1,15 @@
1
  import torch
2
  from einops import rearrange
3
 
4
- from .hybrid_cache import TimeMixState, BlockState
5
  import math
6
  import torch.nn as nn
7
  from torch.nn import functional as F
8
  from .configuration_rwkv_hybrid import RwkvHybridConfig
9
- from typing import TYPE_CHECKING, Optional
10
- from transformers.cache_utils import Cache
11
 
12
  try:
13
- import triton
14
  from rwkvfla.ops.rwkv7 import (
15
  fused_recurrent_rwkv7,
16
  chunk_rwkv7,
@@ -33,17 +32,23 @@ except ImportError:
33
  fused_recurrent_rwkv6 = native_recurrent_rwkv6
34
  fused_addcmul_rwkv7 = torch_addcmul_rwkv7
35
 
 
 
 
 
 
 
 
 
 
36
 
37
  class Rwkv_Tmix_x070(nn.Module):
38
- def __init__(self, args: RwkvHybridConfig, layer_id, update_v_first, get_v_first):
39
  super().__init__()
40
  self.args = args
41
  self.layer_id = layer_id
42
  self.hidden_size = args.hidden_size
43
 
44
- self.update_v_first = update_v_first
45
- self.get_v_first = get_v_first
46
-
47
  self.head_size = args.head_size
48
  self.n_head = args.num_wkv_heads
49
  assert args.hidden_size % self.n_head == 0
@@ -55,7 +60,7 @@ class Rwkv_Tmix_x070(nn.Module):
55
  self.x_k = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
56
  self.x_v = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
57
  self.x_a = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
58
-
59
  D_DECAY_LORA = 64
60
  D_AAA_LORA = 64
61
  D_MV_LORA = 32
@@ -122,7 +127,6 @@ class Rwkv_Tmix_x070(nn.Module):
122
  )
123
  nn.init.constant_(
124
  self.x_a, 1.0 - torch.pow(ddd, 0.9 * ratio_1_to_almost0))
125
-
126
 
127
  def ortho_init(x, scale):
128
  shape = x.shape
@@ -181,7 +185,7 @@ class Rwkv_Tmix_x070(nn.Module):
181
  D_GATE_LORA, self.args.hidden_size), 0.1)
182
  )
183
  nn.init.constant_(
184
- self.x_g, 1.0 - torch.pow(ddd, 0.2 * ratio_1_to_almost0))
185
 
186
  nn.init.constant_(self.k_k, 0.85)
187
  nn.init.constant_(self.k_a, 1.0)
@@ -196,14 +200,14 @@ class Rwkv_Tmix_x070(nn.Module):
196
  nn.init.ones_(self.ln_x.weight)
197
  nn.init.zeros_(self.ln_x.bias)
198
 
199
- def apply_wkv7_state(self, r, k, v, w, a, b, s,
200
- output_final_state,
201
- cu_seqlens,
202
- head_first
203
- ):
204
-
205
  if r.device.type == "cpu":
206
- r, w, k, v, a, b = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.n_head), (r, w, k, v, a, b))
 
207
  o, state = native_recurrent_rwkv7(
208
  r=r, k=k, v=v, w=w,
209
  a=a, b=b,
@@ -215,8 +219,9 @@ class Rwkv_Tmix_x070(nn.Module):
215
  state = state.transpose(-1, -2)
216
  x = rearrange(o, "b h l d -> b l (h d)")
217
  else:
218
- r, w, k, v, a, b = map(lambda x: rearrange(x, 'b l (h d) -> b l h d', h=self.n_head), (r, w, k, v, a, b))
219
- wkv7_func = chunk_rwkv7 if self.training else fused_recurrent_rwkv7
 
220
  o, state = wkv7_func(
221
  r=r, k=k, v=v, w=w,
222
  a=a, b=b,
@@ -224,32 +229,27 @@ class Rwkv_Tmix_x070(nn.Module):
224
  initial_state=s,
225
  output_final_state=output_final_state,
226
  cu_seqlens=cu_seqlens,
227
- head_first=head_first,
228
  )
229
  x = rearrange(o, "b l h d -> b l (h d)")
230
  return x, state
231
 
 
232
  def forward(
233
  self,
234
  hidden_states,
235
- last_state: TimeMixState,
236
- sequence_mask: Optional[torch.Tensor] = None,
237
  use_cache: Optional[bool] = False,
238
  cu_seqlens: Optional[torch.Tensor] = None,
 
 
239
  **kwargs
240
  ):
241
- if sequence_mask is not None:
242
- hidden_states = hidden_states.mul(
243
- sequence_mask[:, -hidden_states.shape[-2]:, None])
244
-
245
  shift_state = last_state.shift_state
246
  B, T, C = hidden_states.size()
247
 
248
- if shift_state is not None:
249
- xx = torch.concat((shift_state.unsqueeze(
250
- 1), hidden_states[:, :-1]), dim=1) - hidden_states
251
- else:
252
- xx = self.time_shift(hidden_states) - hidden_states
253
 
254
  lx = hidden_states[:, -1]
255
 
@@ -257,7 +257,8 @@ class Rwkv_Tmix_x070(nn.Module):
257
  xr, xw, xk, xv, xa, xg = fused_addcmul_rwkv7(
258
  hidden_states, xx, self.x_r, self.x_w, self.x_k, self.x_v, self.x_a, self.x_g)
259
  else:
260
- xr, xw, xk, xv, xa, _ = fused_addcmul_rwkv7(hidden_states, xx, self.x_r, self.x_w, self.x_k, self.x_v, self.x_a)
 
261
 
262
  r = self.receptance(xr)
263
  w = (
@@ -266,21 +267,23 @@ class Rwkv_Tmix_x070(nn.Module):
266
  k = self.key(xk)
267
  v = self.value(xv)
268
  if self.layer_id == 0:
269
- self.update_v_first(v)
270
  else:
271
- # Original implementation
272
- v = v + (self.get_v_first().to(v.device) - v) * torch.sigmoid(
273
  self.v0 + (xv @ self.v1) @ self.v2
274
- ) # add value residual
275
 
 
 
276
  a = torch.sigmoid(
277
  self.a0 + (xa @ self.a1) @ self.a2
278
  ) # a is "in-context learning rate"
279
  if self.args.wkv_has_gate:
280
- g = torch.sigmoid(xg @ self.g1) @ self.g2
281
  kk = k * self.k_k
282
- kk = F.normalize(kk.view(B, T, self.n_head, -1), dim=-1, p=2.0).view(B, T, C)
283
- k = k * (1 + (a - 1) * self.k_a)
 
284
 
285
  wkv_state = last_state.wkv_state
286
  hidden_states, wkv_state = self.apply_wkv7_state(
@@ -292,66 +295,68 @@ class Rwkv_Tmix_x070(nn.Module):
292
  (kk * a),
293
  s=wkv_state,
294
  output_final_state=use_cache,
295
- cu_seqlens=cu_seqlens,
296
- head_first=False
297
  )
298
  if self.args.wkv_has_group_norm:
299
  hidden_states = self.ln_x(
300
  hidden_states.view(B * T, C)).view(B, T, C)
301
- hidden_states = hidden_states + (
302
- (r.view(B, T, self.n_head, -1) * k.view(B, T, self.n_head, -1) * self.r_k).sum(
303
- dim=-1, keepdim=True
304
- )
305
- * v.view(B, T, self.n_head, -1)
306
- ).view(B, T, C)
 
 
 
307
  hidden_states = self.output(
308
  hidden_states * g) if self.args.wkv_has_gate else self.output(hidden_states)
309
- return hidden_states, TimeMixState(lx, wkv_state)
310
 
311
 
312
  class Rwkv7Attention(nn.Module):
313
- def __init__(self, args: RwkvHybridConfig, layer_id, update_v_first, get_v_first):
314
  super().__init__()
315
  self.args = args
316
  self.layer_idx = layer_id
317
- self.time_mixer = Rwkv_Tmix_x070(
318
- args, layer_id, update_v_first, get_v_first)
319
 
320
  def forward(
321
  self,
322
  hidden_states: torch.Tensor,
323
- sequence_mask: Optional[torch.Tensor] = None,
324
- past_key_value: Optional[Cache] = None,
325
- use_cache: Optional[bool] = False,
326
  output_attentions: Optional[bool] = False,
 
 
 
 
 
327
  **kwargs
328
  ):
329
- if sequence_mask is not None:
330
- assert len(sequence_mask.shape) == 2, (
331
- "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
332
- "for padding purposes (0 indicating padding). "
333
- "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
334
- )
335
  batch_size, token_length, _ = hidden_states.shape
336
 
337
- if past_key_value is not None and len(past_key_value) > self.layer_idx:
338
  last_state = past_key_value[self.layer_idx][0]
339
  else:
340
  last_state = self.init_state(
341
  batch_size, hidden_states.device, hidden_states.dtype
342
  )
343
 
344
- attn_output, states = self.time_mixer(hidden_states=hidden_states,
345
- last_state=last_state.time_mix_state,
346
- sequence_mask=sequence_mask,
347
- use_cache=use_cache,
348
- **kwargs)
349
- last_state.time_mix_state = states
350
 
351
- if past_key_value is not None:
 
352
  past_key_value.update(token_length, last_state, self.layer_idx)
353
 
354
- return attn_output, None
355
 
356
  def init_state(self, batch_size, device, dtype) -> BlockState:
357
  wkv_states = torch.zeros(
@@ -364,10 +369,10 @@ class Rwkv7Attention(nn.Module):
364
  device=device,
365
  dtype=torch.float32,
366
  )
367
- token_shift = torch.zeros(
368
  (batch_size, self.args.hidden_size), device=device, dtype=dtype
369
  )
370
- return BlockState(TimeMixState(token_shift, wkv_states), None)
371
 
372
 
373
  class Rwkv_Tmix_x060(nn.Module):
@@ -380,8 +385,6 @@ class Rwkv_Tmix_x060(nn.Module):
380
  self.head_size = args.head_size
381
  self.n_head = args.num_wkv_heads
382
  assert args.hidden_size % self.n_head == 0
383
- H = self.n_head
384
- N = self.head_size
385
 
386
  with torch.no_grad():
387
  ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
@@ -445,7 +448,6 @@ class Rwkv_Tmix_x060(nn.Module):
445
 
446
  self.time_faaaa = nn.Parameter(
447
  tmp.reshape(self.n_head, self.head_size))
448
- # self.time_state = nn.Parameter(torch.zeros(self.n_head, self.head_size, self.head_size))
449
 
450
  self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
451
  self.receptance = nn.Linear(
@@ -465,27 +467,36 @@ class Rwkv_Tmix_x060(nn.Module):
465
  def post_init(self):
466
  pass
467
 
468
- def forward(self, x, last_state: TimeMixState):
 
 
 
 
 
 
 
 
 
469
  shift_state = last_state.shift_state
470
- B, T, C = x.size()
471
  H = self.n_head
472
- if shift_state is not None:
473
- xx = torch.concat((shift_state.unsqueeze(1), x[:, :-1]), dim=1) - x
474
- else:
475
- xx = self.time_shift(x) - x
476
- lx = x[:, -1]
477
 
478
- xxx = x + xx * self.time_maa_x
 
 
 
 
 
479
  xxx = torch.tanh(xxx @ self.time_maa_w1).view(B *
480
  T, 5, -1).transpose(0, 1)
481
  xxx = torch.bmm(xxx, self.time_maa_w2).view(5, B, T, -1)
482
  mw, mk, mv, mr, mg = xxx.unbind(dim=0)
483
 
484
- xw = x + xx * (self.time_maa_w + mw)
485
- xk = x + xx * (self.time_maa_k + mk)
486
- xv = x + xx * (self.time_maa_v + mv)
487
- xr = x + xx * (self.time_maa_r + mr)
488
- xg = x + xx * (self.time_maa_g + mg)
489
 
490
  r = self.receptance(xr)
491
  k = self.key(xk)
@@ -496,16 +507,18 @@ class Rwkv_Tmix_x060(nn.Module):
496
  w = self.time_decay + ww
497
 
498
  wkv_state = last_state.wkv_state
499
- x, wkv_state = self.apply_wkv6_state(
500
  B, T, C, H, r, k, v, w, u=self.time_faaaa, s=wkv_state
501
  )
502
  if self.args.wkv_has_group_norm:
503
- x = self.ln_x(x.view(B * T, C)).view(B, T, C)
504
- x = self.output(x * g)
505
- return x, TimeMixState(lx, wkv_state)
 
506
 
507
  def apply_wkv6_state(self, B, T, C, H, r, k, v, w, u, s):
508
- r, w, k, v = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.n_head), (r, w, k, v))
 
509
 
510
  if r.device.type == "cpu":
511
  wkv6_func = native_recurrent_rwkv6
@@ -535,31 +548,56 @@ class Rwkv6Attention(nn.Module):
535
  self.layer_idx = layer_id
536
  self.time_mixer = Rwkv_Tmix_x060(args, layer_id, **kwargs)
537
 
538
- def forward(self, hidden_states, past_key_value, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
539
  attn_output = hidden_states
540
- B, T, C = attn_output.size()
541
- if past_key_value is not None:
542
- if len(past_key_value) <= self.layer_idx:
543
- last_state = None
544
- else:
545
- last_state = past_key_value[self.layer_idx][0]
546
- if last_state is None:
547
- wkv_states = torch.zeros(
548
- (B, self.args.num_wkv_heads,
549
- self.args.head_size, self.args.head_size),
550
- device=attn_output.device,
551
- dtype=torch.float32,
552
- )
553
- token_shift = torch.zeros(
554
- (B, C), device=attn_output.device, dtype=attn_output.dtype
555
  )
556
- time_state = TimeMixState(token_shift, wkv_states)
557
- channel_state = None
558
- last_state = BlockState(time_state, channel_state)
559
- attn_output, states = self.time_mixer(
560
- attn_output, last_state.time_mix_state)
561
- last_state.time_mix_state = states
562
-
563
- if past_key_value is not None:
564
- past_key_value.update(T, last_state, self.layer_idx)
565
- return attn_output, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from einops import rearrange
3
 
 
4
  import math
5
  import torch.nn as nn
6
  from torch.nn import functional as F
7
  from .configuration_rwkv_hybrid import RwkvHybridConfig
8
+ from typing import Optional
9
+ from .hybrid_cache import HybridCache, AttnState, BlockState
10
 
11
  try:
12
+ import triton # pylint: disable=F401
13
  from rwkvfla.ops.rwkv7 import (
14
  fused_recurrent_rwkv7,
15
  chunk_rwkv7,
 
32
  fused_recurrent_rwkv6 = native_recurrent_rwkv6
33
  fused_addcmul_rwkv7 = torch_addcmul_rwkv7
34
 
35
+ from rwkvfla.utils import check_pytorch_version
36
+
37
+ if check_pytorch_version("2.6"):
38
+ compile_decorator = torch.compile
39
+ torch._dynamo.config.cache_size_limit = 512
40
+ else:
41
+ def compile_decorator(func):
42
+ return func
43
+
44
 
45
  class Rwkv_Tmix_x070(nn.Module):
46
+ def __init__(self, args: RwkvHybridConfig, layer_id, **kwargs):
47
  super().__init__()
48
  self.args = args
49
  self.layer_id = layer_id
50
  self.hidden_size = args.hidden_size
51
 
 
 
 
52
  self.head_size = args.head_size
53
  self.n_head = args.num_wkv_heads
54
  assert args.hidden_size % self.n_head == 0
 
60
  self.x_k = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
61
  self.x_v = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
62
  self.x_a = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
63
+
64
  D_DECAY_LORA = 64
65
  D_AAA_LORA = 64
66
  D_MV_LORA = 32
 
127
  )
128
  nn.init.constant_(
129
  self.x_a, 1.0 - torch.pow(ddd, 0.9 * ratio_1_to_almost0))
 
130
 
131
  def ortho_init(x, scale):
132
  shape = x.shape
 
185
  D_GATE_LORA, self.args.hidden_size), 0.1)
186
  )
187
  nn.init.constant_(
188
+ self.x_g, 1.0 - torch.pow(ddd, 0.2 * ratio_1_to_almost0))
189
 
190
  nn.init.constant_(self.k_k, 0.85)
191
  nn.init.constant_(self.k_a, 1.0)
 
200
  nn.init.ones_(self.ln_x.weight)
201
  nn.init.zeros_(self.ln_x.bias)
202
 
203
+ def apply_wkv7_state(
204
+ self, r, k, v, w, a, b, s,
205
+ output_final_state,
206
+ cu_seqlens
207
+ ):
 
208
  if r.device.type == "cpu":
209
+ r, w, k, v, a, b = map(lambda x: rearrange(
210
+ x, 'b l (h d) -> b h l d', h=self.n_head), (r, w, k, v, a, b))
211
  o, state = native_recurrent_rwkv7(
212
  r=r, k=k, v=v, w=w,
213
  a=a, b=b,
 
219
  state = state.transpose(-1, -2)
220
  x = rearrange(o, "b h l d -> b l (h d)")
221
  else:
222
+ r, w, k, v, a, b = map(lambda x: rearrange(
223
+ x, 'b l (h d) -> b l h d', h=self.n_head), (r, w, k, v, a, b))
224
+ wkv7_func = chunk_rwkv7 if r.shape[1] != 1 else fused_recurrent_rwkv7
225
  o, state = wkv7_func(
226
  r=r, k=k, v=v, w=w,
227
  a=a, b=b,
 
229
  initial_state=s,
230
  output_final_state=output_final_state,
231
  cu_seqlens=cu_seqlens,
232
+ head_first=False,
233
  )
234
  x = rearrange(o, "b l h d -> b l (h d)")
235
  return x, state
236
 
237
+ @compile_decorator
238
  def forward(
239
  self,
240
  hidden_states,
241
+ last_state: AttnState,
 
242
  use_cache: Optional[bool] = False,
243
  cu_seqlens: Optional[torch.Tensor] = None,
244
+ v_first: Optional[torch.Tensor] = None,
245
+ attention_mask: Optional[torch.Tensor] = None,
246
  **kwargs
247
  ):
 
 
 
 
248
  shift_state = last_state.shift_state
249
  B, T, C = hidden_states.size()
250
 
251
+ xx = torch.concat((shift_state.unsqueeze(
252
+ 1), hidden_states[:, :-1]), dim=1) - hidden_states
 
 
 
253
 
254
  lx = hidden_states[:, -1]
255
 
 
257
  xr, xw, xk, xv, xa, xg = fused_addcmul_rwkv7(
258
  hidden_states, xx, self.x_r, self.x_w, self.x_k, self.x_v, self.x_a, self.x_g)
259
  else:
260
+ xr, xw, xk, xv, xa, _ = fused_addcmul_rwkv7(
261
+ hidden_states, xx, self.x_r, self.x_w, self.x_k, self.x_v, self.x_a)
262
 
263
  r = self.receptance(xr)
264
  w = (
 
267
  k = self.key(xk)
268
  v = self.value(xv)
269
  if self.layer_id == 0:
270
+ v_first = v
271
  else:
272
+ v = torch.lerp(v, v_first, torch.sigmoid(
 
273
  self.v0 + (xv @ self.v1) @ self.v2
274
+ )) # add value residual
275
 
276
+ if attention_mask is not None:
277
+ v = v.mul(attention_mask[:, -v.shape[-2]:, None])
278
  a = torch.sigmoid(
279
  self.a0 + (xa @ self.a1) @ self.a2
280
  ) # a is "in-context learning rate"
281
  if self.args.wkv_has_gate:
282
+ g = torch.sigmoid(xg @ self.g1) @ self.g2 + 1.0
283
  kk = k * self.k_k
284
+ kk = F.normalize(kk.view(B, T, self.n_head, -1),
285
+ p=2.0, dim=-1, eps=1e-4 if kk.dtype == torch.float16 else 1e-12).view(B, T, C)
286
+ k = torch.lerp(k, k * a, self.k_a)
287
 
288
  wkv_state = last_state.wkv_state
289
  hidden_states, wkv_state = self.apply_wkv7_state(
 
295
  (kk * a),
296
  s=wkv_state,
297
  output_final_state=use_cache,
298
+ cu_seqlens=cu_seqlens
 
299
  )
300
  if self.args.wkv_has_group_norm:
301
  hidden_states = self.ln_x(
302
  hidden_states.view(B * T, C)).view(B, T, C)
303
+
304
+ # original code:
305
+ # weighted_sum_rk = (r.view(B, T, self.n_head, -1) * k.view(B, T, self.n_head, -1) * self.r_k).sum(
306
+ # dim=-1, keepdim=True
307
+ # )
308
+ weighted_sum_rk = torch.einsum('btij,btij,ij->btij', r.view(B, T, self.n_head, -1),
309
+ k.view(B, T, self.n_head, -1), self.r_k).sum(dim=-1, keepdim=True)
310
+ hidden_states = hidden_states + \
311
+ (weighted_sum_rk * v.view(B, T, self.n_head, -1)).view(B, T, C)
312
  hidden_states = self.output(
313
  hidden_states * g) if self.args.wkv_has_gate else self.output(hidden_states)
314
+ return hidden_states, AttnState(lx, wkv_state), v_first
315
 
316
 
317
  class Rwkv7Attention(nn.Module):
318
+ def __init__(self, args: RwkvHybridConfig, layer_id):
319
  super().__init__()
320
  self.args = args
321
  self.layer_idx = layer_id
322
+ self.time_mixer = Rwkv_Tmix_x070(args, layer_id)
 
323
 
324
  def forward(
325
  self,
326
  hidden_states: torch.Tensor,
327
+ attention_mask: Optional[torch.Tensor] = None,
328
+ position_ids: Optional[torch.Tensor] = None,
329
+ past_key_value: Optional[HybridCache] = None,
330
  output_attentions: Optional[bool] = False,
331
+ use_cache: Optional[bool] = False,
332
+ cache_position: Optional[torch.Tensor] = None,
333
+ position_embeddings: Optional[torch.Tensor] = None,
334
+ cu_seqlens: Optional[torch.Tensor] = None,
335
+ v_first: Optional[torch.Tensor] = None,
336
  **kwargs
337
  ):
338
+
 
 
 
 
 
339
  batch_size, token_length, _ = hidden_states.shape
340
 
341
+ if use_cache and len(past_key_value) > self.layer_idx:
342
  last_state = past_key_value[self.layer_idx][0]
343
  else:
344
  last_state = self.init_state(
345
  batch_size, hidden_states.device, hidden_states.dtype
346
  )
347
 
348
+ attn_output, states, v_first = self.time_mixer(hidden_states=hidden_states,
349
+ last_state=last_state.attn_state,
350
+ use_cache=use_cache,
351
+ cu_seqlens=cu_seqlens,
352
+ v_first=v_first,
353
+ **kwargs)
354
 
355
+ if use_cache:
356
+ last_state.attn_state = states
357
  past_key_value.update(token_length, last_state, self.layer_idx)
358
 
359
+ return attn_output, None, v_first
360
 
361
  def init_state(self, batch_size, device, dtype) -> BlockState:
362
  wkv_states = torch.zeros(
 
369
  device=device,
370
  dtype=torch.float32,
371
  )
372
+ shift_states = torch.zeros(
373
  (batch_size, self.args.hidden_size), device=device, dtype=dtype
374
  )
375
+ return BlockState(AttnState(shift_states, wkv_states), None)
376
 
377
 
378
  class Rwkv_Tmix_x060(nn.Module):
 
385
  self.head_size = args.head_size
386
  self.n_head = args.num_wkv_heads
387
  assert args.hidden_size % self.n_head == 0
 
 
388
 
389
  with torch.no_grad():
390
  ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
 
448
 
449
  self.time_faaaa = nn.Parameter(
450
  tmp.reshape(self.n_head, self.head_size))
 
451
 
452
  self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
453
  self.receptance = nn.Linear(
 
467
  def post_init(self):
468
  pass
469
 
470
+ @compile_decorator
471
+ def forward(
472
+ self,
473
+ hidden_states,
474
+ last_state: AttnState,
475
+ use_cache: Optional[bool] = False,
476
+ cu_seqlens: Optional[torch.Tensor] = None,
477
+ v_first: Optional[torch.Tensor] = None,
478
+ **kwargs
479
+ ):
480
  shift_state = last_state.shift_state
481
+ B, T, C = hidden_states.size()
482
  H = self.n_head
 
 
 
 
 
483
 
484
+ xx = torch.concat((shift_state.unsqueeze(
485
+ 1), hidden_states[:, :-1]), dim=1) - hidden_states
486
+
487
+ lx = hidden_states[:, -1]
488
+
489
+ xxx = hidden_states + xx * self.time_maa_x
490
  xxx = torch.tanh(xxx @ self.time_maa_w1).view(B *
491
  T, 5, -1).transpose(0, 1)
492
  xxx = torch.bmm(xxx, self.time_maa_w2).view(5, B, T, -1)
493
  mw, mk, mv, mr, mg = xxx.unbind(dim=0)
494
 
495
+ xw = hidden_states + xx * (self.time_maa_w + mw)
496
+ xk = hidden_states + xx * (self.time_maa_k + mk)
497
+ xv = hidden_states + xx * (self.time_maa_v + mv)
498
+ xr = hidden_states + xx * (self.time_maa_r + mr)
499
+ xg = hidden_states + xx * (self.time_maa_g + mg)
500
 
501
  r = self.receptance(xr)
502
  k = self.key(xk)
 
507
  w = self.time_decay + ww
508
 
509
  wkv_state = last_state.wkv_state
510
+ hidden_states, wkv_state = self.apply_wkv6_state(
511
  B, T, C, H, r, k, v, w, u=self.time_faaaa, s=wkv_state
512
  )
513
  if self.args.wkv_has_group_norm:
514
+ hidden_states = self.ln_x(
515
+ hidden_states.view(B * T, C)).view(B, T, C)
516
+ hidden_states = self.output(hidden_states * g)
517
+ return hidden_states, AttnState(lx, wkv_state), None
518
 
519
  def apply_wkv6_state(self, B, T, C, H, r, k, v, w, u, s):
520
+ r, w, k, v = map(lambda x: rearrange(
521
+ x, 'b l (h d) -> b h l d', h=self.n_head), (r, w, k, v))
522
 
523
  if r.device.type == "cpu":
524
  wkv6_func = native_recurrent_rwkv6
 
548
  self.layer_idx = layer_id
549
  self.time_mixer = Rwkv_Tmix_x060(args, layer_id, **kwargs)
550
 
551
+ def forward(
552
+ self,
553
+ hidden_states: torch.Tensor,
554
+ attention_mask: Optional[torch.Tensor] = None,
555
+ position_ids: Optional[torch.Tensor] = None,
556
+ past_key_value: Optional[HybridCache] = None,
557
+ output_attentions: Optional[bool] = False,
558
+ use_cache: Optional[bool] = False,
559
+ cache_position: Optional[torch.Tensor] = None,
560
+ position_embeddings: Optional[torch.Tensor] = None,
561
+ cu_seqlens: Optional[torch.Tensor] = None,
562
+ v_first: Optional[torch.Tensor] = None,
563
+ **kwargs
564
+ ):
565
  attn_output = hidden_states
566
+
567
+ batch_size, token_length, _ = hidden_states.shape
568
+
569
+ if use_cache and len(past_key_value) > self.layer_idx:
570
+ last_state = past_key_value[self.layer_idx][0]
571
+ else:
572
+ last_state = self.init_state(
573
+ batch_size, hidden_states.device, hidden_states.dtype
 
 
 
 
 
 
 
574
  )
575
+
576
+ attn_output, states, v_first = self.time_mixer(hidden_states=hidden_states,
577
+ last_state=last_state.attn_state,
578
+ use_cache=use_cache,
579
+ cu_seqlens=cu_seqlens,
580
+ v_first=v_first,
581
+ **kwargs)
582
+
583
+ if use_cache:
584
+ last_state.attn_state = states
585
+ past_key_value.update(token_length, last_state, self.layer_idx)
586
+
587
+ return attn_output, None, v_first
588
+
589
+ def init_state(self, batch_size, device, dtype) -> BlockState:
590
+ wkv_states = torch.zeros(
591
+ (
592
+ batch_size,
593
+ self.args.num_wkv_heads,
594
+ self.args.head_size,
595
+ self.args.head_size,
596
+ ),
597
+ device=device,
598
+ dtype=torch.float32,
599
+ )
600
+ shift_states = torch.zeros(
601
+ (batch_size, self.args.hidden_size), device=device, dtype=dtype
602
+ )
603
+ return BlockState(AttnState(shift_states, wkv_states), None)