# Copyright 2024-present the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import warnings from typing import Any, Optional import torch # from torch import nn from peft.import_utils import is_torchao_available from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge from .config import LoraConfig from .layer import Linear class TorchaoLoraLinear(Linear): """LoRA layer implementation for Linear layers using torchao data""" def __init__(self, *args, get_apply_tensor_subclass, **kwargs): # this is not strictly necessary, as kwargs are stored either way, but we want to error early if # get_apply_tensor_subclass is missing. if kwargs.get("lora_bias", False): raise ValueError(f"{self.__class__.__name__} does not support lora_bias yet, set it to False") super().__init__(*args, **kwargs) self.get_apply_tensor_subclass = get_apply_tensor_subclass self._check_dtype_supported() def _check_dtype_supported(self): # TODO: Not required once int4_weight_only is properly supported by torchao base_layer = self.get_base_layer() weight = base_layer.weight if hasattr(weight, "layout_tensor") and (weight.layout_tensor.data.dtype != torch.int8): raise ValueError(f"{type(self).__name__} only supports int8 weights for now.") def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: from torchao import quantize_ adapter_names = check_adapters_to_merge(self, adapter_names) if not adapter_names: # no adapter to merge return self._check_dtype_supported() base_layer = self.get_base_layer() weight = base_layer.weight for active_adapter in adapter_names: try: weight = weight.dequantize() except NotImplementedError as exc: msg = ( f"Weights of type {type(weight).__name__} do not support dequantization (yet), which is needed to " "support merging." ) raise NotImplementedError(msg) from exc if safe_merge and not torch.isfinite(weight).all(): raise ValueError( f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" ) weight += self.get_delta_weight(active_adapter) # TODO: once (if) torchao supports directly mutating the data, use that instead. del base_layer.weight base_layer.weight = weight quantize_(base_layer, self.get_apply_tensor_subclass()) del weight self.merged_adapters.append(active_adapter) def unmerge(self) -> None: from torchao import quantize_ if not self.merged: warnings.warn("Already unmerged. Nothing to do.") return while len(self.merged_adapters) > 0: active_adapter = self.merged_adapters.pop() if active_adapter not in self.lora_A.keys(): continue base_layer = self.get_base_layer() weight = base_layer.weight try: weight = weight.dequantize() except NotImplementedError as exc: msg = ( f"Weights of type {type(weight).__name__} do not support dequantization (yet), which is needed to " "support unmerging." ) raise NotImplementedError(msg) from exc weight -= self.get_delta_weight(active_adapter) # We go through a dummy module because overriding the weight.data does not work, the tensor retains the old # data. Therefore, we need to go through quantize_, which takes a module as input, and we need to delete and # re-assign the weight. # TODO: once (if) torchao supports directly mutating the data, use that instead. del base_layer.weight base_layer.weight = weight quantize_(base_layer, self.get_apply_tensor_subclass()) del weight def __repr__(self) -> str: rep = super().__repr__() return rep.replace("lora.Linear", f"lora.{self.__class__.__name__}") def dispatch_torchao( target: torch.nn.Module, adapter_name: str, lora_config: LoraConfig, **kwargs: Any, ) -> Optional[torch.nn.Module]: new_module = None if isinstance(target, BaseTunerLayer): target_base_layer = target.get_base_layer() else: target_base_layer = target if not hasattr(target_base_layer, "weight"): return new_module if not is_torchao_available(): return new_module from torchao.dtypes import AffineQuantizedTensor from torchao.quantization import LinearActivationQuantizedTensor if isinstance(target_base_layer.weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)): new_module = TorchaoLoraLinear(target, adapter_name, **kwargs) return new_module