RiverZ's picture
upd
3b609b9
raw
history blame
8.41 kB
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import torch
from torch.nn import CrossEntropyLoss
from peft.utils.integrations import gather_params_ctx
class CPTEmbedding(torch.nn.Module):
"""
CPTEmbedding is a custom embedding layer designed for Context-aware Prompt Tuning (CPT) in PEFT. It initializes
embeddings, applies prompt-specific projections, and computes loss using label masks.
"""
def __init__(self, config, word_embeddings):
"""
Initializes the CPTEmbedding module.
Args:
config (Namespace):
Configuration object containing model hyperparameters and CPT-specific settings.
word_embeddings (torch.nn.Embedding):
The base word embedding layer used to initialize CPT embeddings.
"""
super().__init__()
self.config = copy.deepcopy(config)
num_virtual_tokens = config.num_virtual_tokens
# Initialize embeddings with virtual token dimensions
self.embedding = torch.nn.Embedding(num_virtual_tokens, config.token_dim)
# Initialize embeddings using text-based prompt tuning, if configured
if not config.inference_mode:
assert config.num_virtual_tokens == len(config.cpt_token_ids)
init_token_ids = torch.LongTensor(config.cpt_token_ids).to(word_embeddings.weight.device)
with gather_params_ctx(word_embeddings.parameters()):
word_embedding_weights = word_embeddings(init_token_ids).detach().clone()
word_embedding_weights = word_embedding_weights.to(torch.float32)
self.embedding.weight = torch.nn.Parameter(word_embedding_weights)
# Initialize delta embedding with zero weights
self.delta_embedding = torch.nn.Embedding(num_virtual_tokens, config.token_dim)
self.delta_embedding.weight.data = torch.zeros_like(self.delta_embedding.weight).to(torch.float32)
# Apply hook for backward gradient updates
self.set_updated_tokens()
def forward(self, indices):
"""
Computes the prompt embeddings and applies delta adjustments.
Args:
indices (torch.Tensor):
Indices of the tokens to be embedded.
Returns:
torch.Tensor:
Sum of prompt embeddings and delta embeddings.
"""
with torch.no_grad():
prompt_embeddings = self.embedding(indices)
self.delta_embedding.weight.data = self.get_projection() # Apply epsilon-based projection
delta_prompt_embeddings = self.delta_embedding(indices)
return prompt_embeddings + delta_prompt_embeddings
def set_updated_tokens(self):
"""
Sets up a backward hook to selectively update token gradients based on the CPT token type mask.
"""
tensor_ICL_mask = torch.Tensor(self.config.cpt_tokens_type_mask).long()
mask_input_template = torch.remainder(tensor_ICL_mask, 4) == 1
mask_input = torch.remainder(tensor_ICL_mask, 4) == 2
mask_output_template = torch.remainder(tensor_ICL_mask, 4) == 3
mask = mask_input_template | mask_input | mask_output_template
mask = mask.view(-1, 1)
def backward_hook(grad):
grad = grad * mask.to(grad.device) # Apply mask to gradients
return grad
self.delta_embedding.weight.register_hook(backward_hook)
def get_epsilon(self):
cpt_tokens_type_mask = self.config.cpt_tokens_type_mask
MIN_VALUE = 1e-10
# Calculate normalized epsilon values for input, output, and format tokens
normalized_format_eps = self.config.opt_projection_format_epsilon * torch.sqrt(
torch.Tensor([self.config.token_dim / 2048])
)
normalized_input_eps = self.config.opt_projection_epsilon * torch.sqrt(
torch.Tensor([self.config.token_dim / 2048])
)
epsilon = torch.ones_like(torch.Tensor(cpt_tokens_type_mask)).to(torch.float32) * MIN_VALUE
cpt_tokens_type_mask = torch.Tensor(cpt_tokens_type_mask).long()
epsilon[(cpt_tokens_type_mask > 0) & (torch.remainder(cpt_tokens_type_mask, 4) == 1)] = normalized_format_eps
epsilon[(cpt_tokens_type_mask > 0) & (torch.remainder(cpt_tokens_type_mask, 4) == 3)] = normalized_format_eps
epsilon[(cpt_tokens_type_mask > 0) & (torch.remainder(cpt_tokens_type_mask, 4) == 2)] = normalized_input_eps
return epsilon
def get_projection(self):
"""
Applies epsilon-based projection to the delta embeddings to control their norm.
"""
# Apply projection to control delta embedding norm
with torch.no_grad():
new_embeddings_weights = self.delta_embedding.weight.clone().to(self.delta_embedding.weight.device)
token_norm = torch.norm(new_embeddings_weights, p=2, dim=1)
projection_mask = token_norm > 0
if torch.any(projection_mask):
epsilon = self.get_epsilon().to(self.delta_embedding.weight.device)
new_embeddings_weights[projection_mask] *= (
epsilon[projection_mask] / (token_norm[projection_mask].clamp(min=epsilon[projection_mask]))
).view(-1, 1)
return new_embeddings_weights
@staticmethod
def calculate_loss(base_model_output, labels, cpt_type_mask, config):
"""
Computes the loss for CPT models with optional exponential decay.
Args:
base_model_output (ModelOutput):
Output from the base model containing logits.
labels (torch.Tensor):
Ground-truth labels for the input tokens.
cpt_type_mask (torch.Tensor):
Token type mask used for filtering valid loss terms.
config (Namespace):
Configuration object containing loss-related hyperparameters.
Returns:
ModelOutput:
The base model output with computed loss.
"""
device = base_model_output.logits.device
lm_logits = base_model_output.logits
labels = labels.to(device)
# Shift logits and labels for token prediction
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
shift_cpt_type_mask = cpt_type_mask[..., 1:].contiguous()
shift_labels_bool = (shift_labels.clone().detach() != -100).bool()
batch_size, seq_length, vocab_size = shift_logits.shape
# Compute cross-entropy loss
loss_fct = CrossEntropyLoss(reduction="none", ignore_index=-100)
loss = loss_fct(
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
)
loss = loss.view(batch_size, seq_length)
# Apply exponential decay weights to the loss
shift_labels_weights = shift_labels_bool.clone().detach().float()
for i in range(batch_size):
idx_labels = (shift_cpt_type_mask[i] > 0) & (shift_cpt_type_mask[i] % 4 == 0)
labels_ids = shift_cpt_type_mask[i][idx_labels].unique()
exponential_decay = torch.ones_like(shift_cpt_type_mask[i]).to(device=device).float()
decay_value = 1
for label_mask_idx in torch.flip(labels_ids, [0]):
exponential_decay[shift_cpt_type_mask[i] == label_mask_idx] = decay_value
decay_value *= config.opt_loss_decay_factor
if config.opt_weighted_loss_type == "decay":
shift_labels_weights[i] *= exponential_decay
# Compute the weighted mean loss
loss = (loss[shift_labels_bool] * shift_labels_weights[shift_labels_bool]).mean()
base_model_output.loss = loss
return base_model_output