|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
import re |
|
import warnings |
|
from dataclasses import asdict, dataclass, field |
|
from enum import Enum |
|
from typing import List, Optional, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers.pytorch_utils import Conv1D |
|
|
|
from ..import_utils import is_bnb_available |
|
from ..utils import ( |
|
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, |
|
ModulesToSaveWrapper, |
|
PeftConfig, |
|
PeftType, |
|
_freeze_adapter, |
|
_get_submodules, |
|
transpose, |
|
) |
|
|
|
|
|
if is_bnb_available(): |
|
import bitsandbytes as bnb |
|
|
|
|
|
@dataclass |
|
class LoraConfig(PeftConfig): |
|
""" |
|
This is the configuration class to store the configuration of a [`LoraModel`]. |
|
|
|
Args: |
|
r (`int`): Lora attention dimension. |
|
target_modules (`Union[List[str],str]`): The names of the modules to apply Lora to. |
|
lora_alpha (`float`): The alpha parameter for Lora scaling. |
|
lora_dropout (`float`): The dropout probability for Lora layers. |
|
fan_in_fan_out (`bool`): Set this to True if the layer to replace stores weight like (fan_in, fan_out). |
|
For example, gpt-2 uses `Conv1D` which stores weights like (fan_in, fan_out) and hence this should be set to `True`.: |
|
bias (`str`): Bias type for Lora. Can be 'none', 'all' or 'lora_only' |
|
modules_to_save (`List[str]`):List of modules apart from LoRA layers to be set as trainable |
|
and saved in the final checkpoint. |
|
""" |
|
|
|
r: int = field(default=8, metadata={"help": "Lora attention dimension"}) |
|
target_modules: Optional[Union[List[str], str]] = field( |
|
default=None, |
|
metadata={ |
|
"help": "List of module names or regex expression of the module names to replace with Lora." |
|
"For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' " |
|
}, |
|
) |
|
lora_alpha: int = field(default=None, metadata={"help": "Lora alpha"}) |
|
lora_dropout: float = field(default=None, metadata={"help": "Lora dropout"}) |
|
fan_in_fan_out: bool = field( |
|
default=False, |
|
metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"}, |
|
) |
|
bias: str = field(default="none", metadata={"help": "Bias type for Lora. Can be 'none', 'all' or 'lora_only'"}) |
|
modules_to_save: Optional[List[str]] = field( |
|
default=None, |
|
metadata={ |
|
"help": "List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. " |
|
"For example, in Sequence Classification or Token Classification tasks, " |
|
"the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved." |
|
}, |
|
) |
|
init_lora_weights: bool = field( |
|
default=True, |
|
metadata={"help": "Whether to initialize the weights of the Lora layers."}, |
|
) |
|
|
|
def __post_init__(self): |
|
self.peft_type = PeftType.LORA |
|
|
|
|
|
class LoraModel(torch.nn.Module): |
|
""" |
|
Creates Low Rank Adapter (Lora) model from a pretrained transformers model. |
|
|
|
Args: |
|
model ([`~transformers.PreTrainedModel`]): The model to be adapted. |
|
config ([`LoraConfig`]): The configuration of the Lora model. |
|
|
|
Returns: |
|
`torch.nn.Module`: The Lora model. |
|
|
|
Example: |
|
|
|
```py |
|
>>> from transformers import AutoModelForSeq2SeqLM, LoraConfig |
|
>>> from peft import LoraModel, LoraConfig |
|
|
|
>>> config = LoraConfig( |
|
... peft_type="LORA", |
|
... task_type="SEQ_2_SEQ_LM", |
|
... r=8, |
|
... lora_alpha=32, |
|
... target_modules=["q", "v"], |
|
... lora_dropout=0.01, |
|
... ) |
|
|
|
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") |
|
>>> lora_model = LoraModel(config, model) |
|
``` |
|
|
|
**Attributes**: |
|
- **model** ([`~transformers.PreTrainedModel`]) -- The model to be adapted. |
|
- **peft_config** ([`LoraConfig`]): The configuration of the Lora model. |
|
""" |
|
|
|
def __init__(self, model, config, adapter_name): |
|
super().__init__() |
|
self.model = model |
|
self.forward = self.model.forward |
|
self.peft_config = config |
|
self.add_adapter(adapter_name, self.peft_config[adapter_name]) |
|
|
|
def add_adapter(self, adapter_name, config=None): |
|
if config is not None: |
|
model_config = self.model.config.to_dict() if hasattr(self.model.config, "to_dict") else self.model.config |
|
config = self._prepare_lora_config(config, model_config) |
|
self.peft_config[adapter_name] = config |
|
self._find_and_replace(adapter_name) |
|
if len(self.peft_config) > 1 and self.peft_config[adapter_name].bias != "none": |
|
raise ValueError( |
|
"LoraModel supports only 1 adapter with bias. When using multiple adapters, set bias to 'none' for all adapters." |
|
) |
|
mark_only_lora_as_trainable(self.model, self.peft_config[adapter_name].bias) |
|
if self.peft_config[adapter_name].inference_mode: |
|
_freeze_adapter(self.model, adapter_name) |
|
|
|
def _find_and_replace(self, adapter_name): |
|
lora_config = self.peft_config[adapter_name] |
|
loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False) |
|
if loaded_in_8bit and not is_bnb_available(): |
|
raise ImportError( |
|
"To use Lora with 8-bit quantization, please install the `bitsandbytes` package. " |
|
"You can install it with `pip install bitsandbytes`." |
|
) |
|
is_target_modules_in_base_model = False |
|
kwargs = { |
|
"r": lora_config.r, |
|
"lora_alpha": lora_config.lora_alpha, |
|
"lora_dropout": lora_config.lora_dropout, |
|
"fan_in_fan_out": lora_config.fan_in_fan_out, |
|
"init_lora_weights": lora_config.init_lora_weights, |
|
} |
|
key_list = [key for key, _ in self.model.named_modules()] |
|
for key in key_list: |
|
if isinstance(lora_config.target_modules, str): |
|
target_module_found = re.fullmatch(lora_config.target_modules, key) |
|
else: |
|
target_module_found = any(key.endswith(target_key) for target_key in lora_config.target_modules) |
|
if target_module_found: |
|
if not is_target_modules_in_base_model: |
|
is_target_modules_in_base_model = True |
|
parent, target, target_name = _get_submodules(self.model, key) |
|
bias = target.bias is not None |
|
if isinstance(target, LoraLayer): |
|
target.update_layer( |
|
adapter_name, |
|
lora_config.r, |
|
lora_config.lora_alpha, |
|
lora_config.lora_dropout, |
|
lora_config.init_lora_weights, |
|
) |
|
else: |
|
if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt): |
|
eightbit_kwargs = kwargs.copy() |
|
eightbit_kwargs.update( |
|
{ |
|
"has_fp16_weights": target.state.has_fp16_weights, |
|
"memory_efficient_backward": target.state.memory_efficient_backward, |
|
"threshold": target.state.threshold, |
|
"index": target.index, |
|
} |
|
) |
|
new_module = Linear8bitLt( |
|
adapter_name, target.in_features, target.out_features, bias=bias, **eightbit_kwargs |
|
) |
|
else: |
|
if isinstance(target, torch.nn.Linear): |
|
in_features, out_features = target.in_features, target.out_features |
|
if kwargs["fan_in_fan_out"]: |
|
warnings.warn( |
|
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " |
|
"Setting fan_in_fan_out to False." |
|
) |
|
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False |
|
elif isinstance(target, Conv1D): |
|
in_features, out_features = ( |
|
target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape |
|
) |
|
if not kwargs["fan_in_fan_out"]: |
|
warnings.warn( |
|
"fan_in_fan_out is set to False but the target module is `Conv1D`. " |
|
"Setting fan_in_fan_out to True." |
|
) |
|
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True |
|
else: |
|
raise ValueError( |
|
f"Target module {target} is not supported. " |
|
f"Currently, only `torch.nn.Linear` and `Conv1D` are supported." |
|
) |
|
new_module = Linear(adapter_name, in_features, out_features, bias=bias, **kwargs) |
|
|
|
self._replace_module(parent, target_name, new_module, target) |
|
if not is_target_modules_in_base_model: |
|
pass |
|
|
|
|
|
|
|
|
|
|
|
def _replace_module(self, parent_module, child_name, new_module, old_module): |
|
setattr(parent_module, child_name, new_module) |
|
new_module.weight = old_module.weight |
|
if old_module.bias is not None: |
|
new_module.bias = old_module.bias |
|
if getattr(old_module, "state", None) is not None: |
|
new_module.state = old_module.state |
|
new_module.to(old_module.weight.device) |
|
|
|
|
|
for name, module in new_module.named_modules(): |
|
if "lora_" in name: |
|
module.to(old_module.weight.device) |
|
|
|
def __getattr__(self, name: str): |
|
"""Forward missing attributes to the wrapped module.""" |
|
try: |
|
return super().__getattr__(name) |
|
except AttributeError: |
|
return getattr(self.model, name) |
|
|
|
def get_peft_config_as_dict(self, inference: bool = False): |
|
config_dict = {} |
|
for key, value in self.peft_config.items(): |
|
config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()} |
|
if inference: |
|
config["inference_mode"] = True |
|
config_dict[key] = config |
|
return config |
|
|
|
def _set_adapter_layers(self, enabled=True): |
|
for module in self.model.modules(): |
|
if isinstance(module, LoraLayer): |
|
module.disable_adapters = False if enabled else True |
|
|
|
def enable_adapter_layers(self): |
|
self._set_adapter_layers(enabled=True) |
|
|
|
def disable_adapter_layers(self): |
|
self._set_adapter_layers(enabled=False) |
|
|
|
def set_adapter(self, adapter_name): |
|
for module in self.model.modules(): |
|
if isinstance(module, LoraLayer): |
|
if module.merged: |
|
warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") |
|
module.unmerge() |
|
module.active_adapter = adapter_name |
|
|
|
def merge_adapter(self): |
|
for module in self.model.modules(): |
|
if isinstance(module, LoraLayer): |
|
module.merge() |
|
|
|
def unmerge_adapter(self): |
|
for module in self.model.modules(): |
|
if isinstance(module, LoraLayer): |
|
module.unmerge() |
|
|
|
@staticmethod |
|
def _prepare_lora_config(peft_config, model_config): |
|
if peft_config.target_modules is None: |
|
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING: |
|
raise ValueError("Please specify `target_modules` in `peft_config`") |
|
peft_config.target_modules = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_config["model_type"]] |
|
if peft_config.inference_mode: |
|
peft_config.merge_weights = True |
|
return peft_config |
|
|
|
def merge_and_unload(self): |
|
r""" |
|
This method merges the LoRa layers into the base model. This is needed if someone wants to use the base model |
|
as a standalone model. |
|
""" |
|
if getattr(self.config, "model_type", None) == "gpt2": |
|
raise ValueError("GPT2 models are not supported for merging LORA layers") |
|
|
|
if getattr(self.model, "is_loaded_in_8bit", False): |
|
raise ValueError("Cannot merge LORA layers when the model is loaded in 8-bit mode") |
|
|
|
key_list = [key for key, _ in self.model.named_modules() if "lora" not in key] |
|
for key in key_list: |
|
try: |
|
parent, target, target_name = _get_submodules(self.model, key) |
|
except AttributeError: |
|
continue |
|
if isinstance(target, LoraLayer): |
|
bias = target.bias is not None |
|
new_module = torch.nn.Linear(target.in_features, target.out_features, bias=bias) |
|
target.merge() |
|
self._replace_module(parent, target_name, new_module, target) |
|
|
|
|
|
if isinstance(target, ModulesToSaveWrapper): |
|
setattr(parent, target_name, target.modules_to_save[target.active_adapter]) |
|
|
|
return self.model |
|
|
|
def add_weighted_adapter(self, adapters, weights, adapter_name): |
|
if len({self.peft_config[adapter].r for adapter in adapters}) != 1: |
|
raise ValueError("All adapters must have the same r value") |
|
self.peft_config[adapter_name] = self.peft_config[adapters[0]] |
|
self.peft_config[adapter_name].lora_alpha = self.peft_config[adapters[0]].r |
|
self._find_and_replace(adapter_name) |
|
mark_only_lora_as_trainable(self.model, self.peft_config[adapter_name].bias) |
|
_freeze_adapter(self.model, adapter_name) |
|
key_list = [key for key, _ in self.model.named_modules() if "lora" not in key] |
|
for key in key_list: |
|
_, target, _ = _get_submodules(self.model, key) |
|
if isinstance(target, LoraLayer): |
|
target.lora_A[adapter_name].weight.data = target.lora_A[adapter_name].weight.data * 0.0 |
|
target.lora_B[adapter_name].weight.data = target.lora_B[adapter_name].weight.data * 0.0 |
|
for adapter, weight in zip(adapters, weights): |
|
if adapter not in target.lora_A: |
|
continue |
|
target.lora_A[adapter_name].weight.data += ( |
|
target.lora_A[adapter].weight.data * weight * target.scaling[adapter] |
|
) |
|
target.lora_B[adapter_name].weight.data += target.lora_B[adapter].weight.data * weight |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None: |
|
for n, p in model.named_parameters(): |
|
if "lora_" not in n: |
|
p.requires_grad = False |
|
if bias == "none": |
|
return |
|
elif bias == "all": |
|
for n, p in model.named_parameters(): |
|
if "bias" in n: |
|
p.requires_grad = True |
|
elif bias == "lora_only": |
|
for m in model.modules(): |
|
if isinstance(m, LoraLayer) and hasattr(m, "bias") and m.bias is not None: |
|
m.bias.requires_grad = True |
|
else: |
|
raise NotImplementedError |
|
|
|
class LoraLayer: |
|
def __init__( |
|
self, |
|
in_features: int, |
|
out_features: int, |
|
): |
|
self.r = {} |
|
self.lora_alpha = {} |
|
self.scaling = {} |
|
self.lora_dropout = nn.ModuleDict({}) |
|
self.lora_A = nn.ModuleDict({}) |
|
self.lora_B = nn.ModuleDict({}) |
|
|
|
self.merged = False |
|
self.disable_adapters = False |
|
self.in_features = in_features |
|
self.out_features = out_features |
|
|
|
def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights): |
|
self.r[adapter_name] = r |
|
self.lora_alpha[adapter_name] = lora_alpha |
|
if lora_dropout > 0.0: |
|
lora_dropout_layer = nn.Dropout(p=lora_dropout) |
|
else: |
|
lora_dropout_layer = nn.Identity() |
|
|
|
self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer})) |
|
|
|
if r > 0: |
|
self.lora_A.update(nn.ModuleDict({adapter_name: nn.Linear(self.in_features, r, bias=False)})) |
|
self.lora_B.update(nn.ModuleDict({adapter_name: nn.Linear(r, self.out_features, bias=False)})) |
|
self.scaling[adapter_name] = lora_alpha / r |
|
if init_lora_weights: |
|
self.reset_lora_parameters(adapter_name) |
|
self.to(self.weight.device) |
|
|
|
def reset_lora_parameters(self, adapter_name): |
|
if adapter_name in self.lora_A.keys(): |
|
|
|
nn.init.kaiming_uniform_(self.lora_A[adapter_name].weight, a=math.sqrt(5)) |
|
nn.init.zeros_(self.lora_B[adapter_name].weight) |
|
|
|
|
|
class Linear(nn.Linear, LoraLayer): |
|
|
|
def __init__( |
|
self, |
|
adapter_name: str, |
|
in_features: int, |
|
out_features: int, |
|
r: int = 0, |
|
lora_alpha: int = 1, |
|
lora_dropout: float = 0.0, |
|
fan_in_fan_out: bool = False, |
|
**kwargs, |
|
): |
|
init_lora_weights = kwargs.pop("init_lora_weights", True) |
|
|
|
nn.Linear.__init__(self, in_features, out_features, **kwargs) |
|
LoraLayer.__init__(self, in_features=in_features, out_features=out_features) |
|
|
|
self.weight.requires_grad = False |
|
|
|
self.fan_in_fan_out = fan_in_fan_out |
|
if fan_in_fan_out: |
|
self.weight.data = self.weight.data.T |
|
|
|
nn.Linear.reset_parameters(self) |
|
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) |
|
self.active_adapter = adapter_name |
|
|
|
def merge(self): |
|
if self.active_adapter not in self.lora_A.keys(): |
|
return |
|
if self.merged: |
|
warnings.warn("Already merged. Nothing to do.") |
|
return |
|
if self.r[self.active_adapter] > 0: |
|
self.weight.data += ( |
|
transpose( |
|
self.lora_B[self.active_adapter].weight @ self.lora_A[self.active_adapter].weight, |
|
self.fan_in_fan_out, |
|
) |
|
* self.scaling[self.active_adapter] |
|
) |
|
self.merged = True |
|
|
|
def unmerge(self): |
|
if self.active_adapter not in self.lora_A.keys(): |
|
return |
|
if not self.merged: |
|
warnings.warn("Already unmerged. Nothing to do.") |
|
return |
|
if self.r[self.active_adapter] > 0: |
|
self.weight.data -= ( |
|
transpose( |
|
self.lora_B[self.active_adapter].weight @ self.lora_A[self.active_adapter].weight, |
|
self.fan_in_fan_out, |
|
) |
|
* self.scaling[self.active_adapter] |
|
) |
|
self.merged = False |
|
|
|
def forward(self, x: torch.Tensor): |
|
previous_dtype = x.dtype |
|
|
|
if self.active_adapter not in self.lora_A.keys(): |
|
return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) |
|
if self.disable_adapters: |
|
if self.r[self.active_adapter] > 0 and self.merged: |
|
self.unmerge() |
|
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) |
|
elif self.r[self.active_adapter] > 0 and not self.merged: |
|
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) |
|
|
|
x = x.to(self.lora_A[self.active_adapter].weight.dtype) |
|
|
|
result += ( |
|
self.lora_B[self.active_adapter]( |
|
self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x)) |
|
) |
|
* self.scaling[self.active_adapter] |
|
) |
|
else: |
|
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) |
|
|
|
result = result.to(previous_dtype) |
|
|
|
return result |
|
|
|
|
|
if is_bnb_available(): |
|
|
|
class Linear8bitLt(bnb.nn.Linear8bitLt, LoraLayer): |
|
|
|
def __init__( |
|
self, |
|
adapter_name, |
|
in_features, |
|
out_features, |
|
r: int = 0, |
|
lora_alpha: int = 1, |
|
lora_dropout: float = 0.0, |
|
**kwargs, |
|
): |
|
bnb.nn.Linear8bitLt.__init__( |
|
self, |
|
in_features, |
|
out_features, |
|
bias=kwargs.get("bias", True), |
|
has_fp16_weights=kwargs.get("has_fp16_weights", True), |
|
memory_efficient_backward=kwargs.get("memory_efficient_backward", False), |
|
threshold=kwargs.get("threshold", 0.0), |
|
index=kwargs.get("index", None), |
|
) |
|
LoraLayer.__init__(self, in_features=in_features, out_features=out_features) |
|
|
|
|
|
self.weight.requires_grad = False |
|
|
|
init_lora_weights = kwargs.pop("init_lora_weights", True) |
|
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) |
|
self.active_adapter = adapter_name |
|
|
|
def forward(self, x: torch.Tensor): |
|
result = super().forward(x) |
|
|
|
if self.disable_adapters or self.active_adapter not in self.lora_A.keys(): |
|
return result |
|
elif self.r[self.active_adapter] > 0: |
|
if not torch.is_autocast_enabled(): |
|
expected_dtype = result.dtype |
|
|
|
if x.dtype != torch.float32: |
|
x = x.float() |
|
output = ( |
|
self.lora_B[self.active_adapter]( |
|
self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x)) |
|
).to(expected_dtype) |
|
* self.scaling[self.active_adapter] |
|
) |
|
else: |
|
output = ( |
|
self.lora_B[self.active_adapter]( |
|
self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x)) |
|
) |
|
* self.scaling[self.active_adapter] |
|
) |
|
result += output |
|
return result |