Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2023-present the HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from __future__ import annotations | |
import copy | |
from contextlib import contextmanager | |
from functools import partial | |
from typing import Optional, Union | |
import torch | |
import torch.nn as nn | |
from peft.tuners.lora.layer import LoraLayer | |
from peft.tuners.lora.model import LoraModel | |
from peft.tuners.tuners_utils import BaseTuner | |
from peft.utils.constants import DUMMY_TARGET_MODULES | |
from peft.utils.save_and_load import set_peft_model_state_dict | |
from .. import lora | |
from .classifier import XLoraClassifier | |
from .config import XLoraConfig | |
from .layer import XLoraConv2dLayer, XLoraEmbeddingLayer, XLoraLinearLayer | |
def convert_layers_to_xlora( | |
base: nn.Module, # PeftModel | |
xloramodel: nn.Module, # XLoraModel | |
config: XLoraConfig, | |
) -> tuple[int, torch.device | None]: | |
""" | |
Returns the number of swapped layers. | |
""" | |
total_swapped = 0 | |
all_layers = [] | |
device = None | |
for module in base.modules(): | |
# Check the exact type because classes like OPTLearnedPositionalEmbedding inherit from nn.Embedding | |
if isinstance(module, lora.Linear): | |
device = module.lora_A[next(iter(module.lora_A))].weight.device | |
new_layer = XLoraLinearLayer( | |
model=xloramodel, | |
target=module, | |
target_forward=module.forward, | |
layer_number=total_swapped, | |
config=config, | |
) | |
all_layers.append(new_layer) | |
module.forward = new_layer.forward # type: ignore[method-assign] | |
total_swapped += 1 | |
elif isinstance(module, lora.Embedding): | |
device = module.lora_embedding_A[next(iter(module.lora_embedding_A))].device | |
new_layer = XLoraEmbeddingLayer( | |
model=xloramodel, | |
target=module, | |
target_forward=module.forward, | |
layer_number=total_swapped, | |
config=config, | |
) | |
all_layers.append(new_layer) | |
module.forward = new_layer.forward # type: ignore[method-assign] | |
total_swapped += 1 | |
elif isinstance(module, lora.Conv2d): | |
device = module.lora_A[next(iter(module.lora_A))].weight.device | |
new_layer = XLoraConv2dLayer( | |
model=xloramodel, | |
target=module, | |
target_forward=module.forward, | |
layer_number=total_swapped, | |
config=config, | |
) | |
all_layers.append(new_layer) | |
module.forward = new_layer.forward # type: ignore[method-assign] | |
total_swapped += 1 | |
return (total_swapped, device) | |
def _load_adapter_into_lora_model( | |
lora_model: LoraModel, | |
adapter_name: str, | |
model_id: str, | |
torch_device: Optional[str] = None, | |
ephemeral_gpu_offload: bool = False, | |
autocast_adapter_dtype: bool = True, | |
subfolder: Optional[str] = None, | |
**kwargs, | |
): | |
""" | |
This method emulates the behavior of `PeftModel.from_pretrained`. Updates to `PeftModel.from_pretrained` may need | |
to be reflected here. | |
All params pertain to the adapter (adapter name, model id, `i` is the adapter number in 0 indexing). | |
""" | |
from peft.peft_model import PeftModel | |
from peft.tuners.lora.config import LoraConfig | |
from peft.utils.other import infer_device | |
from peft.utils.save_and_load import load_peft_weights | |
hf_hub_download_kwargs, kwargs = PeftModel._split_kwargs(kwargs) | |
if torch_device is None: | |
torch_device = infer_device() | |
if adapter_name not in lora_model.peft_config: | |
# load the config | |
lora_peft_config = LoraConfig.from_pretrained( | |
model_id, | |
ephemeral_gpu_offload=ephemeral_gpu_offload, | |
subfolder=subfolder, | |
**hf_hub_download_kwargs, | |
) | |
lora_peft_config.inference_mode = False | |
lora_model.peft_config[adapter_name] = lora_peft_config | |
lora_model.inject_adapter(lora_model.model, adapter_name) | |
adapter_weights = load_peft_weights(model_id, device=torch_device, subfolder=subfolder, **hf_hub_download_kwargs) | |
new_adapter_weights = {} | |
# Rework the keys to contain the adapter numbers | |
for old_key in adapter_weights.keys(): | |
key: str = old_key | |
# Remove all the prefixes until we have model.<...> | |
while not (key.startswith("model.") and not key.startswith("model.model.")): | |
key = key[key.find(".") + 1 :] | |
# We always want model.model | |
key = "model." + key | |
new_adapter_weights[key] = adapter_weights[old_key] | |
# load the weights into the model | |
ignore_mismatched_sizes = kwargs.get("ignore_mismatched_sizes", False) | |
load_result = set_peft_model_state_dict( | |
lora_model, | |
new_adapter_weights, | |
adapter_name=adapter_name, | |
ignore_mismatched_sizes=ignore_mismatched_sizes, | |
) | |
if len(load_result.unexpected_keys) > 0: | |
raise ValueError( | |
f"Got unexpected keys! Please raise an issue and tag @EricLBuehler.\n\nunexpected_keys={load_result.unexpected_keys}" | |
) | |
if hasattr(lora_model, "_cast_adapter_dtype"): | |
lora_model._cast_adapter_dtype(adapter_name=adapter_name, autocast_adapter_dtype=autocast_adapter_dtype) | |
class XLoraModel(BaseTuner): | |
""" | |
Creates an X-LoRA (Mixture of LoRA experts), model from a pretrained transformers model. Currently, this X-LoRA | |
implementation only works with models with a transformer architecture. | |
The method is described in detail in https://arxiv.org/abs/2402.07148. | |
Args: | |
model ([`torch.nn.Module`]): The model to be adapted. | |
config ([`XLoraConfig`]): The configuration of the Lora model. | |
adapter_name (`str`): The name of the adapter, does not affect the LoRA adapter names. | |
Returns: | |
`torch.nn.Module`: The X-LoRA model. | |
Example: | |
```py | |
>>> from transformers import AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig | |
>>> from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training | |
>>> model_config = AutoConfig.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1") | |
>>> config = XLoraConfig( | |
... task_type="CAUSAL_LM", | |
... hidden_size=model_config.hidden_size, | |
... xlora_depth=4, | |
... adapters={ | |
... "adapter_1": "./path/to/the/checkpoint/", | |
... "adapter_2": "./path/to/the/checkpoint/", | |
... "adapter_n": "./path/to/the/checkpoint/", | |
... }, | |
... ) | |
>>> int8_config = BitsAndBytesConfig(load_in_8bit=True) | |
>>> model = AutoModelForCausalLM.from_pretrained( | |
... "mistralai/Mistral-7B-Instruct-v0.1", | |
... trust_remote_code=True, | |
... attn_implementation="flash_attention_2", | |
... device_map="cuda:0", | |
... torch_dtype=torch.bfloat16, | |
... quantization_config=int8_config, | |
... ) | |
>>> model = prepare_model_for_kbit_training(4) | |
>>> xlora_model = get_peft_model(model, config) | |
``` | |
""" | |
def __init__( | |
self, | |
model: nn.Module, | |
config: Union[dict[str, XLoraConfig], XLoraConfig], | |
adapter_name: str, | |
torch_device: Optional[str] = None, | |
ephemeral_gpu_offload: bool = False, | |
autocast_adapter_dtype: bool = True, | |
**kwargs, | |
) -> None: | |
""" | |
Create a new X-LoRA model | |
Args: | |
model (`nn.Module`): | |
Base model to apply X-LoRA to. | |
config: ([`XLoraConfig`]): | |
X-LoRA configuration object. | |
adapter_name: (`str`): | |
Adapter name for the X-LoRA adapter. | |
torch_device (`str`, *optional*, defaults to None): | |
(For loading the LoRA adapters) The device to load the adapter on. If `None`, the device will be | |
inferred. | |
ephemeral_gpu_offload (`bool`, *optional*, defaults to `False`): | |
(For loading the LoRA adapters) Whether to use ephemeral GPU offloading for partially loaded modules. | |
Defaults to `False`. | |
autocast_adapter_dtype (`bool`, *optional*, defaults to `True`): | |
(For loading the LoRA adapters) Whether to autocast the adapter dtype. Defaults to `True`. Right now, | |
this will only cast adapter weights using float16 and bfloat16 to float32, as this is typically | |
required for stable training, and only affect select PEFT tuners. | |
kwargs: (`optional`): | |
(For loading the LoRA adapters) Additional arguments to modify the way the adapter is loaded, e.g. the | |
token for Hugging Face Hub. | |
""" | |
nn.Module.__init__(self) | |
if isinstance(config, dict): | |
conf = config[adapter_name] | |
else: | |
conf = config | |
# Create an empty LoraModel | |
base_lora_config = copy.copy(conf) | |
base_lora_config.target_modules = DUMMY_TARGET_MODULES | |
# Imitate a LoraConfig, fields might need to be updated if LoraConfig is updated | |
base_lora_config.layer_replication = None | |
base_lora_config.bias = "none" | |
lora_model = LoraModel(model, base_lora_config, adapter_name) | |
self.xlora_config = conf | |
self.lora_model = lora_model | |
peft_config = conf | |
if hasattr(model.config, "use_cache") and model.config.use_cache: | |
raise ValueError("`use_cache` must be False") | |
adapters_items = peft_config.adapters.items() | |
if hasattr(self.xlora_config, "_subfolders"): | |
adapters_items = zip(peft_config.adapters.items(), self.xlora_config._subfolders) | |
else: | |
adapters_items = peft_config.adapters.items() | |
if hasattr(self.xlora_config, "_subfolders"): | |
for i, (_adapter_name, model_id), subfolder in enumerate(adapters_items): | |
_load_adapter_into_lora_model( | |
lora_model=self.lora_model, | |
adapter_name=str(i), | |
model_id=model_id, | |
torch_device=torch_device, | |
ephemeral_gpu_offload=ephemeral_gpu_offload, | |
autocast_adapter_dtype=autocast_adapter_dtype, | |
subfolder=subfolder, | |
**kwargs, | |
) | |
else: | |
for i, (_adapter_name, model_id) in enumerate(adapters_items): | |
_load_adapter_into_lora_model( | |
lora_model=self.lora_model, | |
adapter_name=str(i), | |
model_id=model_id, | |
torch_device=torch_device, | |
ephemeral_gpu_offload=ephemeral_gpu_offload, | |
autocast_adapter_dtype=autocast_adapter_dtype, | |
subfolder=None, | |
**kwargs, | |
) | |
self.lora_model.set_adapter(list(peft_config.adapters.keys())) | |
self._maybe_freeze_all_adapters() | |
total_swapped, device = convert_layers_to_xlora( | |
model, | |
self, | |
peft_config, | |
) | |
n_classes = len(peft_config.adapters) | |
xlora_classifier = XLoraClassifier(model, peft_config, n_classes, total_swapped, device) | |
# Setup the model internal state | |
self.internal_xlora_classifier = xlora_classifier | |
self.internal_xlora_scalings = None # type: ignore | |
# Controlled by enable_adapter_layers or disable_adapter_layers | |
self.disabled = False | |
def _maybe_freeze_all_adapters(self): | |
self.eval() | |
if not self.xlora_config.use_trainable_adapters: | |
for name, param in self.named_parameters(): | |
if "lora_" in name: | |
param.requires_grad = False | |
def generate(self, *args, **kwargs): | |
res = self.lora_model.generate(*args, **kwargs) # type: ignore | |
# This is necessary because we use PeftModel.disable_adapter() which reenables the adapters | |
self._maybe_freeze_all_adapters() | |
return res | |
def _enable_peft_forward_hooks(self, *generate_args, **generate_kwargs): | |
def scalings_injection_hook(target, args, kwargs, scalings): | |
# pre-forward hook to inject the adapter_names argument when using mixed adapter batches inference | |
kwargs["scalings"] = scalings | |
return args, kwargs | |
handles_to_remove = None | |
def pre_forward(module, *args, **kwargs): | |
nonlocal handles_to_remove | |
# =========================== Forward pass with "dummy" scalings ================== | |
args_real = args[0] | |
kwargs_real = args[1] | |
kwargs_real.update(kwargs) | |
dummy_scalings = self.internal_xlora_classifier.make_dummy_scalings(*args_real, **kwargs_real) | |
hook_handles = [] | |
for module in self.modules(): | |
if isinstance(module, LoraLayer): | |
pre_forward = partial(scalings_injection_hook, scalings=dummy_scalings) | |
handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True) | |
hook_handles.append(handle) | |
with torch.no_grad(): | |
self.lora_model.disable_adapter_layers() | |
try: | |
scaling_pass_kwargs = kwargs_real.copy() | |
scaling_pass_kwargs["output_hidden_states"] = True | |
scaling_pass_kwargs["return_dict"] = True | |
try: | |
base_output = self.lora_model.model.forward(*args_real, **scaling_pass_kwargs) | |
finally: | |
# Clean everything up | |
for handle in hook_handles: | |
handle.remove() | |
finally: | |
self.lora_model.enable_adapter_layers() | |
xlora_scalings = self.internal_xlora_classifier(result=base_output, *args_real, **kwargs_real) | |
# =========================== Real forward pass with calculated scalings ================== | |
hook_handles = [] | |
for module in self.modules(): | |
if isinstance(module, LoraLayer): | |
pre_forward = partial(scalings_injection_hook, scalings=xlora_scalings) | |
handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True) | |
hook_handles.append(handle) | |
handles_to_remove = hook_handles | |
if not self.disabled: | |
forward_handle = self.lora_model.model.register_forward_pre_hook(pre_forward, with_kwargs=True) | |
# Run the forward pass: first the scaling pass in the hook, and then with the base model | |
yield | |
if not self.disabled: | |
# TODO(EricLBuehler): If we get a forward exception, we may have multiple forward hooks. | |
for handle in handles_to_remove: | |
handle.remove() | |
forward_handle.remove() | |
def __getattr__(self, name: str): | |
"""Forward missing attributes to the wrapped module.""" | |
try: | |
return super().__getattr__(name) # defer to nn.Module's logic | |
except AttributeError: | |
if name == "lora_model": # see #1892: prevent infinite recursion if class is not initialized | |
raise | |
return getattr(self.lora_model, name) | |
def _prepare_adapter_config(peft_config, _model_config): | |
# Handle X-LoRA case | |
return peft_config | |
""" | |
Does nothing. X-LoRA needs adapters to be frozen. | |
""" | |
def _mark_only_adapters_as_trainable(self) -> None: ... | |
""" | |
This enables the X-LoRA adapter. | |
""" | |
def enable_adapter_layers(self) -> None: | |
self.disabled = False | |
""" | |
This diasables the X-LoRA adapter. | |
""" | |
def disable_adapter_layers(self) -> None: | |
self.disabled = True | |
def _create_and_replace( | |
self, | |
lora_config, | |
adapter_name, | |
target, | |
target_name, | |
parent, | |
current_key, | |
): | |
# Does nothing because XLoraModel has no target modules | |
pass | |
def _check_target_module_exists(lora_config, key): | |
# Does nothing because XLoraModel has no target modules | |
return False | |
def forward(self, *args, **kwargs): | |
return self.lora_model.model(*args, **kwargs) | |
def set_topk_lora(self, value: Optional[int]): | |
""" | |
Sparsely select the specified top_k LoRA experts instead of the default dense method. Set to None to use dense. | |
This is reflected in the config. | |
""" | |
classifier: XLoraClassifier = self.internal_xlora_classifier # type: ignore | |
classifier.config.top_k_lora = value | |
def set_global_scaling_weight(self, weight: float): | |
""" | |
Set the global LoRA weight, a scalar to multiply the output of each LoRA adapter by. This is by default 1. This | |
is reflected in the config. | |
""" | |
classifier: XLoraClassifier = self.internal_xlora_classifier # type: ignore | |
classifier.config.global_scaling_weight = weight | |
def set_scaling_pass_value(self, value: float | None): | |
""" | |
Set the scaling pass value, the value to set the scalings to during the scaling pass. If the value is None, the | |
scaling pass value will be 1/n where n is the number of adapters. | |
""" | |
classifier: XLoraClassifier = self.internal_xlora_classifier # type: ignore | |
classifier._set_override_scaling_pass_value(value) | |
def get_global_scaling_weight(self) -> float: | |
""" | |
Get the global LoRA weight. | |
""" | |
classifier: XLoraClassifier = self.internal_xlora_classifier # type: ignore | |
return classifier.config.global_scaling_weight | |
def get_latest_scalings(self) -> Optional[torch.Tensor]: | |
""" | |
Returns the latest scalings prediction, or None if no scalings have been predicted. The tensor is of shape | |
(batch_size, seq_len, n_layers, n_classes). | |
""" | |
return self.internal_xlora_scalings | |
def get_scalings_log(self) -> list[torch.Tensor]: | |
""" | |
Returns a shallow (only copying the list itself not the tensors) copy of the list containing the scalings log. | |
Editing the list does not change the underlying log. The tensors are of shape (batch_size, seq_len, n_layers, | |
n_classes). The seq_len dim may vary with input dimension. | |
""" | |
classifier: XLoraClassifier = self.internal_xlora_classifier # type: ignore | |
return classifier.log_scalings.copy() | |
def enable_scalings_logging(self): | |
""" | |
Enable scalings logging. | |
""" | |
classifier: XLoraClassifier = self.internal_xlora_classifier # type: ignore | |
classifier.scalings_logging = True | |
def disable_scalings_logging(self): | |
""" | |
Disable scalings logging, without clearing the log. | |
""" | |
classifier: XLoraClassifier = self.internal_xlora_classifier # type: ignore | |
classifier.scalings_logging = False | |
def clear_scalings_log(self): | |
""" | |
Clear the scalings log. | |
""" | |
classifier: XLoraClassifier = self.internal_xlora_classifier # type: ignore | |
classifier.log_scalings.clear() | |
def get_bucketed_scalings_log(self) -> dict[int, tuple[list[int], list[torch.Tensor]]]: | |
""" | |
Returns bucketed scalings, bucketed by seq_len. Each value consists of the positions (the first) and the | |
associated tensors. The positions are paired with the associated tensors and give the position in the scaling | |
log. | |
""" | |
classifier: XLoraClassifier = self.internal_xlora_classifier # type: ignore | |
return classifier._get_bucketed_scalings() | |