Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2024-present the HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import copy | |
import torch | |
from torch.nn import CrossEntropyLoss | |
from peft.utils.integrations import gather_params_ctx | |
class CPTEmbedding(torch.nn.Module): | |
""" | |
CPTEmbedding is a custom embedding layer designed for Context-aware Prompt Tuning (CPT) in PEFT. It initializes | |
embeddings, applies prompt-specific projections, and computes loss using label masks. | |
""" | |
def __init__(self, config, word_embeddings): | |
""" | |
Initializes the CPTEmbedding module. | |
Args: | |
config (Namespace): | |
Configuration object containing model hyperparameters and CPT-specific settings. | |
word_embeddings (torch.nn.Embedding): | |
The base word embedding layer used to initialize CPT embeddings. | |
""" | |
super().__init__() | |
self.config = copy.deepcopy(config) | |
num_virtual_tokens = config.num_virtual_tokens | |
# Initialize embeddings with virtual token dimensions | |
self.embedding = torch.nn.Embedding(num_virtual_tokens, config.token_dim) | |
# Initialize embeddings using text-based prompt tuning, if configured | |
if not config.inference_mode: | |
assert config.num_virtual_tokens == len(config.cpt_token_ids) | |
init_token_ids = torch.LongTensor(config.cpt_token_ids).to(word_embeddings.weight.device) | |
with gather_params_ctx(word_embeddings.parameters()): | |
word_embedding_weights = word_embeddings(init_token_ids).detach().clone() | |
word_embedding_weights = word_embedding_weights.to(torch.float32) | |
self.embedding.weight = torch.nn.Parameter(word_embedding_weights) | |
# Initialize delta embedding with zero weights | |
self.delta_embedding = torch.nn.Embedding(num_virtual_tokens, config.token_dim) | |
self.delta_embedding.weight.data = torch.zeros_like(self.delta_embedding.weight).to(torch.float32) | |
# Apply hook for backward gradient updates | |
self.set_updated_tokens() | |
def forward(self, indices): | |
""" | |
Computes the prompt embeddings and applies delta adjustments. | |
Args: | |
indices (torch.Tensor): | |
Indices of the tokens to be embedded. | |
Returns: | |
torch.Tensor: | |
Sum of prompt embeddings and delta embeddings. | |
""" | |
with torch.no_grad(): | |
prompt_embeddings = self.embedding(indices) | |
self.delta_embedding.weight.data = self.get_projection() # Apply epsilon-based projection | |
delta_prompt_embeddings = self.delta_embedding(indices) | |
return prompt_embeddings + delta_prompt_embeddings | |
def set_updated_tokens(self): | |
""" | |
Sets up a backward hook to selectively update token gradients based on the CPT token type mask. | |
""" | |
tensor_ICL_mask = torch.Tensor(self.config.cpt_tokens_type_mask).long() | |
mask_input_template = torch.remainder(tensor_ICL_mask, 4) == 1 | |
mask_input = torch.remainder(tensor_ICL_mask, 4) == 2 | |
mask_output_template = torch.remainder(tensor_ICL_mask, 4) == 3 | |
mask = mask_input_template | mask_input | mask_output_template | |
mask = mask.view(-1, 1) | |
def backward_hook(grad): | |
grad = grad * mask.to(grad.device) # Apply mask to gradients | |
return grad | |
self.delta_embedding.weight.register_hook(backward_hook) | |
def get_epsilon(self): | |
cpt_tokens_type_mask = self.config.cpt_tokens_type_mask | |
MIN_VALUE = 1e-10 | |
# Calculate normalized epsilon values for input, output, and format tokens | |
normalized_format_eps = self.config.opt_projection_format_epsilon * torch.sqrt( | |
torch.Tensor([self.config.token_dim / 2048]) | |
) | |
normalized_input_eps = self.config.opt_projection_epsilon * torch.sqrt( | |
torch.Tensor([self.config.token_dim / 2048]) | |
) | |
epsilon = torch.ones_like(torch.Tensor(cpt_tokens_type_mask)).to(torch.float32) * MIN_VALUE | |
cpt_tokens_type_mask = torch.Tensor(cpt_tokens_type_mask).long() | |
epsilon[(cpt_tokens_type_mask > 0) & (torch.remainder(cpt_tokens_type_mask, 4) == 1)] = normalized_format_eps | |
epsilon[(cpt_tokens_type_mask > 0) & (torch.remainder(cpt_tokens_type_mask, 4) == 3)] = normalized_format_eps | |
epsilon[(cpt_tokens_type_mask > 0) & (torch.remainder(cpt_tokens_type_mask, 4) == 2)] = normalized_input_eps | |
return epsilon | |
def get_projection(self): | |
""" | |
Applies epsilon-based projection to the delta embeddings to control their norm. | |
""" | |
# Apply projection to control delta embedding norm | |
with torch.no_grad(): | |
new_embeddings_weights = self.delta_embedding.weight.clone().to(self.delta_embedding.weight.device) | |
token_norm = torch.norm(new_embeddings_weights, p=2, dim=1) | |
projection_mask = token_norm > 0 | |
if torch.any(projection_mask): | |
epsilon = self.get_epsilon().to(self.delta_embedding.weight.device) | |
new_embeddings_weights[projection_mask] *= ( | |
epsilon[projection_mask] / (token_norm[projection_mask].clamp(min=epsilon[projection_mask])) | |
).view(-1, 1) | |
return new_embeddings_weights | |
def calculate_loss(base_model_output, labels, cpt_type_mask, config): | |
""" | |
Computes the loss for CPT models with optional exponential decay. | |
Args: | |
base_model_output (ModelOutput): | |
Output from the base model containing logits. | |
labels (torch.Tensor): | |
Ground-truth labels for the input tokens. | |
cpt_type_mask (torch.Tensor): | |
Token type mask used for filtering valid loss terms. | |
config (Namespace): | |
Configuration object containing loss-related hyperparameters. | |
Returns: | |
ModelOutput: | |
The base model output with computed loss. | |
""" | |
device = base_model_output.logits.device | |
lm_logits = base_model_output.logits | |
labels = labels.to(device) | |
# Shift logits and labels for token prediction | |
shift_logits = lm_logits[..., :-1, :].contiguous() | |
shift_labels = labels[..., 1:].contiguous() | |
shift_cpt_type_mask = cpt_type_mask[..., 1:].contiguous() | |
shift_labels_bool = (shift_labels.clone().detach() != -100).bool() | |
batch_size, seq_length, vocab_size = shift_logits.shape | |
# Compute cross-entropy loss | |
loss_fct = CrossEntropyLoss(reduction="none", ignore_index=-100) | |
loss = loss_fct( | |
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length) | |
) | |
loss = loss.view(batch_size, seq_length) | |
# Apply exponential decay weights to the loss | |
shift_labels_weights = shift_labels_bool.clone().detach().float() | |
for i in range(batch_size): | |
idx_labels = (shift_cpt_type_mask[i] > 0) & (shift_cpt_type_mask[i] % 4 == 0) | |
labels_ids = shift_cpt_type_mask[i][idx_labels].unique() | |
exponential_decay = torch.ones_like(shift_cpt_type_mask[i]).to(device=device).float() | |
decay_value = 1 | |
for label_mask_idx in torch.flip(labels_ids, [0]): | |
exponential_decay[shift_cpt_type_mask[i] == label_mask_idx] = decay_value | |
decay_value *= config.opt_loss_decay_factor | |
if config.opt_weighted_loss_type == "decay": | |
shift_labels_weights[i] *= exponential_decay | |
# Compute the weighted mean loss | |
loss = (loss[shift_labels_bool] * shift_labels_weights[shift_labels_bool]).mean() | |
base_model_output.loss = loss | |
return base_model_output | |