farzadab commited on
Commit
575a10c
·
verified ·
1 Parent(s): e6347fb

Update ultravox_model.py

Browse files
Files changed (1) hide show
  1. ultravox_model.py +166 -59
ultravox_model.py CHANGED
@@ -2,6 +2,7 @@ import logging
2
  import re
3
  from typing import Any, Dict, Generator, Optional, Set, Tuple, TypeVar, Union
4
 
 
5
  import peft
6
  import torch
7
  import torch.nn as nn
@@ -19,6 +20,15 @@ from .ultravox_config import LossConfig
19
  from .ultravox_config import LossFunction
20
  from .ultravox_config import UltravoxConfig
21
 
 
 
 
 
 
 
 
 
 
22
 
23
  class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
24
  """
@@ -69,44 +79,29 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
69
  self.loss_config = LossConfig()
70
  self.post_init()
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  @classmethod
73
- def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
74
- model = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
75
- model._load_child_model_weights(*args, **kwargs)
 
 
 
 
76
  return model
77
 
78
- def _load_child_model_weights(self, *args, **kwargs) -> "UltravoxModel":
79
- if "torch_dtype" in kwargs:
80
- self.config.torch_dtype = kwargs.pop("torch_dtype")
81
-
82
- kwargs.pop("config", None)
83
-
84
- if (
85
- self.config.text_model_id is not None
86
- and self.language_model.device.type == "meta"
87
- ):
88
- # Load the language model weights
89
- self.language_model = transformers.AutoModelForCausalLM.from_pretrained(
90
- self.config.text_model_id,
91
- torch_dtype=self.config.torch_dtype,
92
- *args,
93
- **kwargs,
94
- )
95
-
96
- if (
97
- self.config.audio_model_id is not None
98
- and self.audio_tower.device.type == "meta"
99
- ):
100
- # Load the audio tower weights
101
- self.audio_tower = transformers.AutoModel.from_pretrained(
102
- self.config.audio_model_id,
103
- torch_dtype=self.config.torch_dtype,
104
- *args,
105
- **kwargs,
106
- )
107
-
108
- return self
109
-
110
  def get_input_embeddings(self):
111
  return self.language_model.get_input_embeddings()
112
 
@@ -153,21 +148,29 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
153
  self.vocab_size = model_embeds.num_embeddings
154
  return model_embeds
155
 
156
- def _get_prediction_mask(self, labels: Optional[torch.Tensor]) -> torch.Tensor:
157
- """Get a boolean mask for positions where we want to compute KL divergence.
 
 
158
 
159
  For each label position, we want the position before it since that's where
160
  the model makes the prediction for that label.
161
 
 
 
 
162
  Args:
163
  labels: Tensor of shape (B, T) where B is batch size and T is sequence length,
164
  with -100 for masked positions and token ids for label positions
165
 
166
  Returns:
