|
|
|
|
|
|
|
from collections.abc import Iterable |
|
from typing import Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from vllm.config import VllmConfig |
|
from vllm.logger import init_logger |
|
from vllm.model_executor.layers.layernorm import RMSNorm |
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor |
|
from vllm.model_executor.layers.vocab_parallel_embedding import ( |
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) |
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader |
|
from vllm.model_executor.models import ModelRegistry |
|
from vllm.model_executor.sampling_metadata import SamplingMetadata |
|
from vllm.sequence import IntermediateTensors |
|
|
|
from .utils import maybe_prefix |
|
|
|
logger = init_logger(__name__) |
|
|
|
|
|
class DummyInputLayerNorm(nn.Module): |
|
|
|
def __init__(self, weight=None, bias=None): |
|
super().__init__() |
|
self.weight = nn.Parameter(weight) if weight is not None else None |
|
self.bias = nn.Parameter(bias) if bias is not None else None |
|
|
|
def forward(self, x): |
|
return x |
|
|
|
|
|
class DummyOutputNorm(nn.Module): |
|
|
|
def forward(self, x, residual): |
|
if residual is None: |
|
return x |
|
else: |
|
return x + residual, None |
|
|
|
|
|
class EAGLE(nn.Module): |
|
"""This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077 |
|
Reference implementation: https://github.com/SafeAILab/EAGLE |
|
|
|
Differences from reference implementation: |
|
1. In reference, LlamaDecoderLayer implementation doesn't have |
|
input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427). |
|
Following this approach, our implementation also disables |
|
the input_layernorm for the first decoder layer. |
|
2. We allow any decoder layer to be used in EAGLE whereas in reference |
|
decoder layer is fixed to be LlamaDecoderLayer. |
|
3. We have an optional token_map which reduces draft vocab to most |
|
frequently used tokens to give some additional speed-up by reducing |
|
sampling overhead. This is disabled unless the checkpoint file has |
|
explicit token_map tensor and config has an optional attribute |
|
truncated_vocab_size < vocab_size. To use this technique, one has to find |
|
the top-k most frequent tokens in target dataset and add that as a tensor |
|
in the draft checkpoint (using key token_map). Also, the draft config |
|
needs to have truncated_vocab_size (=k) as an attribute. |
|
4. We allow an enhanced EAGLE architecture similar to the DeepSeek MTP |
|
module with regards to the use of additional RMS norms. The original |
|
EAGLE architecture 1) skips the pre-attention norm in its first |
|
transformer block, and 2) skips the final output norm, both of which we |
|
found to be suboptimal. We also add the support for separate norms |
|
applying to both the token embedding and hidden states before projection |
|
as in DeepSeek MTP, which we found to improve performance as well. |
|
""" |
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
|
super().__init__() |
|
config = vllm_config.model_config.hf_config |
|
self.config = config |
|
|
|
architectures = getattr(self.config.model, "architectures", []) |
|
model_cls, _ = ModelRegistry.resolve_model_cls(architectures) |
|
|
|
self.model = model_cls(vllm_config=vllm_config, |
|
prefix=maybe_prefix(prefix, "model")) |
|
|
|
self.fc = nn.Linear(config.model.hidden_size * 2, |
|
config.model.hidden_size, |
|
bias=getattr(self.config, "eagle_fc_bias", False)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if not hasattr(self.config.model, |
|
"skip_prenorm") or self.config.model.skip_prenorm: |
|
self.model.model.layers[0].input_layernorm = DummyInputLayerNorm( |
|
weight=self.model.model.layers[0].input_layernorm.weight) |
|
|
|
if not hasattr( |
|
self.config.model, |
|
"skip_output_norm") or self.config.model.skip_output_norm: |
|
self.model.model.norm = DummyOutputNorm() |
|
|
|
self.add_para_norm = False |
|
if hasattr(self.config.model, |
|
"add_para_norm") and self.config.model.add_para_norm: |
|
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
self.add_para_norm = True |
|
|
|
self.orig_vocab_size = config.vocab_size |
|
self.truncated_vocab_size = config.truncated_vocab_size |
|
self.unpadded_vocab_size = self.truncated_vocab_size |
|
|
|
self.lm_head = ParallelLMHead( |
|
self.unpadded_vocab_size, |
|
config.hidden_size, |
|
org_num_embeddings=self.truncated_vocab_size, |
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE, |
|
) |
|
|
|
logit_scale = getattr(config, "logit_scale", 1.0) |
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, |
|
self.truncated_vocab_size, |
|
logit_scale) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.token_map = None |
|
|
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: |
|
return self.model.model.get_input_embeddings(input_ids) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
positions: torch.Tensor, |
|
previous_hidden_states: torch.Tensor, |
|
intermediate_tensors: Optional[IntermediateTensors] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.get_input_embeddings(input_ids) |
|
|
|
|
|
|
|
batch_size = inputs_embeds.size(0) |
|
if previous_hidden_states.size(0) == 0 or \ |
|
previous_hidden_states.size(0) != batch_size: |
|
hidden_dim = self.config.model.hidden_size |
|
device = inputs_embeds.device |
|
|
|
previous_hidden_states = \ |
|
torch.zeros(batch_size, hidden_dim, device=device) |
|
|
|
if self.add_para_norm: |
|
inputs_embeds = torch.cat([ |
|
self.enorm(inputs_embeds), |
|
self.hnorm(previous_hidden_states) |
|
], |
|
dim=-1) |
|
else: |
|
inputs_embeds = torch.cat([inputs_embeds, previous_hidden_states], |
|
dim=-1) |
|
|
|
inputs_embeds = self.fc(inputs_embeds) |
|
|
|
inputs_embeds[positions == 0] = 0 |
|
|
|
hidden_states = self.model.model( |
|
input_ids=None, |
|
inputs_embeds=inputs_embeds, |
|
positions=positions, |
|
intermediate_tensors=intermediate_tensors, |
|
) |
|
return hidden_states |
|
|
|
def compute_logits(self, hidden_states: torch.Tensor, |
|
sampling_metadata: SamplingMetadata) -> torch.Tensor: |
|
logits = self.logits_processor(self.lm_head, hidden_states, |
|
sampling_metadata) |
|
|
|
if self.token_map is not None: |
|
_logits = logits |
|
logits = -torch.inf * torch.ones( |
|
size=(*_logits.shape[:-1], self.orig_vocab_size), |
|
device=_logits.device, |
|
dtype=_logits.dtype) |
|
|
|
logits[..., self.token_map] = _logits |
|
|
|
return logits |
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): |
|
|
|
|
|
|
|
|
|
|
|
|
|
model_weights = {} |
|
for name, loaded_weight in weights: |
|
if name == "token_map": |
|
if self.config.truncated_vocab_size < self.config.vocab_size: |
|
self.token_map = nn.Parameter(loaded_weight, |
|
requires_grad=False) |
|
elif name.startswith("fc.weight"): |
|
weight_loader = getattr(self.fc.weight, "weight_loader", |
|
default_weight_loader) |
|
weight_loader(self.fc.weight, loaded_weight) |
|
elif name.startswith("fc.bias"): |
|
if self.fc.bias is not None: |
|
weight_loader = getattr(self.fc.bias, "weight_loader", |
|
default_weight_loader) |
|
weight_loader(self.fc.bias, loaded_weight) |
|
else: |
|
logger.warning_once("Found bias in the loaded weights but " |
|
"the model config doesn't have bias.") |
|
elif name.startswith("enorm.weight"): |
|
weight_loader = getattr(self.enorm.weight, "weight_loader", |
|
default_weight_loader) |
|
weight_loader(self.enorm.weight, loaded_weight) |
|
elif name.startswith("hnorm.weight"): |
|
weight_loader = getattr(self.hnorm.weight, "weight_loader", |
|
default_weight_loader) |
|
weight_loader(self.hnorm.weight, loaded_weight) |
|
elif name.startswith("model.lm_head.") or name.startswith( |
|
"model.model."): |
|
model_weights[name.split("model.", 1)[-1]] = loaded_weight |
|
elif name.startswith("lm_head.") or name.startswith("model."): |
|
model_weights[name] = loaded_weight |
|
else: |
|
model_weights[f"model.{name}"] = loaded_weight |
|
|
|
if "lm_head.weight" in model_weights: |
|
lm_head_weight = model_weights.pop("lm_head.weight") |
|
|
|
if self.token_map is not None and\ |
|
lm_head_weight.shape[0] > self.token_map.shape[0]: |
|
|
|
lm_head_weight = lm_head_weight[self.token_map] |
|
|
|
else: |
|
|
|
lm_head_weight = torch.zeros( |
|
self.lm_head.org_vocab_size, |
|
self.lm_head.embedding_dim, |
|
dtype=self.config.torch_dtype, |
|
) |
|
|
|
weight_loader = getattr(self.lm_head.weight, "weight_loader", |
|
default_weight_loader) |
|
weight_loader(self.lm_head.weight, lm_head_weight) |
|
|
|
self.model.load_weights(model_weights.items()) |
|
|