hvlgo commited on
Commit
2f40c01
·
verified ·
1 Parent(s): b2779c7

Upload SundialForPrediction

Browse files
config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "checkpoints/sundial_base_gift_eval_677",
3
+ "architectures": [
4
+ "SundialForPrediction"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_sundial.SundialConfig",
8
+ "AutoModelForCausalLM": "modeling_sundial.SundialForPrediction"
9
+ },
10
+ "diffusion_batch_mul": 4,
11
+ "dropout_rate": 0.1,
12
+ "flow_loss_depth": 3,
13
+ "hidden_act": "silu",
14
+ "hidden_size": 768,
15
+ "initializer_range": 0.02,
16
+ "input_token_len": 16,
17
+ "intermediate_size": 3072,
18
+ "max_position_embeddings": 10000,
19
+ "model_type": "sundial",
20
+ "num_attention_heads": 12,
21
+ "num_hidden_layers": 12,
22
+ "num_sampling_steps": 50,
23
+ "output_token_lens": [
24
+ 720
25
+ ],
26
+ "rope_theta": 10000,
27
+ "torch_dtype": "float32",
28
+ "transformers_version": "4.40.1",
29
+ "use_cache": true
30
+ }
configuration_sundial.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from transformers import PretrainedConfig
3
+
4
+
5
+ class SundialConfig(PretrainedConfig):
6
+ model_type = "sundial"
7
+ keys_to_ignore_at_inference = ["past_key_values"]
8
+
9
+ def __init__(
10
+ self,
11
+ input_token_len: int = 16,
12
+ hidden_size: int = 768,
13
+ intermediate_size: int = 3072,
14
+ output_token_lens: List[int] = [720],
15
+ num_hidden_layers: int = 12,
16
+ num_attention_heads: int = 12,
17
+ hidden_act: str = "silu",
18
+ use_cache: bool = True,
19
+ rope_theta: int = 10000,
20
+ dropout_rate: float = 0.1,
21
+ initializer_range: float = 0.02,
22
+ max_position_embeddings: int = 10000,
23
+ flow_loss_depth: int = 3,
24
+ num_sampling_steps: int = 50,
25
+ diffusion_batch_mul: int = 4,
26
+ **kwargs,
27
+ ):
28
+ self.input_token_len = input_token_len
29
+ self.hidden_size = hidden_size
30
+ self.intermediate_size = intermediate_size
31
+ self.num_hidden_layers = num_hidden_layers
32
+ self.num_attention_heads = num_attention_heads
33
+ self.hidden_act = hidden_act
34
+ self.output_token_lens = output_token_lens
35
+ self.use_cache = use_cache
36
+ self.rope_theta = rope_theta
37
+ self.dropout_rate = dropout_rate
38
+ self.initializer_range = initializer_range
39
+ self.max_position_embeddings = max_position_embeddings
40
+ self.flow_loss_depth = flow_loss_depth
41
+ self.num_sampling_steps = num_sampling_steps
42
+ self.diffusion_batch_mul = diffusion_batch_mul
43
+
44
+ super().__init__(
45
+ **kwargs,
46
+ )
flow_loss.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+
6
+ class FlowLoss(nn.Module):
7
+ """Flow Loss"""
8
+
9
+ def __init__(self, target_channels, z_channels, depth, width, num_sampling_steps):
10
+ super(FlowLoss, self).__init__()
11
+ self.in_channels = target_channels
12
+ self.net = SimpleMLPAdaLN(
13
+ in_channels=target_channels,
14
+ model_channels=width,
15
+ out_channels=target_channels,
16
+ z_channels=z_channels,
17
+ num_res_blocks=depth
18
+ )
19
+ self.num_sampling_steps = num_sampling_steps
20
+
21
+ def forward(self, target, z, mask=None, mask_y=None):
22
+ noise = torch.randn_like(target)
23
+ t = torch.rand(target.shape[0], device=target.device)
24
+
25
+ noised_target = t[:, None] * target + (1 - t[:, None]) * noise
26
+
27
+ predict_v = self.net(noised_target, t * 1000, z)
28
+
29
+ weights = 1.0 / \
30
+ torch.arange(1, self.in_channels + 1, dtype=torch.float32, device=target.device)
31
+ if mask_y is not None:
32
+ loss = (mask_y * weights * (predict_v - target) ** 2).sum(dim=-1)
33
+ else:
34
+ loss = (weights * (predict_v - target) ** 2).sum(dim=-1)
35
+
36
+ if mask is not None:
37
+ loss = (loss * mask).sum() / mask.sum()
38
+ return loss.mean()
39
+
40
+ def sample(self, z, num_samples=1):
41
+ z = z.repeat(num_samples, 1)
42
+ noise = torch.randn(z.shape[0], self.in_channels).cuda()
43
+ x = noise
44
+ dt = 1.0 / self.num_sampling_steps
45
+ for i in range(self.num_sampling_steps):
46
+ t = (torch.ones((x.shape[0])) * i /
47
+ self.num_sampling_steps).to(x.device)
48
+ pred = self.net(x, t * 1000, z)
49
+ x = x + (pred - noise) * dt
50
+ x = x.reshape(num_samples, -1, self.in_channels).transpose(0, 1)
51
+ return x
52
+
53
+
54
+ def modulate(x, shift, scale):
55
+ return x * (1 + scale) + shift
56
+
57
+
58
+ class TimestepEmbedder(nn.Module):
59
+ """
60
+ Embeds scalar timesteps into vector representations.
61
+ """
62
+
63
+ def __init__(self, hidden_size, frequency_embedding_size=256):
64
+ super().__init__()
65
+ self.mlp = nn.Sequential(
66
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
67
+ nn.SiLU(),
68
+ nn.Linear(hidden_size, hidden_size, bias=True),
69
+ )
70
+ self.frequency_embedding_size = frequency_embedding_size
71
+
72
+ @staticmethod
73
+ def timestep_embedding(t, dim, max_period=10000):
74
+ """
75
+ Create sinusoidal timestep embeddings.
76
+ :param t: a 1-D Tensor of N indices, one per batch element.
77
+ These may be fractional.
78
+ :param dim: the dimension of the output.
79
+ :param max_period: controls the minimum frequency of the embeddings.
80
+ :return: an (N, D) Tensor of positional embeddings.
81
+ """
82
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
83
+ half = dim // 2
84
+ freqs = torch.exp(
85
+ -math.log(max_period) * torch.arange(start=0,
86
+ end=half, dtype=torch.float32) / half
87
+ ).to(device=t.device)
88
+ args = t[:, None].float() * freqs[None]
89
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
90
+ if dim % 2:
91
+ embedding = torch.cat(
92
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
93
+ return embedding
94
+
95
+ def forward(self, t):
96
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
97
+ t_emb = self.mlp(t_freq)
98
+ return t_emb
99
+
100
+
101
+ class ResBlock(nn.Module):
102
+ """
103
+ A residual block that can optionally change the number of channels.
104
+ :param channels: the number of input channels.
105
+ """
106
+
107
+ def __init__(
108
+ self,
109
+ channels
110
+ ):
111
+ super().__init__()
112
+ self.channels = channels
113
+
114
+ self.in_ln = nn.LayerNorm(channels, eps=1e-6)
115
+ self.mlp = nn.Sequential(
116
+ nn.Linear(channels, channels, bias=True),
117
+ nn.SiLU(),
118
+ nn.Linear(channels, channels, bias=True),
119
+ )
120
+
121
+ self.adaLN_modulation = nn.Sequential(
122
+ nn.SiLU(),
123
+ nn.Linear(channels, 3 * channels, bias=True)
124
+ )
125
+
126
+ def forward(self, x, y):
127
+ shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(
128
+ y).chunk(3, dim=-1)
129
+ h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
130
+ h = self.mlp(h)
131
+ return x + gate_mlp * h
132
+
133
+
134
+ class FinalLayer(nn.Module):
135
+ """
136
+ The final layer adopted from DiT.
137
+ """
138
+
139
+ def __init__(self, model_channels, out_channels):
140
+ super().__init__()
141
+ self.norm_final = nn.LayerNorm(
142
+ model_channels, elementwise_affine=False, eps=1e-6)
143
+ self.linear = nn.Linear(model_channels, out_channels, bias=True)
144
+ self.adaLN_modulation = nn.Sequential(
145
+ nn.SiLU(),
146
+ nn.Linear(model_channels, 2 * model_channels, bias=True)
147
+ )
148
+
149
+ def forward(self, x, c):
150
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
151
+ x = modulate(self.norm_final(x), shift, scale)
152
+ x = self.linear(x)
153
+ return x
154
+
155
+
156
+ class SimpleMLPAdaLN(nn.Module):
157
+ """
158
+ The MLP for Diffusion Loss.
159
+ :param in_channels: channels in the input Tensor.
160
+ :param model_channels: base channel count for the model.
161
+ :param out_channels: channels in the output Tensor.
162
+ :param z_channels: channels in the condition.
163
+ :param num_res_blocks: number of residual blocks per downsample.
164
+ """
165
+
166
+ def __init__(
167
+ self,
168
+ in_channels,
169
+ model_channels,
170
+ out_channels,
171
+ z_channels,
172
+ num_res_blocks,
173
+ ):
174
+ super().__init__()
175
+
176
+ self.in_channels = in_channels
177
+ self.model_channels = model_channels
178
+ self.out_channels = out_channels
179
+ self.num_res_blocks = num_res_blocks
180
+
181
+ self.time_embed = TimestepEmbedder(model_channels)
182
+ self.cond_embed = nn.Linear(z_channels, model_channels)
183
+
184
+ self.input_proj = nn.Linear(in_channels, model_channels)
185
+
186
+ res_blocks = []
187
+ for i in range(num_res_blocks):
188
+ res_blocks.append(ResBlock(
189
+ model_channels,
190
+ ))
191
+
192
+ self.res_blocks = nn.ModuleList(res_blocks)
193
+ self.final_layer = FinalLayer(model_channels, out_channels)
194
+
195
+ self.initialize_weights()
196
+
197
+ def initialize_weights(self):
198
+ def _basic_init(module):
199
+ if isinstance(module, nn.Linear):
200
+ torch.nn.init.xavier_uniform_(module.weight)
201
+ if module.bias is not None:
202
+ nn.init.constant_(module.bias, 0)
203
+ self.apply(_basic_init)
204
+
205
+ # Initialize timestep embedding MLP
206
+ nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
207
+ nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
208
+
209
+ # Zero-out adaLN modulation layers
210
+ for block in self.res_blocks:
211
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
212
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
213
+
214
+ # Zero-out output layers
215
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
216
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
217
+ nn.init.constant_(self.final_layer.linear.weight, 0)
218
+ nn.init.constant_(self.final_layer.linear.bias, 0)
219
+
220
+ def forward(self, x, t, c):
221
+ """
222
+ Apply the model to an input batch.
223
+ :param x: an [N x C] Tensor of inputs.
224
+ :param t: a 1-D batch of timesteps.
225
+ :param c: conditioning from AR transformer.
226
+ :return: an [N x C] Tensor of outputs.
227
+ """
228
+ x = self.input_proj(x)
229
+ t = self.time_embed(t)
230
+ c = self.cond_embed(c)
231
+ y = t + c
232
+
233
+ for block in self.res_blocks:
234
+ x = block(x, y)
235
+
236
+ return self.final_layer(x, y)
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.40.1"
4
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:414435b508391f92afadd2aaeec418c806776aeccbce12e638d73a139ca5ca78
3
+ size 513341448
modeling_sundial.py ADDED
@@ -0,0 +1,574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, List, Union
2
+ import torch
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+ from transformers import PreTrainedModel, Cache, DynamicCache
6
+ from transformers.activations import ACT2FN
7
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
8
+ from transformers.modeling_outputs import MoeModelOutputWithPast, MoeCausalLMOutputWithPast
9
+ from .configuration_sundial import SundialConfig
10
+ from .ts_generation_mixin import TSGenerationMixin
11
+ from .flow_loss import FlowLoss
12
+
13
+
14
+ def rotate_half(x):
15
+ x1 = x[..., : x.shape[-1] // 2]
16
+ x2 = x[..., x.shape[-1] // 2:]
17
+ return torch.cat((-x2, x1), dim=-1)
18
+
19
+
20
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
21
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
22
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
23
+ q_embed = (q * cos) + (rotate_half(q) * sin)
24
+ k_embed = (k * cos) + (rotate_half(k) * sin)
25
+ return q_embed, k_embed
26
+
27
+
28
+ class SundialPatchEmbedding(nn.Module):
29
+ def __init__(self, config: SundialConfig):
30
+ super().__init__()
31
+ self.dropout = nn.Dropout(config.dropout_rate)
32
+ self.hidden_layer = nn.Linear(
33
+ config.input_token_len * 2, config.intermediate_size)
34
+ self.act = ACT2FN[config.hidden_act]
35
+ self.output_layer = nn.Linear(
36
+ config.intermediate_size, config.hidden_size)
37
+ self.residual_layer = nn.Linear(
38
+ config.input_token_len * 2, config.hidden_size)
39
+ self.input_token_len = config.input_token_len
40
+
41
+ def forward(self, x):
42
+ mask = torch.ones_like(x, dtype=torch.float32)
43
+ input_length = x.shape[-1]
44
+ padding_length = (self.input_token_len - (input_length %
45
+ self.input_token_len)) % self.input_token_len
46
+ x = F.pad(x, (padding_length, 0))
47
+ mask = F.pad(mask, (padding_length, 0))
48
+ x = x.unfold(dimension=-1, size=self.input_token_len,
49
+ step=self.input_token_len)
50
+ mask = mask.unfold(
51
+ dimension=-1, size=self.input_token_len, step=self.input_token_len)
52
+
53
+ x = torch.cat([x, mask], dim=-1)
54
+ hid = self.act(self.hidden_layer(x))
55
+ out = self.dropout(self.output_layer(hid))
56
+ res = self.residual_layer(x)
57
+ out = out + res
58
+ return out
59
+
60
+
61
+ class SundialRotaryEmbedding(torch.nn.Module):
62
+ def __init__(self, dim, max_position_embeddings=10000, base=10000, device=None):
63
+ super().__init__()
64
+ self.dim = dim
65
+ self.max_position_embeddings = max_position_embeddings
66
+ self.base = base
67
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim,
68
+ 2, dtype=torch.int64).float().to(device) / self.dim))
69
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
70
+
71
+ # Build here to make `torch.jit.trace` work.
72
+ self._set_cos_sin_cache(
73
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
74
+ )
75
+
76
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
77
+ self.max_seq_len_cached = seq_len
78
+ t = torch.arange(self.max_seq_len_cached, device=device,
79
+ dtype=torch.int64).type_as(self.inv_freq)
80
+
81
+ freqs = torch.outer(t, self.inv_freq)
82
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
83
+ emb = torch.cat((freqs, freqs), dim=-1)
84
+ self.register_buffer(
85
+ "cos_cached", emb.cos().to(dtype), persistent=False)
86
+ self.register_buffer(
87
+ "sin_cached", emb.sin().to(dtype), persistent=False)
88
+
89
+ def forward(self, x, seq_len=None):
90
+ # x: [bs, num_attention_heads, seq_len, head_size]
91
+ if seq_len > self.max_seq_len_cached:
92
+ self._set_cos_sin_cache(
93
+ seq_len=seq_len, device=x.device, dtype=x.dtype)
94
+
95
+ return (
96
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
97
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
98
+ )
99
+
100
+
101
+ class SundialAttention(nn.Module):
102
+ def __init__(self, config: SundialConfig, layer_idx: Optional[int] = None):
103
+ super().__init__()
104
+ self.layer_idx = layer_idx
105
+ self.hidden_size = config.hidden_size
106
+ self.num_heads = config.num_attention_heads
107
+ self.head_dim = self.hidden_size // self.num_heads
108
+ self.attention_dropout = config.dropout_rate
109
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
110
+ self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
111
+ self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
112
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
113
+ self.rotary_emb = SundialRotaryEmbedding(
114
+ self.head_dim, max_position_embeddings=config.max_position_embeddings)
115
+
116
+ def forward(
117
+ self,
118
+ hidden_states: torch.Tensor,
119
+ attention_mask: Optional[torch.Tensor] = None,
120
+ position_ids: Optional[torch.LongTensor] = None,
121
+ past_key_value: Optional[Cache] = None,
122
+ output_attentions: bool = False,
123
+ **kwargs,
124
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
125
+ bsz, q_len, _ = hidden_states.size()
126
+
127
+ query_states = self.q_proj(hidden_states)
128
+ key_states = self.k_proj(hidden_states)
129
+ value_states = self.v_proj(hidden_states)
130
+
131
+ query_states = query_states.view(
132
+ bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
133
+ key_states = key_states.view(
134
+ bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
135
+ value_states = value_states.view(
136
+ bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
137
+
138
+ kv_seq_len = key_states.shape[-2]
139
+ if past_key_value is not None:
140
+ kv_seq_len += past_key_value.get_usable_length(
141
+ kv_seq_len, self.layer_idx)
142
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
143
+ query_states, key_states = apply_rotary_pos_emb(
144
+ query_states, key_states, cos, sin, position_ids)
145
+
146
+ if past_key_value is not None:
147
+ key_states, value_states = past_key_value.update(
148
+ key_states, value_states, self.layer_idx)
149
+
150
+ attn_output = F.scaled_dot_product_attention(
151
+ query_states, key_states, value_states, attention_mask, dropout_p=(self.attention_dropout if self.training else 0.0))
152
+
153
+ attn_output = attn_output.transpose(1, 2).contiguous()
154
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
155
+ attn_output = self.o_proj(attn_output)
156
+
157
+ if not output_attentions:
158
+ attn_weights = None
159
+
160
+ return attn_output, attn_weights, past_key_value
161
+
162
+
163
+ class SundialMLP(nn.Module):
164
+ def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str):
165
+ super().__init__()
166
+ self.hidden_size = hidden_size
167
+ self.intermediate_size = intermediate_size
168
+ self.gate_proj = nn.Linear(
169
+ self.hidden_size, self.intermediate_size, bias=False)
170
+ self.up_proj = nn.Linear(
171
+ self.hidden_size, self.intermediate_size, bias=False)
172
+ self.down_proj = nn.Linear(
173
+ self.intermediate_size, self.hidden_size, bias=False)
174
+ self.act_fn = ACT2FN[hidden_act]
175
+
176
+ def forward(self, hidden_state):
177
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
178
+
179
+
180
+ class SundialDecoderLayer(nn.Module):
181
+ def __init__(self, config: SundialConfig, layer_idx: int):
182
+ super().__init__()
183
+ self.self_attn = SundialAttention(config, layer_idx)
184
+
185
+ self.ffn_layer = SundialMLP(
186
+ hidden_size=config.hidden_size,
187
+ intermediate_size=config.intermediate_size,
188
+ hidden_act=config.hidden_act,
189
+ )
190
+ self.norm1 = torch.nn.LayerNorm(config.hidden_size)
191
+ self.norm2 = torch.nn.LayerNorm(config.hidden_size)
192
+
193
+ def forward(
194
+ self,
195
+ hidden_states: torch.Tensor,
196
+ attention_mask: Optional[torch.Tensor] = None,
197
+ position_ids: Optional[torch.LongTensor] = None,
198
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
199
+ output_attentions: Optional[bool] = False,
200
+ use_cache: Optional[bool] = False,
201
+ **kwargs,
202
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor, Optional[torch.FloatTensor], Optional[torch.FloatTensor]]:
203
+ residual = hidden_states
204
+
205
+ hidden_states = self.norm1(hidden_states)
206
+
207
+ # Self Attention
208
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
209
+ hidden_states=hidden_states,
210
+ attention_mask=attention_mask,
211
+ position_ids=position_ids,
212
+ past_key_value=past_key_value,
213
+ output_attentions=output_attentions,
214
+ use_cache=use_cache,
215
+ )
216
+ hidden_states = residual + hidden_states
217
+
218
+ # Fully Connected
219
+ residual = hidden_states
220
+ hidden_states = self.norm2(hidden_states)
221
+ hidden_states = self.ffn_layer(hidden_states)
222
+ hidden_states = residual + hidden_states
223
+
224
+ if not output_attentions:
225
+ self_attn_weights = None
226
+
227
+ if not use_cache:
228
+ present_key_value = None
229
+ return hidden_states, self_attn_weights, present_key_value
230
+
231
+
232
+ class SundialPreTrainedModel(PreTrainedModel):
233
+ config_class = SundialConfig
234
+ base_model_prefix = "model"
235
+ supports_gradient_checkpointing = True
236
+ _no_split_modules = ["SundialDecoderLayer"]
237
+ _skip_keys_device_placement = "past_key_values"
238
+ _supports_flash_attn_2 = True
239
+ _supports_sdpa = False
240
+ _supports_cache_class = True
241
+
242
+ def _init_weights(self, module):
243
+ std = self.config.initializer_range
244
+ if isinstance(module, torch.nn.Linear):
245
+ module.weight.data.normal_(mean=0.0, std=std)
246
+ if module.bias is not None:
247
+ module.bias.data.zero_()
248
+ elif isinstance(module, torch.nn.Embedding):
249
+ module.weight.data.normal_(mean=0.0, std=std)
250
+ if module.padding_idx is not None:
251
+ module.weight.data[module.padding_idx].zero_()
252
+
253
+
254
+ class SundialModel(SundialPreTrainedModel):
255
+ def __init__(self, config: SundialConfig):
256
+ super().__init__(config)
257
+ self.embed_layer = SundialPatchEmbedding(config)
258
+ self.layers = nn.ModuleList(
259
+ [SundialDecoderLayer(config, layer_idx)
260
+ for layer_idx in range(config.num_hidden_layers)]
261
+ )
262
+ self.norm = torch.nn.LayerNorm(config.hidden_size)
263
+ self.gradient_checkpointing = False
264
+
265
+ def forward(
266
+ self,
267
+ input_ids: torch.FloatTensor = None,
268
+ attention_mask: Optional[torch.Tensor] = None,
269
+ position_ids: Optional[torch.LongTensor] = None,
270
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
271
+ inputs_embeds: Optional[torch.FloatTensor] = None,
272
+ use_cache: Optional[bool] = None,
273
+ output_attentions: Optional[bool] = None,
274
+ output_hidden_states: Optional[bool] = None,
275
+ return_dict: Optional[bool] = None,
276
+ ) -> Union[Tuple, MoeModelOutputWithPast]:
277
+ # input_ids is the input of time series, its shape is [batch_size, seq_len]
278
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
279
+ output_hidden_states = (
280
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
281
+ )
282
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
283
+
284
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
285
+
286
+ # retrieve input_ids and inputs_embeds
287
+ if input_ids is not None and inputs_embeds is not None:
288
+ raise ValueError(
289
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
290
+ elif input_ids is not None:
291
+ batch_size, seq_length = input_ids.shape
292
+ elif inputs_embeds is not None:
293
+ batch_size, seq_length, _ = inputs_embeds.shape
294
+ else:
295
+ raise ValueError(
296
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds")
297
+
298
+ if inputs_embeds is None:
299
+ inputs_embeds = self.embed_layer(input_ids)
300
+ seq_length = inputs_embeds.shape[1]
301
+
302
+ if self.gradient_checkpointing and self.training:
303
+ if use_cache:
304
+ use_cache = False
305
+
306
+ past_key_values_length = 0
307
+
308
+ if use_cache:
309
+ use_legacy_cache = not isinstance(past_key_values, Cache)
310
+ if use_legacy_cache:
311
+ past_key_values = DynamicCache.from_legacy_cache(
312
+ past_key_values)
313
+ past_key_values_length = past_key_values.get_usable_length(
314
+ seq_length)
315
+
316
+ if position_ids is None:
317
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
318
+ position_ids = torch.arange(
319
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
320
+ )
321
+ # position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
322
+ position_ids = position_ids.view(-1, seq_length)
323
+ else:
324
+ position_ids = position_ids.view(-1, seq_length).long()
325
+
326
+ # 4d mask is passed through the layers
327
+ attention_mask = _prepare_4d_causal_attention_mask(
328
+ attention_mask,
329
+ (batch_size, seq_length),
330
+ inputs_embeds,
331
+ past_key_values_length,
332
+ sliding_window=None,
333
+ )
334
+
335
+ hidden_states = inputs_embeds
336
+
337
+ # decoder layers
338
+ all_hidden_states = () if output_hidden_states else None
339
+ all_self_attns = () if output_attentions else None
340
+ next_decoder_cache = None
341
+
342
+ for decoder_layer in self.layers:
343
+ if output_hidden_states:
344
+ all_hidden_states += (hidden_states,)
345
+
346
+ if self.gradient_checkpointing and self.training:
347
+ layer_outputs = self._gradient_checkpointing_func(
348
+ decoder_layer.__call__,
349
+ hidden_states,
350
+ attention_mask,
351
+ position_ids,
352
+ past_key_values,
353
+ output_attentions,
354
+ use_cache,
355
+ )
356
+ else:
357
+ layer_outputs = decoder_layer(
358
+ hidden_states,
359
+ attention_mask=attention_mask,
360
+ position_ids=position_ids,
361
+ past_key_value=past_key_values,
362
+ output_attentions=output_attentions,
363
+ use_cache=use_cache,
364
+ )
365
+
366
+ hidden_states = layer_outputs[0]
367
+
368
+ if output_attentions:
369
+ all_self_attns += (layer_outputs[1],)
370
+
371
+ if use_cache:
372
+ next_decoder_cache = layer_outputs[2]
373
+
374
+ hidden_states = self.norm(hidden_states)
375
+ # add hidden states from the last decoder layer
376
+ if output_hidden_states:
377
+ all_hidden_states += (hidden_states,)
378
+
379
+ next_cache = None
380
+ if use_cache:
381
+ next_cache = next_decoder_cache.to_legacy_cache(
382
+ ) if use_legacy_cache else next_decoder_cache
383
+
384
+ if not return_dict:
385
+ return tuple(
386
+ v
387
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
388
+ if v is not None
389
+ )
390
+ return MoeModelOutputWithPast(
391
+ last_hidden_state=hidden_states,
392
+ past_key_values=next_cache,
393
+ hidden_states=all_hidden_states,
394
+ attentions=all_self_attns,
395
+ )
396
+
397
+
398
+ class SundialForPrediction(SundialPreTrainedModel, TSGenerationMixin):
399
+ def __init__(self, config: SundialConfig):
400
+ super().__init__(config)
401
+ self.config = config
402
+ self.model = SundialModel(self.config)
403
+ self.flow_loss = FlowLoss(self.config.output_token_lens[-1], self.config.hidden_size,
404
+ self.config.flow_loss_depth, self.config.hidden_size, self.config.num_sampling_steps)
405
+ self.post_init()
406
+
407
+ def set_decoder(self, decoder):
408
+ self.model = decoder
409
+
410
+ def get_decoder(self):
411
+ return self.model
412
+
413
+ def forward(
414
+ self,
415
+ input_ids: torch.FloatTensor = None,
416
+ attention_mask: Optional[torch.Tensor] = None,
417
+ position_ids: Optional[torch.LongTensor] = None,
418
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
419
+ inputs_embeds: Optional[torch.FloatTensor] = None,
420
+ labels: Optional[torch.FloatTensor] = None,
421
+ loss_masks: Optional[torch.FloatTensor] = None,
422
+ mask_y: Optional[torch.FloatTensor] = None,
423
+ use_cache: Optional[bool] = None,
424
+ output_attentions: Optional[bool] = None,
425
+ output_hidden_states: Optional[bool] = None,
426
+ return_dict: Optional[bool] = None,
427
+ max_output_length: Optional[int] = None,
428
+ revin: Optional[bool] = False,
429
+ num_samples: Optional[int] = 1,
430
+ ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
431
+
432
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
433
+ output_hidden_states = (
434
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
435
+ )
436
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
437
+
438
+ if revin:
439
+ means = input_ids.mean(1, keepdim=True).detach()
440
+ stdev = input_ids.std(dim=1, keepdim=True, unbiased=False).detach()
441
+ stdev = torch.where(stdev > 1e-2, stdev, torch.tensor(1.0, device=input_ids.device))
442
+ input_ids = (input_ids - means) / stdev
443
+ outputs = self.model(
444
+ input_ids=input_ids,
445
+ attention_mask=attention_mask,
446
+ position_ids=position_ids,
447
+ past_key_values=past_key_values,
448
+ inputs_embeds=inputs_embeds,
449
+ use_cache=use_cache,
450
+ output_attentions=output_attentions,
451
+ output_hidden_states=output_hidden_states,
452
+ return_dict=return_dict,
453
+ )
454
+
455
+ hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state
456
+ predictions = None
457
+
458
+ loss = None
459
+ if labels is not None:
460
+ if revin:
461
+ labels = (labels - means) / stdev
462
+ output_token_len = self.config.output_token_lens[-1]
463
+ seq_len = hidden_states.shape[1] * self.config.input_token_len
464
+ labels = labels[:, :seq_len -
465
+ self.config.input_token_len + output_token_len]
466
+ shift_labels = labels.unfold(
467
+ dimension=-1, size=output_token_len, step=self.config.input_token_len)
468
+
469
+ bsz, L, _ = shift_labels.shape
470
+ shift_labels = shift_labels.reshape(
471
+ bsz * L, -1).repeat(self.config.diffusion_batch_mul, 1)
472
+ hidden_states = hidden_states.reshape(
473
+ bsz * L, -1).repeat(self.config.diffusion_batch_mul, 1)
474
+ loss_masks = loss_masks.reshape(
475
+ bsz * L).repeat(self.config.diffusion_batch_mul)
476
+ mask_y = mask_y.repeat(L * self.config.diffusion_batch_mul, 1)
477
+
478
+ loss = self.flow_loss(shift_labels, hidden_states, loss_masks, mask_y)
479
+ else:
480
+ if max_output_length is None:
481
+ output_token_len = self.config.output_token_lens[0]
482
+ max_output_length = output_token_len
483
+ else:
484
+ output_token_len = self.config.output_token_lens[0]
485
+ for h in self.config.output_token_lens[1:]:
486
+ if h > max_output_length:
487
+ break
488
+ else:
489
+ output_token_len = h
490
+
491
+ bsz = hidden_states.shape[0]
492
+ hidden_states = hidden_states[:, -1, :]
493
+ predictions = self.flow_loss.sample(hidden_states, num_samples)
494
+ if output_token_len > max_output_length:
495
+ predictions = predictions[:, :, :max_output_length]
496
+ if revin:
497
+ predictions = predictions * stdev + means
498
+ if not return_dict:
499
+ output = (predictions,) + outputs[1:]
500
+ return (loss) + output if loss is not None else output
501
+
502
+ return MoeCausalLMOutputWithPast(
503
+ loss=loss,
504
+ logits=predictions,
505
+ past_key_values=outputs.past_key_values,
506
+ hidden_states=outputs.hidden_states,
507
+ attentions=outputs.attentions,
508
+ )
509
+
510
+ def prepare_inputs_for_generation(
511
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, revin=False, num_samples=1, **kwargs
512
+ ):
513
+ # Omit tokens covered by past_key_values
514
+ if past_key_values is not None:
515
+ if isinstance(past_key_values, Cache):
516
+ cache_length = past_key_values.get_seq_length()
517
+ if isinstance(past_key_values, DynamicCache):
518
+ past_length = past_key_values.seen_tokens
519
+ else:
520
+ past_length = cache_length
521
+
522
+ max_cache_length = past_key_values.get_max_length()
523
+ else:
524
+ cache_length = past_length = past_key_values[0][0].shape[2]
525
+ max_cache_length = None
526
+
527
+ # Keep only the unprocessed tokens:
528
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
529
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
530
+ # input)
531
+ if attention_mask is not None and attention_mask.shape[1] > (input_ids.shape[1] // self.config.input_token_len):
532
+ input_ids = input_ids[:, -
533
+ (attention_mask.shape[1] - past_length) * self.config.input_token_len:]
534
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
535
+ # input_ids based on the past_length.
536
+ elif past_length < (input_ids.shape[1] // self.config.input_token_len):
537
+ input_ids = input_ids[:, past_length *
538
+ self.config.input_token_len:]
539
+ # 3 - Otherwise (past_length >= (input_ids.shape[1] // self.config.input_token_len)), let's assume input_ids only has unprocessed tokens.
540
+
541
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
542
+ if (
543
+ max_cache_length is not None
544
+ and attention_mask is not None
545
+ and cache_length + (input_ids.shape[1] // self.config.input_token_len) > max_cache_length
546
+ ):
547
+ attention_mask = attention_mask[:, -max_cache_length:]
548
+
549
+ position_ids = kwargs.get("position_ids", None)
550
+ if attention_mask is not None and position_ids is None:
551
+ # create position_ids on the fly for batch generation
552
+ position_ids = attention_mask.long().cumsum(-1) - 1
553
+ position_ids.masked_fill_(attention_mask == 0, 1)
554
+ if past_key_values:
555
+ position_ids = position_ids[:, -
556
+ (input_ids.shape[1] // self.config.input_token_len):]
557
+
558
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
559
+ if inputs_embeds is not None and past_key_values is None:
560
+ model_inputs = {"inputs_embeds": inputs_embeds}
561
+ else:
562
+ model_inputs = {"input_ids": input_ids}
563
+
564
+ model_inputs.update(
565
+ {
566
+ "position_ids": position_ids,
567
+ "past_key_values": past_key_values,
568
+ "use_cache": kwargs.get("use_cache"),
569
+ "attention_mask": attention_mask,
570
+ "revin": revin,
571
+ "num_samples": num_samples,
572
+ }
573
+ )
574
+ return model_inputs
ts_generation_mixin.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import Any, Dict, List, Optional, Union, Callable
3
+ import torch
4
+ from transformers import GenerationMixin, LogitsProcessorList, StoppingCriteriaList
5
+ from transformers.generation import validate_stopping_criteria, EosTokenCriteria
6
+ from transformers.generation.utils import GenerateNonBeamOutput, GenerateEncoderDecoderOutput, GenerateDecoderOnlyOutput, GenerationConfig, GenerateOutput
7
+ from transformers.utils import ModelOutput
8
+
9
+
10
+ class TSGenerationMixin(GenerationMixin):
11
+ @torch.no_grad()
12
+ def generate(
13
+ self,
14
+ inputs: Optional[torch.Tensor] = None,
15
+ generation_config: Optional[GenerationConfig] = None,
16
+ logits_processor: Optional[LogitsProcessorList] = None,
17
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
18
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
19
+ synced_gpus: Optional[bool] = None,
20
+ assistant_model: Optional["PreTrainedModel"] = None,
21
+ streamer: Optional["BaseStreamer"] = None,
22
+ negative_prompt_ids: Optional[torch.Tensor] = None,
23
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
24
+ revin: Optional[bool] = True,
25
+ num_samples: Optional[int] = 1,
26
+ **kwargs,
27
+ ) -> Union[GenerateOutput, torch.LongTensor]:
28
+ if len(inputs.shape) != 2:
29
+ raise ValueError('Input shape must be: [batch_size, seq_len]')
30
+ if revin:
31
+ means = inputs.mean(dim=-1, keepdim=True)
32
+ stdev = inputs.std(dim=-1, keepdim=True, unbiased=False) + 1e-5
33
+ inputs = (inputs - means) / stdev
34
+ outputs = super().generate(inputs=inputs, generation_config=generation_config, logits_processor=logits_processor, stopping_criteria=stopping_criteria, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, synced_gpus=synced_gpus, assistant_model=assistant_model, streamer=streamer, negative_prompt_ids=negative_prompt_ids, negative_prompt_attention_mask=negative_prompt_attention_mask, num_samples=num_samples, **kwargs)
35
+ if revin:
36
+ stdev = stdev.unsqueeze(1).repeat(1, num_samples, 1)
37
+ means = means.unsqueeze(1).repeat(1, num_samples, 1)
38
+ outputs = (outputs * stdev) + means
39
+ return outputs
40
+
41
+ def _greedy_search(
42
+ self,
43
+ input_ids: torch.Tensor,
44
+ logits_processor: Optional[LogitsProcessorList] = None,
45
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
46
+ max_length: Optional[int] = None,
47
+ pad_token_id: Optional[int] = None,
48
+ eos_token_id: Optional[Union[int, List[int]]] = None,
49
+ output_attentions: Optional[bool] = None,
50
+ output_hidden_states: Optional[bool] = None,
51
+ output_scores: Optional[bool] = None,
52
+ output_logits: Optional[bool] = None,
53
+ return_dict_in_generate: Optional[bool] = None,
54
+ synced_gpus: bool = False,
55
+ streamer: Optional["BaseStreamer"] = None,
56
+ **model_kwargs,
57
+ ) -> Union[GenerateNonBeamOutput, torch.Tensor]:
58
+ input_ids = input_ids.to(self.device)
59
+ batch_size, cur_len = input_ids.shape
60
+ # init values
61
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
62
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
63
+ if max_length is not None:
64
+ warnings.warn(
65
+ "`max_length` is deprecated in this function, use"
66
+ " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
67
+ UserWarning,
68
+ )
69
+ stopping_criteria = validate_stopping_criteria(
70
+ stopping_criteria, max_length)
71
+ pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
72
+ if eos_token_id is not None:
73
+ stopping_criteria.append(
74
+ EosTokenCriteria(eos_token_id=eos_token_id))
75
+ else:
76
+ # remove when the method is totally private
77
+ # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
78
+ eos_token_id = [
79
+ criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
80
+ ]
81
+ eos_token_id = eos_token_id[0] if eos_token_id else None
82
+ if eos_token_id is None and self.generation_config.eos_token_id is not None:
83
+ eos_token_id = self.generation_config.eos_token_id
84
+ stopping_criteria.append(
85
+ EosTokenCriteria(eos_token_id=eos_token_id))
86
+
87
+ if isinstance(eos_token_id, int):
88
+ eos_token_id = [eos_token_id]
89
+ output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
90
+ output_attentions = (
91
+ output_attentions if output_attentions is not None else self.generation_config.output_attentions
92
+ )
93
+ output_hidden_states = (
94
+ output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
95
+ )
96
+ return_dict_in_generate = (
97
+ return_dict_in_generate
98
+ if return_dict_in_generate is not None
99
+ else self.generation_config.return_dict_in_generate
100
+ )
101
+
102
+ # init attention / hidden states / scores tuples
103
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
104
+ scores = () if (return_dict_in_generate and output_scores) else None
105
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
106
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
107
+ decoder_hidden_states = () if (
108
+ return_dict_in_generate and output_hidden_states) else None
109
+
110
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
111
+ if return_dict_in_generate and self.config.is_encoder_decoder:
112
+ encoder_attentions = model_kwargs["encoder_outputs"].get(
113
+ "attentions") if output_attentions else None
114
+ encoder_hidden_states = (
115
+ model_kwargs["encoder_outputs"].get(
116
+ "hidden_states") if output_hidden_states else None
117
+ )
118
+
119
+ # keep track of which sequences are already finished
120
+ if "inputs_embeds" in model_kwargs:
121
+ cur_len = model_kwargs["inputs_embeds"].shape[1]
122
+ this_peer_finished = False
123
+ unfinished_sequences = torch.ones(
124
+ batch_size, dtype=torch.long, device=input_ids.device)
125
+ model_kwargs["cache_position"] = torch.arange(
126
+ cur_len, device=input_ids.device)
127
+ true_seq_len = (cur_len + self.config.input_token_len - 1) // self.config.input_token_len
128
+ model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, -true_seq_len:]
129
+ max_length = stopping_criteria.max_length
130
+ generate_results = None
131
+ while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
132
+ # prepare model inputs
133
+ model_inputs = self.prepare_inputs_for_generation(
134
+ input_ids, **model_kwargs)
135
+
136
+ input_length = input_ids.shape[1]
137
+
138
+ # forward pass to get next token
139
+ outputs = self(
140
+ **model_inputs,
141
+ return_dict=True,
142
+ output_attentions=output_attentions,
143
+ output_hidden_states=output_hidden_states,
144
+ max_output_length=max_length - input_length,
145
+ )
146
+
147
+ if synced_gpus and this_peer_finished:
148
+ continue # don't waste resources running the code we don't need
149
+ next_token_logits = outputs.logits
150
+
151
+ # pre-process distribution
152
+ next_tokens_scores = logits_processor(input_ids, next_token_logits)
153
+
154
+ # Store scores, attentions and hidden_states when required
155
+ if return_dict_in_generate:
156
+ if output_scores:
157
+ scores += (next_tokens_scores,)
158
+ if output_logits:
159
+ raw_logits += (next_token_logits,)
160
+ if output_attentions:
161
+ decoder_attentions += (
162
+ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (
163
+ outputs.attentions,)
164
+ )
165
+ if self.config.is_encoder_decoder:
166
+ cross_attentions += (outputs.cross_attentions,)
167
+
168
+ if output_hidden_states:
169
+ decoder_hidden_states += (
170
+ (outputs.decoder_hidden_states,)
171
+ if self.config.is_encoder_decoder
172
+ else (outputs.hidden_states,)
173
+ )
174
+
175
+ # argmax
176
+ # next_tokens = torch.argmax(next_tokens_scores, dim=-1)
177
+ next_tokens = next_tokens_scores
178
+
179
+ # finished sentences should have their next token be a padding token
180
+ if eos_token_id is not None:
181
+ if pad_token_id is None:
182
+ raise ValueError(
183
+ "If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
184
+ next_tokens = next_tokens * unfinished_sequences + \
185
+ pad_token_id * (1 - unfinished_sequences)
186
+
187
+ # update generated ids, model inputs, and length for next step
188
+ horizon_length = next_tokens.shape[-1] // self.config.input_token_len
189
+
190
+ past_key_values = model_kwargs.get("past_key_values")
191
+ if past_key_values is None:
192
+ generate_results = next_tokens
193
+ else:
194
+ generate_results = torch.cat([generate_results, next_tokens], dim=-1)
195
+ input_ids = torch.cat([input_ids, next_tokens.median(dim=1)[0]], dim=-1)
196
+
197
+ if streamer is not None:
198
+ streamer.put(next_tokens.cpu())
199
+ model_kwargs = self._update_model_kwargs_for_generation(
200
+ outputs,
201
+ model_kwargs,
202
+ horizon_length=horizon_length,
203
+ is_encoder_decoder=self.config.is_encoder_decoder,
204
+ )
205
+ unfinished_sequences = unfinished_sequences & ~stopping_criteria(
206
+ input_ids, scores)
207
+ this_peer_finished = unfinished_sequences.max() == 0
208
+
209
+ if input_ids.shape[-1] > max_length:
210
+ input_ids = input_ids[:, :max_length]
211
+
212
+ if streamer is not None:
213
+ streamer.end()
214
+
215
+ if return_dict_in_generate:
216
+ if self.config.is_encoder_decoder:
217
+ return GenerateEncoderDecoderOutput(
218
+ sequences=input_ids,
219
+ scores=scores,
220
+ logits=raw_logits,
221
+ encoder_attentions=encoder_attentions,
222
+ encoder_hidden_states=encoder_hidden_states,
223
+ decoder_attentions=decoder_attentions,
224
+ cross_attentions=cross_attentions,
225
+ decoder_hidden_states=decoder_hidden_states,
226
+ past_key_values=model_kwargs.get("past_key_values"),
227
+ )
228
+ else:
229
+ return GenerateDecoderOnlyOutput(
230
+ sequences=input_ids,
231
+ scores=scores,
232
+ logits=raw_logits,
233
+ attentions=decoder_attentions,
234
+ hidden_states=decoder_hidden_states,
235
+ past_key_values=model_kwargs.get("past_key_values"),
236
+ )
237
+ else:
238
+ return generate_results[:, :, :(max_length - cur_len)]
239
+
240
+ def _update_model_kwargs_for_generation(
241
+ self,
242
+ outputs: ModelOutput,
243
+ model_kwargs: Dict[str, Any],
244
+ horizon_length: int = 1,
245
+ is_encoder_decoder: bool = False,
246
+ standardize_cache_format: bool = False,
247
+ ) -> Dict[str, Any]:
248
+ # update past_key_values
249
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
250
+ outputs, standardize_cache_format=standardize_cache_format
251
+ )
252
+ if getattr(outputs, "state", None) is not None:
253
+ model_kwargs["state"] = outputs.state
254
+
255
+ # update token_type_ids with last value
256
+ if "token_type_ids" in model_kwargs:
257
+ token_type_ids = model_kwargs["token_type_ids"]
258
+ model_kwargs["token_type_ids"] = torch.cat(
259
+ [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
260
+
261
+ if not is_encoder_decoder:
262
+ # update attention mask
263
+ if "attention_mask" in model_kwargs:
264
+ attention_mask = model_kwargs["attention_mask"]
265
+ model_kwargs["attention_mask"] = torch.cat(
266
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], horizon_length))], dim=-1
267
+ )
268
+ else:
269
+ # update decoder attention mask
270
+ if "decoder_attention_mask" in model_kwargs:
271
+ decoder_attention_mask = model_kwargs["decoder_attention_mask"]
272
+ model_kwargs["decoder_attention_mask"] = torch.cat(
273
+ [decoder_attention_mask, decoder_attention_mask.new_ones(
274
+ (decoder_attention_mask.shape[0], horizon_length))],
275
+ dim=-1,
276
+ )
277
+
278
+ if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
279
+ model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + horizon_length
280
+
281
+ return model_kwargs