|
import importlib |
|
import re |
|
import warnings |
|
from dataclasses import dataclass, field |
|
from typing import Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers.pytorch_utils import Conv1D |
|
|
|
from ..utils import ( |
|
TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING, |
|
PeftType, |
|
_freeze_adapter, |
|
_get_submodules, |
|
transpose, |
|
) |
|
from .lora import ( |
|
LoraConfig, |
|
LoraLayer, |
|
LoraModel, |
|
mark_only_lora_as_trainable, |
|
) |
|
|
|
|
|
def is_bnb_available(): |
|
return importlib.util.find_spec("bitsandbytes") is not None |
|
|
|
|
|
if is_bnb_available(): |
|
import bitsandbytes as bnb |
|
|
|
|
|
@dataclass |
|
class AdaLoraConfig(LoraConfig): |
|
""" |
|
This is the configuration class to store the configuration of a [`~peft.AdaLora`]. |
|
|
|
Args: |
|
target_r (`int`): The target average rank of incremental matrix. |
|
init_r (`int`): The initial rank for each incremental matrix. |
|
tinit (`int`): The steps of initial fine-tuning warmup. |
|
tfinal (`int`): The step of final fine-tuning. |
|
deltaT (`int`): The time internval between two budget allocations. |
|
beta1 (`float`): The hyperparameter of EMA for sensitivity smoothing. |
|
beta2 (`float`): The hyperparameter of EMA for undertainty quantification. |
|
orth_reg_weight (`float`): The coefficient of orthogonal regularization. |
|
total_step (`int`): The total training steps that should be specified before training. |
|
rank_pattern (`list`): The allocated rank for each weight matrix by RankAllocator. |
|
""" |
|
|
|
target_r: int = field(default=8, metadata={"help": "Target Lora matrix dimension."}) |
|
init_r: int = field(default=12, metadata={"help": "Intial Lora matrix dimension."}) |
|
tinit: int = field(default=0, metadata={"help": "The steps of initial warmup."}) |
|
tfinal: int = field(default=0, metadata={"help": "The steps of final warmup."}) |
|
deltaT: int = field(default=1, metadata={"help": "Step interval of rank allocation."}) |
|
beta1: float = field(default=0.85, metadata={"help": "Hyperparameter of EMA."}) |
|
beta2: float = field(default=0.85, metadata={"help": "Hyperparameter of EMA."}) |
|
orth_reg_weight: float = field(default=0.5, metadata={"help": "The orthogonal regularization coefficient."}) |
|
total_step: Optional[int] = field(default=None, metadata={"help": "The total training steps."}) |
|
rank_pattern: Optional[dict] = field(default=None, metadata={"help": "The saved rank pattern."}) |
|
|
|
def __post_init__(self): |
|
self.peft_type = PeftType.ADALORA |
|
|
|
|
|
class AdaLoraModel(LoraModel): |
|
""" |
|
Creates AdaLoRA (Adaptive LoRA) model from a pretrained transformers model. Paper: |
|
https://openreview.net/pdf?id=lq62uWRJjiY |
|
|
|
Args: |
|
model ([`transformers.PreTrainedModel`]): The model to be adapted. |
|
config ([`AdaLoraConfig`]): The configuration of the AdaLora model. |
|
|
|
Returns: |
|
`torch.nn.Module`: The AdaLora model. |
|
|
|
Example:: |
|
|
|
>>> from transformers import AutoModelForSeq2SeqLM, LoraConfig >>> from peft import AdaLoraModel, AdaLoraConfig |
|
>>> config = AdaLoraConfig( |
|
peft_type="ADALORA", task_type="SEQ_2_SEQ_LM", r=8, lora_alpha=32, target_modules=["q", "v"], |
|
lora_dropout=0.01, |
|
) |
|
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") >>> model = AdaLoraModel(config, model) |
|
|
|
**Attributes**: |
|
- **model** ([`transformers.PreTrainedModel`]) -- The model to be adapted. |
|
- **peft_config** ([`AdaLoraConfig`]): The configuration of the AdaLora model. |
|
""" |
|
|
|
def __init__(self, model, config, adapter_name): |
|
nn.Module.__init__(self) |
|
self.model = model |
|
self.peft_config = config |
|
self.add_adapter(adapter_name, self.peft_config[adapter_name]) |
|
|
|
def add_adapter(self, adapter_name, config=None): |
|
if config is not None: |
|
model_config = self.model.config.to_dict() if hasattr(self.model.config, "to_dict") else self.model.config |
|
config = self._prepare_adalora_config(config, model_config) |
|
self.peft_config[adapter_name] = config |
|
self._find_and_replace(adapter_name) |
|
if len(self.peft_config) > 1 and self.peft_config[adapter_name].bias != "none": |
|
raise ValueError( |
|
"AdaLoraModel supports only 1 adapter with bias. When using multiple adapters, set bias to 'none' for all adapters." |
|
) |
|
traininable_mode_counter = 0 |
|
for config in self.peft_config.values(): |
|
if not config.inference_mode: |
|
traininable_mode_counter += 1 |
|
|
|
if traininable_mode_counter > 1: |
|
raise ValueError( |
|
"AdaLoraModel supports only 1 trainable adapter. " |
|
"When using multiple adapters, set inference_mode to True for all adapters except the one you want to train." |
|
) |
|
|
|
mark_only_lora_as_trainable(self.model, self.peft_config[adapter_name].bias) |
|
if self.peft_config[adapter_name].inference_mode: |
|
_freeze_adapter(self.model, adapter_name) |
|
else: |
|
self.trainable_adapter_name = adapter_name |
|
self.rankallocator = RankAllocator(self.model, self.peft_config[adapter_name], self.trainable_adapter_name) |
|
|
|
def _find_and_replace(self, adapter_name): |
|
lora_config = self.peft_config[adapter_name] |
|
loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False) |
|
if loaded_in_8bit and not is_bnb_available(): |
|
raise ImportError( |
|
"To use Lora with 8-bit quantization, please install the `bitsandbytes` package. " |
|
"You can install it with `pip install bitsandbytes`." |
|
) |
|
is_target_modules_in_base_model = False |
|
kwargs = { |
|
"r": lora_config.init_r, |
|
"lora_alpha": lora_config.lora_alpha, |
|
"lora_dropout": lora_config.lora_dropout, |
|
"fan_in_fan_out": lora_config.fan_in_fan_out, |
|
"init_lora_weights": lora_config.init_lora_weights, |
|
} |
|
key_list = [key for key, _ in self.model.named_modules()] |
|
for key in key_list: |
|
if isinstance(lora_config.target_modules, str): |
|
target_module_found = re.fullmatch(lora_config.target_modules, key) |
|
else: |
|
target_module_found = any(key.endswith(target_key) for target_key in lora_config.target_modules) |
|
if target_module_found: |
|
if not is_target_modules_in_base_model: |
|
is_target_modules_in_base_model = True |
|
parent, target, target_name = _get_submodules(self.model, key) |
|
bias = target.bias is not None |
|
if isinstance(target, LoraLayer): |
|
target.update_layer( |
|
adapter_name, |
|
lora_config.init_r, |
|
lora_config.lora_alpha, |
|
lora_config.lora_dropout, |
|
lora_config.init_lora_weights, |
|
) |
|
else: |
|
if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt): |
|
kwargs.update( |
|
{ |
|
"has_fp16_weights": target.state.has_fp16_weights, |
|
"memory_efficient_backward": target.state.memory_efficient_backward, |
|
"threshold": target.state.threshold, |
|
"index": target.index, |
|
} |
|
) |
|
new_module = SVDLinear8bitLt( |
|
adapter_name, target.in_features, target.out_features, bias=bias, **kwargs |
|
) |
|
else: |
|
if isinstance(target, torch.nn.Linear): |
|
in_features, out_features = target.in_features, target.out_features |
|
if kwargs["fan_in_fan_out"]: |
|
warnings.warn( |
|
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " |
|
"Setting fan_in_fan_out to False." |
|
) |
|
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False |
|
elif isinstance(target, Conv1D): |
|
in_features, out_features = ( |
|
target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape |
|
) |
|
if not kwargs["fan_in_fan_out"]: |
|
warnings.warn( |
|
"fan_in_fan_out is set to False but the target module is `Conv1D`. " |
|
"Setting fan_in_fan_out to True." |
|
) |
|
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True |
|
else: |
|
raise ValueError( |
|
f"Target module {target} is not supported. " |
|
f"Currently, only `torch.nn.Linear` and `Conv1D` are supported." |
|
) |
|
new_module = SVDLinear(adapter_name, in_features, out_features, bias=bias, **kwargs) |
|
|
|
self._replace_module(parent, target_name, new_module, target) |
|
if not is_target_modules_in_base_model: |
|
raise ValueError( |
|
f"Target modules {lora_config.target_modules} not found in the base model. " |
|
f"Please check the target modules and try again." |
|
) |
|
|
|
def __getattr__(self, name: str): |
|
"""Forward missing attributes to the wrapped module.""" |
|
try: |
|
return super().__getattr__(name) |
|
except AttributeError: |
|
return getattr(self.model, name) |
|
|
|
def forward(self, *args, **kwargs): |
|
outputs = self.model.forward(*args, **kwargs) |
|
|
|
|
|
orth_reg_weight = self.peft_config[self.trainable_adapter_name].orth_reg_weight |
|
assert orth_reg_weight > 0 |
|
|
|
if hasattr(outputs, "loss"): |
|
regu_loss = 0 |
|
num_param = 0 |
|
for n, p in self.model.named_parameters(): |
|
if ("lora_A" in n or "lora_B" in n) and self.trainable_adapter_name in n: |
|
para_cov = p @ p.T if "lora_A" in n else p.T @ p |
|
I = torch.eye(*para_cov.size(), out=torch.empty_like(para_cov)) |
|
I.requires_grad = False |
|
num_param += 1 |
|
regu_loss += torch.norm(para_cov - I, p="fro") |
|
regu_loss = regu_loss / num_param |
|
outputs.loss += orth_reg_weight * regu_loss |
|
return outputs |
|
|
|
def resize_modules_by_rank_pattern(self, rank_pattern, adapter_name): |
|
lora_config = self.peft_config[adapter_name] |
|
for name, rank_idx in rank_pattern.items(): |
|
if isinstance(rank_idx, list): |
|
rank = sum(rank_idx) |
|
elif isinstance(rank_idx, torch.Tensor): |
|
rank_idx = rank_idx.view(-1) |
|
rank = rank_idx.sum().item() |
|
else: |
|
raise ValueError("Unexcepted type of rank_idx") |
|
key = ".".join(name.split(".")[0:-2]) if adapter_name in name else ".".join(name.split(".")[0:-1]) |
|
_, target, _ = _get_submodules(self.model, key) |
|
lora_E_weights = target.lora_E[adapter_name][rank_idx] |
|
lora_A_weights = target.lora_A[adapter_name][rank_idx] |
|
lora_B_weights = target.lora_B[adapter_name][:, rank_idx] |
|
ranknum = target.ranknum[adapter_name] |
|
target.update_layer( |
|
adapter_name, |
|
rank, |
|
lora_config.lora_alpha, |
|
lora_config.lora_dropout, |
|
lora_config.init_lora_weights, |
|
) |
|
with torch.no_grad(): |
|
if rank > 0: |
|
target.lora_E[adapter_name].copy_(lora_E_weights) |
|
target.lora_A[adapter_name].copy_(lora_A_weights) |
|
target.lora_B[adapter_name].copy_(lora_B_weights) |
|
|
|
target.ranknum[adapter_name].copy_(ranknum) |
|
|
|
def resize_state_dict_by_rank_pattern(self, rank_pattern, state_dict, adapter_name): |
|
for name, rank_idx in rank_pattern.items(): |
|
rank = sum(rank_idx) |
|
prefix = ".".join(name.split(".")[0:-2]) if adapter_name in name else ".".join(name.split(".")[0:-1]) |
|
for layer in ["lora_E", "lora_A", "lora_B"]: |
|
key = f"base_model.model.{prefix}.{layer}.{adapter_name}" |
|
if layer != "lora_B": |
|
state_dict[key] = ( |
|
state_dict[key][rank_idx] if rank != state_dict[key].shape[0] else state_dict[key] |
|
) |
|
else: |
|
state_dict[key] = ( |
|
state_dict[key][:, rank_idx] if rank != state_dict[key].shape[1] else state_dict[key] |
|
) |
|
return state_dict |
|
|
|
def update_and_allocate(self, global_step): |
|
lora_config = self.peft_config[self.trainable_adapter_name] |
|
|
|
if global_step < lora_config.total_step - lora_config.tfinal: |
|
_, rank_pattern = self.rankallocator.update_and_allocate(self.model, global_step) |
|
if rank_pattern: |
|
lora_config.rank_pattern = rank_pattern |
|
|
|
elif global_step == lora_config.total_step - lora_config.tfinal: |
|
_, rank_pattern = self.rankallocator.update_and_allocate(self.model, global_step, force_mask=True) |
|
|
|
|
|
lora_config.rank_pattern = rank_pattern |
|
self.rankallocator.reset_ipt() |
|
|
|
|
|
elif global_step > lora_config.total_step - lora_config.tfinal: |
|
self.rankallocator.mask_using_rank_pattern(self.model, lora_config.rank_pattern) |
|
|
|
else: |
|
return None |
|
|
|
@staticmethod |
|
def _prepare_adalora_config(peft_config, model_config): |
|
if peft_config.target_modules is None: |
|
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING: |
|
raise ValueError("Please specify `target_modules` in `peft_config`") |
|
peft_config.target_modules = TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING[ |
|
model_config["model_type"] |
|
] |
|
if peft_config.inference_mode: |
|
peft_config.merge_weights = True |
|
return peft_config |
|
|
|
|
|
class AdaLoraLayer(LoraLayer): |
|
def __init__( |
|
self, |
|
in_features: int, |
|
out_features: int, |
|
): |
|
super().__init__(in_features, out_features) |
|
self.lora_E = nn.ParameterDict({}) |
|
self.lora_A = nn.ParameterDict({}) |
|
self.lora_B = nn.ParameterDict({}) |
|
self.ranknum = nn.ParameterDict({}) |
|
|
|
def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights): |
|
self.r[adapter_name] = r |
|
self.lora_alpha[adapter_name] = lora_alpha |
|
if lora_dropout > 0.0: |
|
lora_dropout_layer = nn.Dropout(p=lora_dropout) |
|
else: |
|
|
|
def lora_dropout_layer(x): |
|
return x |
|
|
|
self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer})) |
|
|
|
|
|
self.lora_A.update(nn.ParameterDict({adapter_name: nn.Parameter(torch.zeros(r, self.in_features))})) |
|
|
|
self.lora_E.update(nn.ParameterDict({adapter_name: nn.Parameter(torch.zeros(r, 1))})) |
|
|
|
self.lora_B.update(nn.ParameterDict({adapter_name: nn.Parameter(torch.zeros(self.out_features, r))})) |
|
|
|
self.ranknum.update(nn.ParameterDict({adapter_name: nn.Parameter(torch.zeros(1), requires_grad=False)})) |
|
self.ranknum[adapter_name].data.fill_(float(r)) |
|
self.ranknum[adapter_name].requires_grad = False |
|
self.scaling[adapter_name] = lora_alpha if lora_alpha > 0 else float(r) |
|
if init_lora_weights: |
|
self.reset_lora_parameters(adapter_name) |
|
self.to(self.weight.device) |
|
|
|
def reset_lora_parameters(self, adapter_name): |
|
if adapter_name in self.lora_A.keys(): |
|
nn.init.zeros_(self.lora_E[adapter_name]) |
|
nn.init.normal_(self.lora_A[adapter_name], mean=0.0, std=0.02) |
|
nn.init.normal_(self.lora_B[adapter_name], mean=0.0, std=0.02) |
|
|
|
|
|
class SVDLinear(nn.Linear, AdaLoraLayer): |
|
|
|
def __init__( |
|
self, |
|
adapter_name: str, |
|
in_features: int, |
|
out_features: int, |
|
r: int = 0, |
|
lora_alpha: int = 1, |
|
lora_dropout: float = 0.0, |
|
fan_in_fan_out: bool = False, |
|
**kwargs, |
|
): |
|
init_lora_weights = kwargs.pop("init_lora_weights", True) |
|
nn.Linear.__init__(self, in_features, out_features, **kwargs) |
|
AdaLoraLayer.__init__(self, in_features=in_features, out_features=out_features) |
|
|
|
self.weight.requires_grad = False |
|
|
|
self.fan_in_fan_out = fan_in_fan_out |
|
if fan_in_fan_out: |
|
self.weight.data = self.weight.data.T |
|
|
|
nn.Linear.reset_parameters(self) |
|
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) |
|
self.active_adapter = adapter_name |
|
|
|
def merge(self): |
|
if self.active_adapter not in self.lora_A.keys(): |
|
return |
|
if self.merged: |
|
warnings.warn("Already merged. Nothing to do.") |
|
return |
|
if self.r[self.active_adapter] > 0: |
|
self.weight.data += ( |
|
transpose( |
|
self.lora_B[self.active_adapter] |
|
@ (self.lora_A[self.active_adapter] * self.lora_E[self.active_adapter]) |
|
) |
|
* self.scaling[self.active_adapter] |
|
/ (self.ranknum[self.active_adapter] + 1e-5) |
|
) |
|
self.merged = True |
|
|
|
def unmerge(self): |
|
if self.active_adapter not in self.lora_A.keys(): |
|
return |
|
if not self.merged: |
|
warnings.warn("Already unmerged. Nothing to do.") |
|
return |
|
if self.r[self.active_adapter] > 0: |
|
self.weight.data -= ( |
|
transpose( |
|
self.lora_B[self.active_adapter] |
|
@ (self.lora_A[self.active_adapter] * self.lora_E[self.active_adapter]) |
|
) |
|
* self.scaling[self.active_adapter] |
|
/ (self.ranknum[self.active_adapter] + 1e-5) |
|
) |
|
self.merged = False |
|
|
|
def forward(self, x: torch.Tensor): |
|
if self.active_adapter not in self.lora_A.keys(): |
|
return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) |
|
if self.disable_adapters: |
|
if self.r[self.active_adapter] > 0 and self.merged: |
|
self.unmerge() |
|
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) |
|
elif self.r[self.active_adapter] > 0 and not self.merged: |
|
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) |
|
result += ( |
|
( |
|
self.lora_dropout[self.active_adapter](x) |
|
@ (self.lora_A[self.active_adapter] * self.lora_E[self.active_adapter]).T |
|
@ self.lora_B[self.active_adapter].T |
|
) |
|
* self.scaling[self.active_adapter] |
|
/ (self.ranknum[self.active_adapter] + 1e-5) |
|
) |
|
else: |
|
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) |
|
return result |
|
|
|
|
|
if is_bnb_available(): |
|
|
|
class SVDLinear8bitLt(bnb.nn.Linear8bitLt, AdaLoraLayer): |
|
|
|
def __init__( |
|
self, |
|
adapter_name, |
|
in_features, |
|
out_features, |
|
r: int = 0, |
|
lora_alpha: int = 1, |
|
lora_dropout: float = 0.0, |
|
**kwargs, |
|
): |
|
bnb.nn.Linear8bitLt.__init__( |
|
self, |
|
in_features, |
|
out_features, |
|
bias=kwargs.get("bias", True), |
|
has_fp16_weights=kwargs.get("has_fp16_weights", True), |
|
memory_efficient_backward=kwargs.get("memory_efficient_backward", False), |
|
threshold=kwargs.get("threshold", 0.0), |
|
index=kwargs.get("index", None), |
|
) |
|
AdaLoraLayer.__init__(self, in_features=in_features, out_features=out_features) |
|
|
|
self.weight.requires_grad = False |
|
|
|
init_lora_weights = kwargs.pop("init_lora_weights", True) |
|
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) |
|
self.active_adapter = adapter_name |
|
|
|
def forward(self, x: torch.Tensor): |
|
result = super().forward(x) |
|
|
|
if self.disable_adapters or self.active_adapter not in self.lora_A.keys(): |
|
return result |
|
elif self.r[self.active_adapter] > 0: |
|
if not torch.is_autocast_enabled(): |
|
expected_dtype = result.dtype |
|
|
|
if x.dtype != torch.float32: |
|
x = x.float() |
|
output = ( |
|
( |
|
self.lora_dropout[self.active_adapter](x) |
|
@ (self.lora_A[self.active_adapter] * self.lora_E[self.active_adapter]).T |
|
@ self.lora_B[self.active_adapter].T |
|
).to(expected_dtype) |
|
* self.scaling[self.active_adapter] |
|
/ (self.ranknum[self.active_adapter] + 1e-5) |
|
) |
|
else: |
|
output = ( |
|
( |
|
self.lora_dropout[self.active_adapter](x) |
|
@ (self.lora_A[self.active_adapter] * self.lora_E[self.active_adapter]).T |
|
@ self.lora_B[self.active_adapter].T |
|
) |
|
* self.scaling[self.active_adapter] |
|
/ (self.ranknum[self.active_adapter] + 1e-5) |
|
) |
|
result += output |
|
return result |
|
|
|
|
|
class RankAllocator(object): |
|
""" |
|
The RankAllocator for AdaLoraModel. Paper: https://openreview.net/pdf?id=lq62uWRJjiY |
|
|
|
Args: |
|
config ([`AdaLoraConfig`]): The configuration of the AdaLora model. |
|
model: the model that we apply AdaLoRA to. |
|
|
|
""" |
|
|
|
def __init__(self, model, peft_config, adapter_name): |
|
self.peft_config = peft_config |
|
self.adapter_name = adapter_name |
|
self.beta1 = peft_config.beta1 |
|
self.beta2 = peft_config.beta2 |
|
assert self.beta1 > 0 and self.beta1 < 1 |
|
assert self.beta2 > 0 and self.beta2 < 1 |
|
|
|
self.reset_ipt() |
|
self._set_budget_scheduler(model) |
|
|
|
def set_total_step(self, total_step): |
|
self.peft_config.total_step = total_step |
|
|
|
def reset_ipt(self): |
|
self.ipt = {} |
|
self.exp_avg_ipt = {} |
|
self.exp_avg_unc = {} |
|
|
|
def _set_budget_scheduler(self, model): |
|
self.init_bgt = 0 |
|
self.name_set = set() |
|
for n, p in model.named_parameters(): |
|
if f"lora_A.{self.adapter_name}" in n: |
|
self.init_bgt += p.size(0) |
|
self.name_set.add(n.replace("lora_A", "%s")) |
|
self.name_set = sorted(self.name_set) |
|
|
|
self.target_bgt = self.peft_config.target_r * len(self.name_set) |
|
|
|
def budget_schedule(self, step: int): |
|
tinit = self.peft_config.tinit |
|
tfinal = self.peft_config.tfinal |
|
total_step = self.peft_config.total_step |
|
|
|
if step <= tinit: |
|
budget = self.init_bgt |
|
mask_ind = False |
|
|
|
elif step > total_step - tfinal: |
|
budget = self.target_bgt |
|
mask_ind = True |
|
else: |
|
|
|
mul_coeff = 1 - (step - tinit) / (total_step - tfinal - tinit) |
|
budget = int((self.init_bgt - self.target_bgt) * (mul_coeff**3) + self.target_bgt) |
|
mask_ind = True if step % self.peft_config.deltaT == 0 else False |
|
return budget, mask_ind |
|
|
|
def update_ipt(self, model): |
|
|
|
for n, p in model.named_parameters(): |
|
if "lora_" in n and self.adapter_name in n: |
|
if n not in self.ipt: |
|
self.ipt[n] = torch.zeros_like(p) |
|
self.exp_avg_ipt[n] = torch.zeros_like(p) |
|
self.exp_avg_unc[n] = torch.zeros_like(p) |
|
with torch.no_grad(): |
|
self.ipt[n] = (p * p.grad).abs().detach() |
|
|
|
self.exp_avg_ipt[n] = self.beta1 * self.exp_avg_ipt[n] + (1 - self.beta1) * self.ipt[n] |
|
|
|
self.exp_avg_unc[n] = ( |
|
self.beta2 * self.exp_avg_unc[n] + (1 - self.beta2) * (self.ipt[n] - self.exp_avg_ipt[n]).abs() |
|
) |
|
|
|
def _element_score(self, n): |
|
return self.exp_avg_ipt[n] * self.exp_avg_unc[n] |
|
|
|
def _combine_ipt(self, ipt_E, ipt_AB): |
|
ipt_AB = ipt_AB.sum(dim=1, keepdim=False) |
|
sum_ipt = ipt_E.view(-1) + ipt_AB.view(-1) |
|
return sum_ipt |
|
|
|
def mask_to_budget(self, model, budget): |
|
value_ipt = {} |
|
vector_ipt = {} |
|
triplet_ipt = {} |
|
|
|
for n, p in model.named_parameters(): |
|
if f"lora_A.{self.adapter_name}" in n: |
|
entry_ipt = self._element_score(n) |
|
comb_ipt = torch.mean(entry_ipt, dim=1, keepdim=True) |
|
name_m = n.replace("lora_A", "%s") |
|
if name_m not in vector_ipt: |
|
vector_ipt[name_m] = [comb_ipt] |
|
else: |
|
vector_ipt[name_m].append(comb_ipt) |
|
if f"lora_B.{self.adapter_name}" in n: |
|
entry_ipt = self._element_score(n) |
|
comb_ipt = torch.mean(entry_ipt, dim=0, keepdim=False).view(-1, 1) |
|
name_m = n.replace("lora_B", "%s") |
|
if name_m not in vector_ipt: |
|
vector_ipt[name_m] = [comb_ipt] |
|
else: |
|
vector_ipt[name_m].append(comb_ipt) |
|
if f"lora_E.{self.adapter_name}" in n: |
|
entry_ipt = self._element_score(n) |
|
name_m = n.replace("lora_E", "%s") |
|
value_ipt[name_m] = entry_ipt |
|
|
|
all_score = [] |
|
|
|
for name_m in vector_ipt: |
|
ipt_E = value_ipt[name_m] |
|
ipt_AB = torch.cat(vector_ipt[name_m], dim=1) |
|
sum_ipt = self._combine_ipt(ipt_E, ipt_AB) |
|
name_E = name_m % "lora_E" |
|
triplet_ipt[name_E] = sum_ipt.view(-1, 1) |
|
all_score.append(sum_ipt.view(-1)) |
|
|
|
|
|
mask_threshold = torch.kthvalue( |
|
torch.cat(all_score), |
|
k=self.init_bgt - budget, |
|
)[0].item() |
|
|
|
rank_pattern = {} |
|
|
|
with torch.no_grad(): |
|
for n, p in model.named_parameters(): |
|
if f"lora_E.{self.adapter_name}" in n: |
|
p.masked_fill_(triplet_ipt[n] <= mask_threshold, 0.0) |
|
rank_pattern[n] = (~(triplet_ipt[n] <= mask_threshold)).view(-1).tolist() |
|
return rank_pattern |
|
|
|
def update_and_allocate(self, model, global_step, force_mask=False): |
|
|
|
if global_step < self.peft_config.total_step - self.peft_config.tfinal: |
|
self.update_ipt(model) |
|
budget, mask_ind = self.budget_schedule(global_step) |
|
|
|
if mask_ind or force_mask: |
|
rank_pattern = self.mask_to_budget(model, budget) |
|
else: |
|
rank_pattern = None |
|
return budget, rank_pattern |
|
|
|
def mask_using_rank_pattern(self, model, rank_pattern): |
|
|
|
is_adapter_name_truncated = False |
|
if self.adapter_name not in next(iter(rank_pattern.keys())): |
|
is_adapter_name_truncated = True |
|
|
|
with torch.no_grad(): |
|
for n, p in model.named_parameters(): |
|
if f"lora_E.{self.adapter_name}" in n: |
|
key = n if not is_adapter_name_truncated else n.replace(f".{self.adapter_name}", "") |
|
mask = torch.Tensor(rank_pattern[key]).unsqueeze(-1).to(p.device) |
|
p.masked_fill_(~mask.bool(), 0.0) |
|
|