167
- Boolean tensor of shape (B, T) that's True for positions where we want to compute KL divergence
 
 
168
  """
169
  if labels is None:
170
  raise ValueError("labels must be provided")
 
171
  # Shift the label mask right by 1 along the sequence dimension
172
  # This gives us positions where we make predictions for the next token
173
  label_mask = labels != -100
@@ -175,7 +178,19 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
175
  pred_mask[:, :-1] = label_mask[
176
  :, 1:
177
  ] # shift right by 1 along sequence dimension
178
- return pred_mask
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
  def _compute_kl_loss(
181
  self,
@@ -198,21 +213,38 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
198
  past_key_values=past_key_values,
199
  **kwargs,
200
  )
201
- # compute the KL divergence loss between the two models
 
 
 
 
 
202
  kl_loss = F.kl_div(
203
  F.log_softmax(
204
- lm_output.logits[self._get_prediction_mask(labels)]
205
- / self.loss_config.kl_temperature,
 
 
 
 
 
 
 
 
 
 
 
 
206
  dim=-1,
207
  ),
208
  F.softmax(
209
- alt_lm_output.logits[self._get_prediction_mask(alt_labels)]
210
- / self.loss_config.kl_temperature,
211
  dim=-1,
212
  ),
213
  reduction="batchmean",
214
  )
215
- return {"loss": kl_loss}
 
216
 
217
  def _audio_iter(
218
  self, audio_batch_size: torch.Tensor
@@ -380,18 +412,27 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
380
  cls, config: UltravoxConfig
381
  ) -> "UltravoxProjector":
382
  projector = UltravoxProjector(config)
383
- projector.to(config.torch_dtype)
 
 
 
384
  return projector
385
 
386
  @classmethod
387
  def _create_audio_tower(
388
  cls, config: UltravoxConfig
389
  ) -> Union[transformers.Wav2Vec2Model, "ModifiedWhisperEncoder"]:
390
- with transformers.modeling_utils.no_init_weights():
391
- # we only ever use from_config if the weights are retrained, hence initializing is not
392
- # required. This makes the model quite creation faster since init on CPU is quite slow.
393
- if "whisper" in config.audio_config._name_or_path.lower():
394
- audio_tower = ModifiedWhisperEncoder(config.audio_config)
 
 
 
 
 
 
395
  audio_tower.init_latency_mask(
396
  config.audio_latency_block_size, dtype=config.torch_dtype
397
  )
@@ -400,7 +441,27 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
400
  None,
401
  0,
402
  ), "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'"
403
- audio_tower = transformers.AutoModel.from_config(config.audio_config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
 
405
  if isinstance(
406
  audio_tower,
@@ -418,14 +479,27 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
418
  def _create_language_model(
419
  cls, config: UltravoxConfig
420
  ) -> transformers.LlamaForCausalLM:
421
- with transformers.modeling_utils.no_init_weights():
422
- # we only ever use from_config if the weights are retrained, hence initializing is not
423
- # required. This makes the model quite creation faster since init on CPU is quite slow.
424
- language_model = transformers.AutoModelForCausalLM.from_config(
425
- config.text_config,
426
- attn_implementation=config.text_config._attn_implementation,
427
- torch_dtype=config.torch_dtype,
 
 
 
 
428
  )
 
 
 
 
 
 
 
 
 
429
 
430
  language_model = apply_lora(language_model, config.text_model_lora_config)
431
  return language_model
@@ -525,6 +599,39 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
525
  )
526
 
527
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
528
  # TODO: refactor common parts to a shared module
529
  def is_cache_empty(
530
  past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]],
 
2
  import re
3
  from typing import Any, Dict, Generator, Optional, Set, Tuple, TypeVar, Union
4
 
5
+ import accelerate
6
  import peft
7
  import torch
8
  import torch.nn as nn
 
20
  from .ultravox_config import LossFunction
21
  from .ultravox_config import UltravoxConfig
22
 
23
+ FROM_PRETRAINED_KWARGS = {}
24
+ SHARED_PRETRAINED_KWARGS = [
25
+ "tp_plan",
26
+ "device_map",
27
+ "torch_dtype",
28
+ "attn_implementation",
29
+ "use_flash_attention_2",
30
+ ]
31
+
32
 
33
  class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
34
  """
 
79
  self.loss_config = LossConfig()
80
  self.post_init()
81
 
82
+ def _init_weights(self, module):
83
+ if module is self:
84
+ if self.config.text_model_id is not None:
85
+ self.language_model = self._create_language_model(self.config)
86
+ if self.config.audio_model_id is not None:
87
+ self.audio_tower = self._create_audio_tower(self.config)
88
+ elif module in self.language_model.modules():
89
+ pass
90
+ elif module in self.audio_tower.modules():
91
+ pass
92
+ else:
93
+ super()._init_weights(module)
94
+
95
  @classmethod
96
+ def from_pretrained(cls, *args, **kwargs):
97
+ global FROM_PRETRAINED_KWARGS
98
+ FROM_PRETRAINED_KWARGS = {
99
+ k: v for k, v in kwargs.items() if k in SHARED_PRETRAINED_KWARGS
100
+ }
101
+ model = super().from_pretrained(*args, **kwargs)
102
+ FROM_PRETRAINED_KWARGS = {}
103
  return model
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  def get_input_embeddings(self):
106
  return self.language_model.get_input_embeddings()
107
 
 
148
  self.vocab_size = model_embeds.num_embeddings
