Update ultravox_model.py
Browse files- ultravox_model.py +166 -59
ultravox_model.py
CHANGED
@@ -2,6 +2,7 @@ import logging
|
|
2 |
import re
|
3 |
from typing import Any, Dict, Generator, Optional, Set, Tuple, TypeVar, Union
|
4 |
|
|
|
5 |
import peft
|
6 |
import torch
|
7 |
import torch.nn as nn
|
@@ -19,6 +20,15 @@ from .ultravox_config import LossConfig
|
|
19 |
from .ultravox_config import LossFunction
|
20 |
from .ultravox_config import UltravoxConfig
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
|
24 |
"""
|
@@ -69,44 +79,29 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
|
|
69 |
self.loss_config = LossConfig()
|
70 |
self.post_init()
|
71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
@classmethod
|
73 |
-
def from_pretrained(cls,
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
76 |
return model
|
77 |
|
78 |
-
def _load_child_model_weights(self, *args, **kwargs) -> "UltravoxModel":
|
79 |
-
if "torch_dtype" in kwargs:
|
80 |
-
self.config.torch_dtype = kwargs.pop("torch_dtype")
|
81 |
-
|
82 |
-
kwargs.pop("config", None)
|
83 |
-
|
84 |
-
if (
|
85 |
-
self.config.text_model_id is not None
|
86 |
-
and self.language_model.device.type == "meta"
|
87 |
-
):
|
88 |
-
# Load the language model weights
|
89 |
-
self.language_model = transformers.AutoModelForCausalLM.from_pretrained(
|
90 |
-
self.config.text_model_id,
|
91 |
-
torch_dtype=self.config.torch_dtype,
|
92 |
-
*args,
|
93 |
-
**kwargs,
|
94 |
-
)
|
95 |
-
|
96 |
-
if (
|
97 |
-
self.config.audio_model_id is not None
|
98 |
-
and self.audio_tower.device.type == "meta"
|
99 |
-
):
|
100 |
-
# Load the audio tower weights
|
101 |
-
self.audio_tower = transformers.AutoModel.from_pretrained(
|
102 |
-
self.config.audio_model_id,
|
103 |
-
torch_dtype=self.config.torch_dtype,
|
104 |
-
*args,
|
105 |
-
**kwargs,
|
106 |
-
)
|
107 |
-
|
108 |
-
return self
|
109 |
-
|
110 |
def get_input_embeddings(self):
|
111 |
return self.language_model.get_input_embeddings()
|
112 |
|
@@ -153,21 +148,29 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
|
|
153 |
self.vocab_size = model_embeds.num_embeddings
|
154 |
return model_embeds
|
155 |
|
156 |
-
def _get_prediction_mask(
|
157 |
-
|
|
|
|
|
158 |
|
159 |
For each label position, we want the position before it since that's where
|
160 |
the model makes the prediction for that label.
|
161 |
|
|
|
|
|
|
|
162 |
Args:
|
163 |
labels: Tensor of shape (B, T) where B is batch size and T is sequence length,
|
164 |
with -100 for masked positions and token ids for label positions
|
165 |
|
166 |
Returns:
|
167 |
-
|
|
|
|
|
168 |
"""
|
169 |
if labels is None:
|
170 |
raise ValueError("labels must be provided")
|
|
|
171 |
# Shift the label mask right by 1 along the sequence dimension
|
172 |
# This gives us positions where we make predictions for the next token
|
173 |
label_mask = labels != -100
|
@@ -175,7 +178,19 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
|
|
175 |
pred_mask[:, :-1] = label_mask[
|
176 |
:, 1:
|
177 |
] # shift right by 1 along sequence dimension
|
178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
|
180 |
def _compute_kl_loss(
|
181 |
self,
|
@@ -198,21 +213,38 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
|
|
198 |
past_key_values=past_key_values,
|
199 |
**kwargs,
|
200 |
)
|
201 |
-
|
|
|
|
|
|
|
|
|
|
|
202 |
kl_loss = F.kl_div(
|
203 |
F.log_softmax(
|
204 |
-
lm_output.logits[self.
|
205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
dim=-1,
|
207 |
),
|
208 |
F.softmax(
|
209 |
-
alt_lm_output.logits[self.
|
210 |
-
/ self.loss_config.kl_temperature,
|
211 |
dim=-1,
|
212 |
),
|
213 |
reduction="batchmean",
|
214 |
)
|
215 |
-
|
|
|
216 |
|
217 |
def _audio_iter(
|
218 |
self, audio_batch_size: torch.Tensor
|
@@ -380,18 +412,27 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
|
|
380 |
cls, config: UltravoxConfig
|
381 |
) -> "UltravoxProjector":
|
382 |
projector = UltravoxProjector(config)
|
383 |
-
|
|
|
|
|
|
|
384 |
return projector
|
385 |
|
386 |
@classmethod
|
387 |
def _create_audio_tower(
|
388 |
cls, config: UltravoxConfig
|
389 |
) -> Union[transformers.Wav2Vec2Model, "ModifiedWhisperEncoder"]:
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
395 |
audio_tower.init_latency_mask(
|
396 |
config.audio_latency_block_size, dtype=config.torch_dtype
|
397 |
)
|
@@ -400,7 +441,27 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
|
|
400 |
None,
|
401 |
0,
|
402 |
), "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'"
|
403 |
-
audio_tower = transformers.AutoModel.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
404 |
|
405 |
if isinstance(
|
406 |
audio_tower,
|
@@ -418,14 +479,27 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
|
|
418 |
def _create_language_model(
|
419 |
cls, config: UltravoxConfig
|
420 |
) -> transformers.LlamaForCausalLM:
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
|
|
|
|
|
|
|
|
428 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
429 |
|
430 |
language_model = apply_lora(language_model, config.text_model_lora_config)
|
431 |
return language_model
|
@@ -525,6 +599,39 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
|
|
525 |
)
|
526 |
|
527 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
528 |
# TODO: refactor common parts to a shared module
|
529 |
def is_cache_empty(
|
530 |
past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]],
|
|
|
2 |
import re
|
3 |
from typing import Any, Dict, Generator, Optional, Set, Tuple, TypeVar, Union
|
4 |
|
5 |
+
import accelerate
|
6 |
import peft
|
7 |
import torch
|
8 |
import torch.nn as nn
|
|
|
20 |
from .ultravox_config import LossFunction
|
21 |
from .ultravox_config import UltravoxConfig
|
22 |
|
23 |
+
FROM_PRETRAINED_KWARGS = {}
|
24 |
+
SHARED_PRETRAINED_KWARGS = [
|
25 |
+
"tp_plan",
|
26 |
+
"device_map",
|
27 |
+
"torch_dtype",
|
28 |
+
"attn_implementation",
|
29 |
+
"use_flash_attention_2",
|
30 |
+
]
|
31 |
+
|
32 |
|
33 |
class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
|
34 |
"""
|
|
|
79 |
self.loss_config = LossConfig()
|
80 |
self.post_init()
|
81 |
|
82 |
+
def _init_weights(self, module):
|
83 |
+
if module is self:
|
84 |
+
if self.config.text_model_id is not None:
|
85 |
+
self.language_model = self._create_language_model(self.config)
|
86 |
+
if self.config.audio_model_id is not None:
|
87 |
+
self.audio_tower = self._create_audio_tower(self.config)
|
88 |
+
elif module in self.language_model.modules():
|
89 |
+
pass
|
90 |
+
elif module in self.audio_tower.modules():
|
91 |
+
pass
|
92 |
+
else:
|
93 |
+
super()._init_weights(module)
|
94 |
+
|
95 |
@classmethod
|
96 |
+
def from_pretrained(cls, *args, **kwargs):
|
97 |
+
global FROM_PRETRAINED_KWARGS
|
98 |
+
FROM_PRETRAINED_KWARGS = {
|
99 |
+
k: v for k, v in kwargs.items() if k in SHARED_PRETRAINED_KWARGS
|
100 |
+
}
|
101 |
+
model = super().from_pretrained(*args, **kwargs)
|
102 |
+
FROM_PRETRAINED_KWARGS = {}
|
103 |
return model
|
104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
def get_input_embeddings(self):
|
106 |
return self.language_model.get_input_embeddings()
|
107 |
|
|
|
148 |
self.vocab_size = model_embeds.num_embeddings
|
149 |
return model_embeds
|
150 |
|
151 |
+
def _get_prediction_mask(
|
152 |
+
self, labels: Optional[torch.Tensor]
|
153 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
154 |
+
"""Get boolean masks for positions where we want to compute KL divergence.
|
155 |
|
156 |
For each label position, we want the position before it since that's where
|
157 |
the model makes the prediction for that label.
|
158 |
|
159 |
+
Additionally, we want to identify the position right before the EOT token
|
160 |
+
(the last token with label != -100).
|
161 |
+
|
162 |
Args:
|
163 |
labels: Tensor of shape (B, T) where B is batch size and T is sequence length,
|
164 |
with -100 for masked positions and token ids for label positions
|
165 |
|
166 |
Returns:
|
167 |
+
Tuple containing:
|
168 |
+
- pred_mask: Boolean tensor of shape (B, T) that's True for positions where we want to compute KL divergence
|
169 |
+
- eot_mask: Boolean tensor of shape (B, T) that's True only for the last prediction position in each sequence
|
170 |
"""
|
171 |
if labels is None:
|
172 |
raise ValueError("labels must be provided")
|
173 |
+
|
174 |
# Shift the label mask right by 1 along the sequence dimension
|
175 |
# This gives us positions where we make predictions for the next token
|
176 |
label_mask = labels != -100
|
|
|
178 |
pred_mask[:, :-1] = label_mask[
|
179 |
:, 1:
|
180 |
] # shift right by 1 along sequence dimension
|
181 |
+
|
182 |
+
# Create EOT mask - identify only the last prediction position in each sequence
|
183 |
+
eot_mask = torch.zeros_like(pred_mask)
|
184 |
+
batch_size = labels.shape[0]
|
185 |
+
|
186 |
+
for i in range(batch_size):
|
187 |
+
# Find positions where we make predictions
|
188 |
+
pred_positions = torch.where(pred_mask[i])[0]
|
189 |
+
if len(pred_positions) > 0:
|
190 |
+
# Only mark the last prediction position
|
191 |
+
eot_mask[i, pred_positions[-1]] = True
|
192 |
+
|
193 |
+
return pred_mask, eot_mask
|
194 |
|
195 |
def _compute_kl_loss(
|
196 |
self,
|
|
|
213 |
past_key_values=past_key_values,
|
214 |
**kwargs,
|
215 |
)
|
216 |
+
|
217 |
+
# Get prediction masks for regular tokens and EOT tokens
|
218 |
+
pred_mask, eot_mask = self._get_prediction_mask(labels)
|
219 |
+
alt_pred_mask, alt_eot_mask = self._get_prediction_mask(alt_labels)
|
220 |
+
|
221 |
+
# compute the KL divergence loss between the two models for regular tokens
|
222 |
kl_loss = F.kl_div(
|
223 |
F.log_softmax(
|
224 |
+
lm_output.logits[pred_mask] / self.loss_config.kl_temperature,
|
225 |
+
dim=-1,
|
226 |
+
),
|
227 |
+
F.softmax(
|
228 |
+
alt_lm_output.logits[alt_pred_mask] / self.loss_config.kl_temperature,
|
229 |
+
dim=-1,
|
230 |
+
),
|
231 |
+
reduction="batchmean",
|
232 |
+
)
|
233 |
+
|
234 |
+
# Compute the KL divergence loss for EOT token positions if any exist
|
235 |
+
eot_loss = F.kl_div(
|
236 |
+
F.log_softmax(
|
237 |
+
lm_output.logits[eot_mask] / self.loss_config.kl_temperature,
|
238 |
dim=-1,
|
239 |
),
|
240 |
F.softmax(
|
241 |
+
alt_lm_output.logits[alt_eot_mask] / self.loss_config.kl_temperature,
|
|
|
242 |
dim=-1,
|
243 |
),
|
244 |
reduction="batchmean",
|
245 |
)
|
246 |
+
|
247 |
+
return {"loss": kl_loss + self.loss_config.eot_loss_weight * eot_loss}
|
248 |
|
249 |
def _audio_iter(
|
250 |
self, audio_batch_size: torch.Tensor
|
|
|
412 |
cls, config: UltravoxConfig
|
413 |
) -> "UltravoxProjector":
|
414 |
projector = UltravoxProjector(config)
|
415 |
+
dtype = config.torch_dtype
|
416 |
+
if isinstance(dtype, str):
|
417 |
+
dtype = getattr(torch, dtype)
|
418 |
+
projector.to(dtype)
|
419 |
return projector
|
420 |
|
421 |
@classmethod
|
422 |
def _create_audio_tower(
|
423 |
cls, config: UltravoxConfig
|
424 |
) -> Union[transformers.Wav2Vec2Model, "ModifiedWhisperEncoder"]:
|
425 |
+
# We probably don't want to pass tp_plan or device_map to the audio tower
|
426 |
+
# But potentially other kwargs can be passed in. TODO
|
427 |
+
kwargs = {"torch_dtype": config.torch_dtype}
|
428 |
+
if (
|
429 |
+
transformers.modeling_utils._init_weights
|
430 |
+
and config.audio_model_id is not None
|
431 |
+
):
|
432 |
+
if "whisper" in config.audio_model_id.lower():
|
433 |
+
audio_tower = ModifiedWhisperEncoder.from_pretrained(
|
434 |
+
config.audio_model_id, **kwargs
|
435 |
+
)
|
436 |
audio_tower.init_latency_mask(
|
437 |
config.audio_latency_block_size, dtype=config.torch_dtype
|
438 |
)
|
|
|
441 |
None,
|
442 |
0,
|
443 |
), "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'"
|
444 |
+
audio_tower = transformers.AutoModel.from_pretrained(
|
445 |
+
config.audio_model_id, **kwargs
|
446 |
+
)
|
447 |
+
else:
|
448 |
+
with accelerate.init_empty_weights():
|
449 |
+
if "whisper" in config.audio_config._name_or_path.lower():
|
450 |
+
audio_tower = ModifiedWhisperEncoder(config.audio_config)
|
451 |
+
audio_tower.init_latency_mask(
|
452 |
+
config.audio_latency_block_size,
|
453 |
+
dtype=config.torch_dtype,
|
454 |
+
)
|
455 |
+
else:
|
456 |
+
assert config.audio_latency_block_size in (
|
457 |
+
None,
|
458 |
+
0,
|
459 |
+
), "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'"
|
460 |
+
# we only ever use from_config if the weights are retrained, hence initializing is not
|
461 |
+
# required. This makes the model quite creation faster since init on CPU is quite slow.
|
462 |
+
audio_tower = transformers.AutoModel.from_config(
|
463 |
+
config.audio_config, **kwargs
|
464 |
+
)
|
465 |
|
466 |
if isinstance(
|
467 |
audio_tower,
|
|
|
479 |
def _create_language_model(
|
480 |
cls, config: UltravoxConfig
|
481 |
) -> transformers.LlamaForCausalLM:
|
482 |
+
if (
|
483 |
+
transformers.modeling_utils._init_weights
|
484 |
+
and config.text_model_id is not None
|
485 |
+
):
|
486 |
+
language_model = transformers.AutoModelForCausalLM.from_pretrained(
|
487 |
+
config.text_model_id,
|
488 |
+
**{
|
489 |
+
"attn_implementation": config.text_config._attn_implementation,
|
490 |
+
"torch_dtype": config.torch_dtype,
|
491 |
+
**FROM_PRETRAINED_KWARGS,
|
492 |
+
},
|
493 |
)
|
494 |
+
else:
|
495 |
+
with accelerate.init_empty_weights():
|
496 |
+
# we only ever use from_config if the weights are retrained, hence initializing is not
|
497 |
+
# required. This makes the model quite creation faster since init on CPU is quite slow.
|
498 |
+
language_model = transformers.AutoModelForCausalLM.from_config(
|
499 |
+
config.text_config,
|
500 |
+
attn_implementation=config.text_config._attn_implementation,
|
501 |
+
torch_dtype=config.torch_dtype,
|
502 |
+
)
|
503 |
|
504 |
language_model = apply_lora(language_model, config.text_model_lora_config)
|
505 |
return language_model
|
|
|
599 |
)
|
600 |
|
601 |
|
602 |
+
def get_checkpoint_files(
|
603 |
+
model_id: str,
|
604 |
+
) -> tuple[list[str], dict | None, list[str]]:
|
605 |
+
resolved_archive_file = transformers.utils.cached_file(
|
606 |
+
model_id,
|
607 |
+
transformers.utils.SAFE_WEIGHTS_NAME,
|
608 |
+
_raise_exceptions_for_missing_entries=False,
|
609 |
+
)
|
610 |
+
|
611 |
+
if resolved_archive_file is not None:
|
612 |
+
# not sharded
|
613 |
+
sharded_metadata = None
|
614 |
+
state_dict = transformers.modeling_utils.load_state_dict(resolved_archive_file)
|
615 |
+
loaded_state_dict_keys = list(state_dict.keys())
|
616 |
+
else:
|
617 |
+
# sharded
|
618 |
+
resolved_archive_file = transformers.utils.cached_file(
|
619 |
+
model_id, transformers.utils.SAFE_WEIGHTS_INDEX_NAME
|
620 |
+
)
|
621 |
+
resolved_archive_file, sharded_metadata = (
|
622 |
+
transformers.modeling_utils.get_checkpoint_shard_files(
|
623 |
+
model_id,
|
624 |
+
resolved_archive_file,
|
625 |
+
)
|
626 |
+
)
|
627 |
+
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
|
628 |
+
|
629 |
+
if isinstance(resolved_archive_file, str):
|
630 |
+
resolved_archive_file = [resolved_archive_file]
|
631 |
+
|
632 |
+
return resolved_archive_file, sharded_metadata, loaded_state_dict_keys
|
633 |
+
|
634 |
+
|
635 |
# TODO: refactor common parts to a shared module
|
636 |
def is_cache_empty(
|
637 |
past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]],
|