Upload wkv.py
Browse files
wkv.py
CHANGED
@@ -1,16 +1,15 @@
|
|
1 |
import torch
|
2 |
from einops import rearrange
|
3 |
|
4 |
-
from .hybrid_cache import TimeMixState, BlockState
|
5 |
import math
|
6 |
import torch.nn as nn
|
7 |
from torch.nn import functional as F
|
8 |
from .configuration_rwkv_hybrid import RwkvHybridConfig
|
9 |
-
from typing import
|
10 |
-
from
|
11 |
|
12 |
try:
|
13 |
-
import triton
|
14 |
from rwkvfla.ops.rwkv7 import (
|
15 |
fused_recurrent_rwkv7,
|
16 |
chunk_rwkv7,
|
@@ -33,17 +32,23 @@ except ImportError:
|
|
33 |
fused_recurrent_rwkv6 = native_recurrent_rwkv6
|
34 |
fused_addcmul_rwkv7 = torch_addcmul_rwkv7
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
class Rwkv_Tmix_x070(nn.Module):
|
38 |
-
def __init__(self, args: RwkvHybridConfig, layer_id,
|
39 |
super().__init__()
|
40 |
self.args = args
|
41 |
self.layer_id = layer_id
|
42 |
self.hidden_size = args.hidden_size
|
43 |
|
44 |
-
self.update_v_first = update_v_first
|
45 |
-
self.get_v_first = get_v_first
|
46 |
-
|
47 |
self.head_size = args.head_size
|
48 |
self.n_head = args.num_wkv_heads
|
49 |
assert args.hidden_size % self.n_head == 0
|
@@ -55,7 +60,7 @@ class Rwkv_Tmix_x070(nn.Module):
|
|
55 |
self.x_k = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
|
56 |
self.x_v = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
|
57 |
self.x_a = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
|
58 |
-
|
59 |
D_DECAY_LORA = 64
|
60 |
D_AAA_LORA = 64
|
61 |
D_MV_LORA = 32
|
@@ -122,7 +127,6 @@ class Rwkv_Tmix_x070(nn.Module):
|
|
122 |
)
|
123 |
nn.init.constant_(
|
124 |
self.x_a, 1.0 - torch.pow(ddd, 0.9 * ratio_1_to_almost0))
|
125 |
-
|
126 |
|
127 |
def ortho_init(x, scale):
|
128 |
shape = x.shape
|
@@ -181,7 +185,7 @@ class Rwkv_Tmix_x070(nn.Module):
|
|
181 |
D_GATE_LORA, self.args.hidden_size), 0.1)
|
182 |
)
|
183 |
nn.init.constant_(
|
184 |
-
|
185 |
|
186 |
nn.init.constant_(self.k_k, 0.85)
|
187 |
nn.init.constant_(self.k_a, 1.0)
|
@@ -196,14 +200,14 @@ class Rwkv_Tmix_x070(nn.Module):
|
|
196 |
nn.init.ones_(self.ln_x.weight)
|
197 |
nn.init.zeros_(self.ln_x.bias)
|
198 |
|
199 |
-
def apply_wkv7_state(
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
if r.device.type == "cpu":
|
206 |
-
r, w, k, v, a, b = map(lambda x: rearrange(
|
|
|
207 |
o, state = native_recurrent_rwkv7(
|
208 |
r=r, k=k, v=v, w=w,
|
209 |
a=a, b=b,
|
@@ -215,8 +219,9 @@ class Rwkv_Tmix_x070(nn.Module):
|
|
215 |
state = state.transpose(-1, -2)
|
216 |
x = rearrange(o, "b h l d -> b l (h d)")
|
217 |
else:
|
218 |
-
r, w, k, v, a, b = map(lambda x: rearrange(
|
219 |
-
|
|
|
220 |
o, state = wkv7_func(
|
221 |
r=r, k=k, v=v, w=w,
|
222 |
a=a, b=b,
|
@@ -224,32 +229,27 @@ class Rwkv_Tmix_x070(nn.Module):
|
|
224 |
initial_state=s,
|
225 |
output_final_state=output_final_state,
|
226 |
cu_seqlens=cu_seqlens,
|
227 |
-
head_first=
|
228 |
)
|
229 |
x = rearrange(o, "b l h d -> b l (h d)")
|
230 |
return x, state
|
231 |
|
|
|
232 |
def forward(
|
233 |
self,
|
234 |
hidden_states,
|
235 |
-
last_state:
|
236 |
-
sequence_mask: Optional[torch.Tensor] = None,
|
237 |
use_cache: Optional[bool] = False,
|
238 |
cu_seqlens: Optional[torch.Tensor] = None,
|
|
|
|
|
239 |
**kwargs
|
240 |
):
|
241 |
-
if sequence_mask is not None:
|
242 |
-
hidden_states = hidden_states.mul(
|
243 |
-
sequence_mask[:, -hidden_states.shape[-2]:, None])
|
244 |
-
|
245 |
shift_state = last_state.shift_state
|
246 |
B, T, C = hidden_states.size()
|
247 |
|
248 |
-
|
249 |
-
|
250 |
-
1), hidden_states[:, :-1]), dim=1) - hidden_states
|
251 |
-
else:
|
252 |
-
xx = self.time_shift(hidden_states) - hidden_states
|
253 |
|
254 |
lx = hidden_states[:, -1]
|
255 |
|
@@ -257,7 +257,8 @@ class Rwkv_Tmix_x070(nn.Module):
|
|
257 |
xr, xw, xk, xv, xa, xg = fused_addcmul_rwkv7(
|
258 |
hidden_states, xx, self.x_r, self.x_w, self.x_k, self.x_v, self.x_a, self.x_g)
|
259 |
else:
|
260 |
-
xr, xw, xk, xv, xa, _ = fused_addcmul_rwkv7(
|
|
|
261 |
|
262 |
r = self.receptance(xr)
|
263 |
w = (
|
@@ -266,21 +267,23 @@ class Rwkv_Tmix_x070(nn.Module):
|
|
266 |
k = self.key(xk)
|
267 |
v = self.value(xv)
|
268 |
if self.layer_id == 0:
|
269 |
-
|
270 |
else:
|
271 |
-
|
272 |
-
v = v + (self.get_v_first().to(v.device) - v) * torch.sigmoid(
|
273 |
self.v0 + (xv @ self.v1) @ self.v2
|
274 |
-
) # add value residual
|
275 |
|
|
|
|
|
276 |
a = torch.sigmoid(
|
277 |
self.a0 + (xa @ self.a1) @ self.a2
|
278 |
) # a is "in-context learning rate"
|
279 |
if self.args.wkv_has_gate:
|
280 |
-
g = torch.sigmoid(xg @ self.g1) @ self.g2
|
281 |
kk = k * self.k_k
|
282 |
-
kk = F.normalize(kk.view(B, T, self.n_head, -1),
|
283 |
-
|
|
|
284 |
|
285 |
wkv_state = last_state.wkv_state
|
286 |
hidden_states, wkv_state = self.apply_wkv7_state(
|
@@ -292,66 +295,68 @@ class Rwkv_Tmix_x070(nn.Module):
|
|
292 |
(kk * a),
|
293 |
s=wkv_state,
|
294 |
output_final_state=use_cache,
|
295 |
-
cu_seqlens=cu_seqlens
|
296 |
-
head_first=False
|
297 |
)
|
298 |
if self.args.wkv_has_group_norm:
|
299 |
hidden_states = self.ln_x(
|
300 |
hidden_states.view(B * T, C)).view(B, T, C)
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
|
|
|
|
|
|
307 |
hidden_states = self.output(
|
308 |
hidden_states * g) if self.args.wkv_has_gate else self.output(hidden_states)
|
309 |
-
return hidden_states,
|
310 |
|
311 |
|
312 |
class Rwkv7Attention(nn.Module):
|
313 |
-
def __init__(self, args: RwkvHybridConfig, layer_id
|
314 |
super().__init__()
|
315 |
self.args = args
|
316 |
self.layer_idx = layer_id
|
317 |
-
self.time_mixer = Rwkv_Tmix_x070(
|
318 |
-
args, layer_id, update_v_first, get_v_first)
|
319 |
|
320 |
def forward(
|
321 |
self,
|
322 |
hidden_states: torch.Tensor,
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
output_attentions: Optional[bool] = False,
|
|
|
|
|
|
|
|
|
|
|
327 |
**kwargs
|
328 |
):
|
329 |
-
|
330 |
-
assert len(sequence_mask.shape) == 2, (
|
331 |
-
"Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
|
332 |
-
"for padding purposes (0 indicating padding). "
|
333 |
-
"Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
|
334 |
-
)
|
335 |
batch_size, token_length, _ = hidden_states.shape
|
336 |
|
337 |
-
if
|
338 |
last_state = past_key_value[self.layer_idx][0]
|
339 |
else:
|
340 |
last_state = self.init_state(
|
341 |
batch_size, hidden_states.device, hidden_states.dtype
|
342 |
)
|
343 |
|
344 |
-
attn_output, states = self.time_mixer(hidden_states=hidden_states,
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
|
351 |
-
if
|
|
|
352 |
past_key_value.update(token_length, last_state, self.layer_idx)
|
353 |
|
354 |
-
return attn_output, None
|
355 |
|
356 |
def init_state(self, batch_size, device, dtype) -> BlockState:
|
357 |
wkv_states = torch.zeros(
|
@@ -364,10 +369,10 @@ class Rwkv7Attention(nn.Module):
|
|
364 |
device=device,
|
365 |
dtype=torch.float32,
|
366 |
)
|
367 |
-
|
368 |
(batch_size, self.args.hidden_size), device=device, dtype=dtype
|
369 |
)
|
370 |
-
return BlockState(
|
371 |
|
372 |
|
373 |
class Rwkv_Tmix_x060(nn.Module):
|
@@ -380,8 +385,6 @@ class Rwkv_Tmix_x060(nn.Module):
|
|
380 |
self.head_size = args.head_size
|
381 |
self.n_head = args.num_wkv_heads
|
382 |
assert args.hidden_size % self.n_head == 0
|
383 |
-
H = self.n_head
|
384 |
-
N = self.head_size
|
385 |
|
386 |
with torch.no_grad():
|
387 |
ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
|
@@ -445,7 +448,6 @@ class Rwkv_Tmix_x060(nn.Module):
|
|
445 |
|
446 |
self.time_faaaa = nn.Parameter(
|
447 |
tmp.reshape(self.n_head, self.head_size))
|
448 |
-
# self.time_state = nn.Parameter(torch.zeros(self.n_head, self.head_size, self.head_size))
|
449 |
|
450 |
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
451 |
self.receptance = nn.Linear(
|
@@ -465,27 +467,36 @@ class Rwkv_Tmix_x060(nn.Module):
|
|
465 |
def post_init(self):
|
466 |
pass
|
467 |
|
468 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
469 |
shift_state = last_state.shift_state
|
470 |
-
B, T, C =
|
471 |
H = self.n_head
|
472 |
-
if shift_state is not None:
|
473 |
-
xx = torch.concat((shift_state.unsqueeze(1), x[:, :-1]), dim=1) - x
|
474 |
-
else:
|
475 |
-
xx = self.time_shift(x) - x
|
476 |
-
lx = x[:, -1]
|
477 |
|
478 |
-
|
|
|
|
|
|
|
|
|
|
|
479 |
xxx = torch.tanh(xxx @ self.time_maa_w1).view(B *
|
480 |
T, 5, -1).transpose(0, 1)
|
481 |
xxx = torch.bmm(xxx, self.time_maa_w2).view(5, B, T, -1)
|
482 |
mw, mk, mv, mr, mg = xxx.unbind(dim=0)
|
483 |
|
484 |
-
xw =
|
485 |
-
xk =
|
486 |
-
xv =
|
487 |
-
xr =
|
488 |
-
xg =
|
489 |
|
490 |
r = self.receptance(xr)
|
491 |
k = self.key(xk)
|
@@ -496,16 +507,18 @@ class Rwkv_Tmix_x060(nn.Module):
|
|
496 |
w = self.time_decay + ww
|
497 |
|
498 |
wkv_state = last_state.wkv_state
|
499 |
-
|
500 |
B, T, C, H, r, k, v, w, u=self.time_faaaa, s=wkv_state
|
501 |
)
|
502 |
if self.args.wkv_has_group_norm:
|
503 |
-
|
504 |
-
|
505 |
-
|
|
|
506 |
|
507 |
def apply_wkv6_state(self, B, T, C, H, r, k, v, w, u, s):
|
508 |
-
r, w, k, v = map(lambda x: rearrange(
|
|
|
509 |
|
510 |
if r.device.type == "cpu":
|
511 |
wkv6_func = native_recurrent_rwkv6
|
@@ -535,31 +548,56 @@ class Rwkv6Attention(nn.Module):
|
|
535 |
self.layer_idx = layer_id
|
536 |
self.time_mixer = Rwkv_Tmix_x060(args, layer_id, **kwargs)
|
537 |
|
538 |
-
def forward(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
539 |
attn_output = hidden_states
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
(B, self.args.num_wkv_heads,
|
549 |
-
self.args.head_size, self.args.head_size),
|
550 |
-
device=attn_output.device,
|
551 |
-
dtype=torch.float32,
|
552 |
-
)
|
553 |
-
token_shift = torch.zeros(
|
554 |
-
(B, C), device=attn_output.device, dtype=attn_output.dtype
|
555 |
)
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
from einops import rearrange
|
3 |
|
|
|
4 |
import math
|
5 |
import torch.nn as nn
|
6 |
from torch.nn import functional as F
|
7 |
from .configuration_rwkv_hybrid import RwkvHybridConfig
|
8 |
+
from typing import Optional
|
9 |
+
from .hybrid_cache import HybridCache, AttnState, BlockState
|
10 |
|
11 |
try:
|
12 |
+
import triton # pylint: disable=F401
|
13 |
from rwkvfla.ops.rwkv7 import (
|
14 |
fused_recurrent_rwkv7,
|
15 |
chunk_rwkv7,
|
|
|
32 |
fused_recurrent_rwkv6 = native_recurrent_rwkv6
|
33 |
fused_addcmul_rwkv7 = torch_addcmul_rwkv7
|
34 |
|
35 |
+
from rwkvfla.utils import check_pytorch_version
|
36 |
+
|
37 |
+
if check_pytorch_version("2.6"):
|
38 |
+
compile_decorator = torch.compile
|
39 |
+
torch._dynamo.config.cache_size_limit = 512
|
40 |
+
else:
|
41 |
+
def compile_decorator(func):
|
42 |
+
return func
|
43 |
+
|
44 |
|
45 |
class Rwkv_Tmix_x070(nn.Module):
|
46 |
+
def __init__(self, args: RwkvHybridConfig, layer_id, **kwargs):
|
47 |
super().__init__()
|
48 |
self.args = args
|
49 |
self.layer_id = layer_id
|
50 |
self.hidden_size = args.hidden_size
|
51 |
|
|
|
|
|
|
|
52 |
self.head_size = args.head_size
|
53 |
self.n_head = args.num_wkv_heads
|
54 |
assert args.hidden_size % self.n_head == 0
|
|
|
60 |
self.x_k = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
|
61 |
self.x_v = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
|
62 |
self.x_a = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
|
63 |
+
|
64 |
D_DECAY_LORA = 64
|
65 |
D_AAA_LORA = 64
|
66 |
D_MV_LORA = 32
|
|
|
127 |
)
|
128 |
nn.init.constant_(
|
129 |
self.x_a, 1.0 - torch.pow(ddd, 0.9 * ratio_1_to_almost0))
|
|
|
130 |
|
131 |
def ortho_init(x, scale):
|
132 |
shape = x.shape
|
|
|
185 |
D_GATE_LORA, self.args.hidden_size), 0.1)
|
186 |
)
|
187 |
nn.init.constant_(
|
188 |
+
self.x_g, 1.0 - torch.pow(ddd, 0.2 * ratio_1_to_almost0))
|
189 |
|
190 |
nn.init.constant_(self.k_k, 0.85)
|
191 |
nn.init.constant_(self.k_a, 1.0)
|
|
|
200 |
nn.init.ones_(self.ln_x.weight)
|
201 |
nn.init.zeros_(self.ln_x.bias)
|
202 |
|
203 |
+
def apply_wkv7_state(
|
204 |
+
self, r, k, v, w, a, b, s,
|
205 |
+
output_final_state,
|
206 |
+
cu_seqlens
|
207 |
+
):
|
|
|
208 |
if r.device.type == "cpu":
|
209 |
+
r, w, k, v, a, b = map(lambda x: rearrange(
|
210 |
+
x, 'b l (h d) -> b h l d', h=self.n_head), (r, w, k, v, a, b))
|
211 |
o, state = native_recurrent_rwkv7(
|
212 |
r=r, k=k, v=v, w=w,
|
213 |
a=a, b=b,
|
|
|
219 |
state = state.transpose(-1, -2)
|
220 |
x = rearrange(o, "b h l d -> b l (h d)")
|
221 |
else:
|
222 |
+
r, w, k, v, a, b = map(lambda x: rearrange(
|
223 |
+
x, 'b l (h d) -> b l h d', h=self.n_head), (r, w, k, v, a, b))
|
224 |
+
wkv7_func = chunk_rwkv7 if r.shape[1] != 1 else fused_recurrent_rwkv7
|
225 |
o, state = wkv7_func(
|
226 |
r=r, k=k, v=v, w=w,
|
227 |
a=a, b=b,
|
|
|
229 |
initial_state=s,
|
230 |
output_final_state=output_final_state,
|
231 |
cu_seqlens=cu_seqlens,
|
232 |
+
head_first=False,
|
233 |
)
|
234 |
x = rearrange(o, "b l h d -> b l (h d)")
|
235 |
return x, state
|
236 |
|
237 |
+
@compile_decorator
|
238 |
def forward(
|
239 |
self,
|
240 |
hidden_states,
|
241 |
+
last_state: AttnState,
|
|
|
242 |
use_cache: Optional[bool] = False,
|
243 |
cu_seqlens: Optional[torch.Tensor] = None,
|
244 |
+
v_first: Optional[torch.Tensor] = None,
|
245 |
+
attention_mask: Optional[torch.Tensor] = None,
|
246 |
**kwargs
|
247 |
):
|
|
|
|
|
|
|
|
|
248 |
shift_state = last_state.shift_state
|
249 |
B, T, C = hidden_states.size()
|
250 |
|
251 |
+
xx = torch.concat((shift_state.unsqueeze(
|
252 |
+
1), hidden_states[:, :-1]), dim=1) - hidden_states
|
|
|
|
|
|
|
253 |
|
254 |
lx = hidden_states[:, -1]
|
255 |
|
|
|
257 |
xr, xw, xk, xv, xa, xg = fused_addcmul_rwkv7(
|
258 |
hidden_states, xx, self.x_r, self.x_w, self.x_k, self.x_v, self.x_a, self.x_g)
|
259 |
else:
|
260 |
+
xr, xw, xk, xv, xa, _ = fused_addcmul_rwkv7(
|
261 |
+
hidden_states, xx, self.x_r, self.x_w, self.x_k, self.x_v, self.x_a)
|
262 |
|
263 |
r = self.receptance(xr)
|
264 |
w = (
|
|
|
267 |
k = self.key(xk)
|
268 |
v = self.value(xv)
|
269 |
if self.layer_id == 0:
|
270 |
+
v_first = v
|
271 |
else:
|
272 |
+
v = torch.lerp(v, v_first, torch.sigmoid(
|
|
|
273 |
self.v0 + (xv @ self.v1) @ self.v2
|
274 |
+
)) # add value residual
|
275 |
|
276 |
+
if attention_mask is not None:
|
277 |
+
v = v.mul(attention_mask[:, -v.shape[-2]:, None])
|
278 |
a = torch.sigmoid(
|
279 |
self.a0 + (xa @ self.a1) @ self.a2
|
280 |
) # a is "in-context learning rate"
|
281 |
if self.args.wkv_has_gate:
|
282 |
+
g = torch.sigmoid(xg @ self.g1) @ self.g2 + 1.0
|
283 |
kk = k * self.k_k
|
284 |
+
kk = F.normalize(kk.view(B, T, self.n_head, -1),
|
285 |
+
p=2.0, dim=-1, eps=1e-4 if kk.dtype == torch.float16 else 1e-12).view(B, T, C)
|
286 |
+
k = torch.lerp(k, k * a, self.k_a)
|
287 |
|
288 |
wkv_state = last_state.wkv_state
|
289 |
hidden_states, wkv_state = self.apply_wkv7_state(
|
|
|
295 |
(kk * a),
|
296 |
s=wkv_state,
|
297 |
output_final_state=use_cache,
|
298 |
+
cu_seqlens=cu_seqlens
|
|
|
299 |
)
|
300 |
if self.args.wkv_has_group_norm:
|
301 |
hidden_states = self.ln_x(
|
302 |
hidden_states.view(B * T, C)).view(B, T, C)
|
303 |
+
|
304 |
+
# original code:
|
305 |
+
# weighted_sum_rk = (r.view(B, T, self.n_head, -1) * k.view(B, T, self.n_head, -1) * self.r_k).sum(
|
306 |
+
# dim=-1, keepdim=True
|
307 |
+
# )
|
308 |
+
weighted_sum_rk = torch.einsum('btij,btij,ij->btij', r.view(B, T, self.n_head, -1),
|
309 |
+
k.view(B, T, self.n_head, -1), self.r_k).sum(dim=-1, keepdim=True)
|
310 |
+
hidden_states = hidden_states + \
|
311 |
+
(weighted_sum_rk * v.view(B, T, self.n_head, -1)).view(B, T, C)
|
312 |
hidden_states = self.output(
|
313 |
hidden_states * g) if self.args.wkv_has_gate else self.output(hidden_states)
|
314 |
+
return hidden_states, AttnState(lx, wkv_state), v_first
|
315 |
|
316 |
|
317 |
class Rwkv7Attention(nn.Module):
|
318 |
+
def __init__(self, args: RwkvHybridConfig, layer_id):
|
319 |
super().__init__()
|
320 |
self.args = args
|
321 |
self.layer_idx = layer_id
|
322 |
+
self.time_mixer = Rwkv_Tmix_x070(args, layer_id)
|
|
|
323 |
|
324 |
def forward(
|
325 |
self,
|
326 |
hidden_states: torch.Tensor,
|
327 |
+
attention_mask: Optional[torch.Tensor] = None,
|
328 |
+
position_ids: Optional[torch.Tensor] = None,
|
329 |
+
past_key_value: Optional[HybridCache] = None,
|
330 |
output_attentions: Optional[bool] = False,
|
331 |
+
use_cache: Optional[bool] = False,
|
332 |
+
cache_position: Optional[torch.Tensor] = None,
|
333 |
+
position_embeddings: Optional[torch.Tensor] = None,
|
334 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
335 |
+
v_first: Optional[torch.Tensor] = None,
|
336 |
**kwargs
|
337 |
):
|
338 |
+
|
|
|
|
|
|
|
|
|
|
|
339 |
batch_size, token_length, _ = hidden_states.shape
|
340 |
|
341 |
+
if use_cache and len(past_key_value) > self.layer_idx:
|
342 |
last_state = past_key_value[self.layer_idx][0]
|
343 |
else:
|
344 |
last_state = self.init_state(
|
345 |
batch_size, hidden_states.device, hidden_states.dtype
|
346 |
)
|
347 |
|
348 |
+
attn_output, states, v_first = self.time_mixer(hidden_states=hidden_states,
|
349 |
+
last_state=last_state.attn_state,
|
350 |
+
use_cache=use_cache,
|
351 |
+
cu_seqlens=cu_seqlens,
|
352 |
+
v_first=v_first,
|
353 |
+
**kwargs)
|
354 |
|
355 |
+
if use_cache:
|
356 |
+
last_state.attn_state = states
|
357 |
past_key_value.update(token_length, last_state, self.layer_idx)
|
358 |
|
359 |
+
return attn_output, None, v_first
|
360 |
|
361 |
def init_state(self, batch_size, device, dtype) -> BlockState:
|
362 |
wkv_states = torch.zeros(
|
|
|
369 |
device=device,
|
370 |
dtype=torch.float32,
|
371 |
)
|
372 |
+
shift_states = torch.zeros(
|
373 |
(batch_size, self.args.hidden_size), device=device, dtype=dtype
|
374 |
)
|
375 |
+
return BlockState(AttnState(shift_states, wkv_states), None)
|
376 |
|
377 |
|
378 |
class Rwkv_Tmix_x060(nn.Module):
|
|
|
385 |
self.head_size = args.head_size
|
386 |
self.n_head = args.num_wkv_heads
|
387 |
assert args.hidden_size % self.n_head == 0
|
|
|
|
|
388 |
|
389 |
with torch.no_grad():
|
390 |
ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
|
|
|
448 |
|
449 |
self.time_faaaa = nn.Parameter(
|
450 |
tmp.reshape(self.n_head, self.head_size))
|
|
|
451 |
|
452 |
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
453 |
self.receptance = nn.Linear(
|
|
|
467 |
def post_init(self):
|
468 |
pass
|
469 |
|
470 |
+
@compile_decorator
|
471 |
+
def forward(
|
472 |
+
self,
|
473 |
+
hidden_states,
|
474 |
+
last_state: AttnState,
|
475 |
+
use_cache: Optional[bool] = False,
|
476 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
477 |
+
v_first: Optional[torch.Tensor] = None,
|
478 |
+
**kwargs
|
479 |
+
):
|
480 |
shift_state = last_state.shift_state
|
481 |
+
B, T, C = hidden_states.size()
|
482 |
H = self.n_head
|
|
|
|
|
|
|
|
|
|
|
483 |
|
484 |
+
xx = torch.concat((shift_state.unsqueeze(
|
485 |
+
1), hidden_states[:, :-1]), dim=1) - hidden_states
|
486 |
+
|
487 |
+
lx = hidden_states[:, -1]
|
488 |
+
|
489 |
+
xxx = hidden_states + xx * self.time_maa_x
|
490 |
xxx = torch.tanh(xxx @ self.time_maa_w1).view(B *
|
491 |
T, 5, -1).transpose(0, 1)
|
492 |
xxx = torch.bmm(xxx, self.time_maa_w2).view(5, B, T, -1)
|
493 |
mw, mk, mv, mr, mg = xxx.unbind(dim=0)
|
494 |
|
495 |
+
xw = hidden_states + xx * (self.time_maa_w + mw)
|
496 |
+
xk = hidden_states + xx * (self.time_maa_k + mk)
|
497 |
+
xv = hidden_states + xx * (self.time_maa_v + mv)
|
498 |
+
xr = hidden_states + xx * (self.time_maa_r + mr)
|
499 |
+
xg = hidden_states + xx * (self.time_maa_g + mg)
|
500 |
|
501 |
r = self.receptance(xr)
|
502 |
k = self.key(xk)
|
|
|
507 |
w = self.time_decay + ww
|
508 |
|
509 |
wkv_state = last_state.wkv_state
|
510 |
+
hidden_states, wkv_state = self.apply_wkv6_state(
|
511 |
B, T, C, H, r, k, v, w, u=self.time_faaaa, s=wkv_state
|
512 |
)
|
513 |
if self.args.wkv_has_group_norm:
|
514 |
+
hidden_states = self.ln_x(
|
515 |
+
hidden_states.view(B * T, C)).view(B, T, C)
|
516 |
+
hidden_states = self.output(hidden_states * g)
|
517 |
+
return hidden_states, AttnState(lx, wkv_state), None
|
518 |
|
519 |
def apply_wkv6_state(self, B, T, C, H, r, k, v, w, u, s):
|
520 |
+
r, w, k, v = map(lambda x: rearrange(
|
521 |
+
x, 'b l (h d) -> b h l d', h=self.n_head), (r, w, k, v))
|
522 |
|
523 |
if r.device.type == "cpu":
|
524 |
wkv6_func = native_recurrent_rwkv6
|
|
|
548 |
self.layer_idx = layer_id
|
549 |
self.time_mixer = Rwkv_Tmix_x060(args, layer_id, **kwargs)
|
550 |
|
551 |
+
def forward(
|
552 |
+
self,
|
553 |
+
hidden_states: torch.Tensor,
|
554 |
+
attention_mask: Optional[torch.Tensor] = None,
|
555 |
+
position_ids: Optional[torch.Tensor] = None,
|
556 |
+
past_key_value: Optional[HybridCache] = None,
|
557 |
+
output_attentions: Optional[bool] = False,
|
558 |
+
use_cache: Optional[bool] = False,
|
559 |
+
cache_position: Optional[torch.Tensor] = None,
|
560 |
+
position_embeddings: Optional[torch.Tensor] = None,
|
561 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
562 |
+
v_first: Optional[torch.Tensor] = None,
|
563 |
+
**kwargs
|
564 |
+
):
|
565 |
attn_output = hidden_states
|
566 |
+
|
567 |
+
batch_size, token_length, _ = hidden_states.shape
|
568 |
+
|
569 |
+
if use_cache and len(past_key_value) > self.layer_idx:
|
570 |
+
last_state = past_key_value[self.layer_idx][0]
|
571 |
+
else:
|
572 |
+
last_state = self.init_state(
|
573 |
+
batch_size, hidden_states.device, hidden_states.dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
574 |
)
|
575 |
+
|
576 |
+
attn_output, states, v_first = self.time_mixer(hidden_states=hidden_states,
|
577 |
+
last_state=last_state.attn_state,
|
578 |
+
use_cache=use_cache,
|
579 |
+
cu_seqlens=cu_seqlens,
|
580 |
+
v_first=v_first,
|
581 |
+
**kwargs)
|
582 |
+
|
583 |
+
if use_cache:
|
584 |
+
last_state.attn_state = states
|
585 |
+
past_key_value.update(token_length, last_state, self.layer_idx)
|
586 |
+
|
587 |
+
return attn_output, None, v_first
|
588 |
+
|
589 |
+
def init_state(self, batch_size, device, dtype) -> BlockState:
|
590 |
+
wkv_states = torch.zeros(
|
591 |
+
(
|
592 |
+
batch_size,
|
593 |
+
self.args.num_wkv_heads,
|
594 |
+
self.args.head_size,
|
595 |
+
self.args.head_size,
|
596 |
+
),
|
597 |
+
device=device,
|
598 |
+
dtype=torch.float32,
|
599 |
+
)
|
600 |
+
shift_states = torch.zeros(
|
601 |
+
(batch_size, self.args.hidden_size), device=device, dtype=dtype
|
602 |
+
)
|
603 |
+
return BlockState(AttnState(shift_states, wkv_states), None)
|