149
  return model_embeds
150
 
151
+ def _get_prediction_mask(
152
+ self, labels: Optional[torch.Tensor]
153
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
154
+ """Get boolean masks for positions where we want to compute KL divergence.
155
 
156
  For each label position, we want the position before it since that's where
157
  the model makes the prediction for that label.
158
 
159
+ Additionally, we want to identify the position right before the EOT token
160
+ (the last token with label != -100).
161
+
162
  Args:
163
  labels: Tensor of shape (B, T) where B is batch size and T is sequence length,
164
  with -100 for masked positions and token ids for label positions
165
 
166
  Returns:
167
+ Tuple containing:
168
+ - pred_mask: Boolean tensor of shape (B, T) that's True for positions where we want to compute KL divergence
169
+ - eot_mask: Boolean tensor of shape (B, T) that's True only for the last prediction position in each sequence
170
  """
171
  if labels is None:
172
  raise ValueError("labels must be provided")
173
+
174
  # Shift the label mask right by 1 along the sequence dimension
175
  # This gives us positions where we make predictions for the next token
176
  label_mask = labels != -100
 
178
  pred_mask[:, :-1] = label_mask[
179
  :, 1:
180
  ] # shift right by 1 along sequence dimension
181
+
182
+ # Create EOT mask - identify only the last prediction position in each sequence
183
+ eot_mask = torch.zeros_like(pred_mask)
184
+ batch_size = labels.shape[0]
185
+
186
+ for i in range(batch_size):
187
+ # Find positions where we make predictions
188
+ pred_positions = torch.where(pred_mask[i])[0]
189
+ if len(pred_positions) > 0:
190
+ # Only mark the last prediction position
191
+ eot_mask[i, pred_positions[-1]] = True
192
+
193
+ return pred_mask, eot_mask
194
 
195
  def _compute_kl_loss(
196
  self,
 
213
  past_key_values=past_key_values,
214
  **kwargs,
215
  )
216
+
217
+ # Get prediction masks for regular tokens and EOT tokens
218
+ pred_mask, eot_mask = self._get_prediction_mask(labels)
219
+ alt_pred_mask, alt_eot_mask = self._get_prediction_mask(alt_labels)
220
+
221
+ # compute the KL divergence loss between the two models for regular tokens
222
  kl_loss = F.kl_div(
223
  F.log_softmax(
224
+ lm_output.logits[pred_mask] / self.loss_config.kl_temperature,
225
+ dim=-1,
226
+ ),
227
+ F.softmax(
228
+ alt_lm_output.logits[alt_pred_mask] / self.loss_config.kl_temperature,
229
+ dim=-1,
230
+ ),
231
+ reduction="batchmean",
232
+ )
233
+
234
+ # Compute the KL divergence loss for EOT token positions if any exist
235
+ eot_loss = F.kl_div(
236
+ F.log_softmax(
237
+ lm_output.logits[eot_mask] / self.loss_config.kl_temperature,
238
  dim=-1,
239
  ),
240
  F.softmax(
241
+ alt_lm_output.logits[alt_eot_mask] / self.loss_config.kl_temperature,
 
242
  dim=-1,
243
  ),
244
  reduction="batchmean",
245
  )
246
+
247
+ return {"loss": kl_loss + self.loss_config.eot_loss_weight * eot_loss}
248
 
249
  def _audio_iter(
250
  self, audio_batch_size: torch.Tensor
 
412
  cls, config: UltravoxConfig
413
  ) -> "UltravoxProjector":
414
  projector = UltravoxProjector(config)
415
+ dtype = config.torch_dtype
416
+ if isinstance(dtype, str):
417
+ dtype = getattr(torch, dtype)
418
+ projector.to(dtype)
419
  return projector
420
 
421
  @classmethod
422
  def _create_audio_tower(
423
  cls, config: UltravoxConfig
424
  ) -> Union[transformers.Wav2Vec2Model, "ModifiedWhisperEncoder"]:
425
+ # We probably don't want to pass tp_plan or device_map to the audio tower
426
+ # But potentially other kwargs can be passed in. TODO
427
+ kwargs = {"torch_dtype": config.torch_dtype}
428
+ if (
429
+ transformers.modeling_utils._init_weights
430
+ and config.audio_model_id is not None
431
+ ):
432
+ if "whisper" in config.audio_model_id.lower():
433
+ audio_tower = ModifiedWhisperEncoder.from_pretrained(
434
+ config.audio_model_id, **kwargs
435
+ )
436
  audio_tower.init_latency_mask(
437
  config.audio_latency_block_size, dtype=config.torch_dtype
438
  )
 
441
  None,
442
  0,
443
  ), "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'"
444
+ audio_tower = transformers.AutoModel.from_pretrained(
445
+ config.audio_model_id, **kwargs
446
+ )
447
+ else:
448
+ with accelerate.init_empty_weights():
449
+ if "whisper" in config.audio_config._name_or_path.lower():
450
+ audio_tower = ModifiedWhisperEncoder(config.audio_config)
451
+ audio_tower.init_latency_mask(
452
+ config.audio_latency_block_size,
453
+ dtype=config.torch_dtype,
454
+ )
455
+ else:
456
+ assert config.audio_latency_block_size in (
457
+ None,
458
+ 0,
459
+ ), "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'"
460
+ # we only ever use from_config if the weights are retrained, hence initializing is not
461
+ # required. This makes the model quite creation faster since init on CPU is quite slow.
462
+ audio_tower = transformers.AutoModel.from_config(
463
+ config.audio_config, **kwargs
464
+ )
465
 
466
  if isinstance(
467
  audio_tower,
 
479
  def _create_language_model(
480
  cls, config: UltravoxConfig
481
  ) -> transformers.LlamaForCausalLM:
482
+ if (
483
+ transformers.modeling_utils._init_weights
484
+ and config.text_model_id is not None
485
+ ):
486
+ language_model = transformers.AutoModelForCausalLM.from_pretrained(
487
+ config.text_model_id,
488
+ **{
489
+ "attn_implementation": config.text_config._attn_implementation,
490
+ "torch_dtype": config.torch_dtype,
491
+ **FROM_PRETRAINED_KWARGS,
492
+ },
493
  )
494
+ else:
495
+ with accelerate.init_empty_weights():
496
+ # we only ever use from_config if the weights are retrained, hence initializing is not
497
+ # required. This makes the model quite creation faster since init on CPU is quite slow.
498
+ language_model = transformers.AutoModelForCausalLM.from_config(
499
+ config.text_config,
500
+ attn_implementation=config.text_config._attn_implementation,
501
+ torch_dtype=config.torch_dtype,
502
+ )
503
 
504
  language_model = apply_lora(language_model, config.text_model_lora_config)
505
  return language_model
 
599
  )
600
 
601
 
602
+ def get_checkpoint_files(
603
+ model_id: str,
604
+ ) -> tuple[list[str], dict | None, list[str]]:
605
+ resolved_archive_file = transformers.utils.cached_file(
606
+ model_id,
607
+ transformers.utils.SAFE_WEIGHTS_NAME,
608
+ _raise_exceptions_for_missing_entries=False,
609
+ )
610
+
611
+ if resolved_archive_file is not None:
612
+ # not sharded
613
+ sharded_metadata = None
614
+ state_dict = transformers.modeling_utils.load_state_dict(resolved_archive_file)
615
+ loaded_state_dict_keys = list(state_dict.keys())
616
+ else:
617
+ # sharded
618
+ resolved_archive_file = transformers.utils.cached_file(
619
+ model_id, transformers.utils.SAFE_WEIGHTS_INDEX_NAME
620
+ )
621
+ resolved_archive_file, sharded_metadata = (
622
+ transformers.modeling_utils.get_checkpoint_shard_files(
623
+ model_id,
624
+ resolved_archive_file,
625
+ )
626
+ )
627
+ loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
628
+
629
+ if isinstance(resolved_archive_file, str):
630
+ resolved_archive_file = [resolved_archive_file]
631
+
632
+ return resolved_archive_file, sharded_metadata, loaded_state_dict_keys
633
+
634
+
635
  # TODO: refactor common parts to a shared module
636
  def is_cache_empty(
637
  past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]],