Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2024-present the HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from copy import deepcopy | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from peft.utils.integrations import dequantize_module_weight, gather_params_ctx | |
from peft.utils.other import transpose | |
class DoraLinearLayer(nn.Module): | |
def __init__(self, fan_in_fan_out): | |
super().__init__() | |
self.fan_in_fan_out = fan_in_fan_out | |
def get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor: | |
# calculate L2 norm of weight matrix, column-wise | |
weight = transpose(weight, self.fan_in_fan_out) | |
weight = weight + scaling * lora_weight | |
weight_norm = torch.linalg.norm(weight, dim=1).to(weight.dtype) | |
return weight_norm | |
def update_layer(self, *, base_layer, lora_A, lora_B, scaling, place_on_cpu=False) -> None: | |
# temporarily convert fp16 to fp32, as fp16 can cause trouble on CPU with PyTorch < 2.2 | |
dtype_is_fp16 = lora_A.dtype == torch.float16 | |
if dtype_is_fp16: | |
lora_A = lora_A.float() | |
lora_B = lora_B.float() | |
with gather_params_ctx(base_layer.parameters()): | |
if base_layer.__class__.__name__ == "Linear4bit": | |
# We have to create a copy of the base layer, otherwise, FSDP will throw an error. 8bit does not work | |
# yet because Int8Params cannot be correctly deep-copied (attributes vanish) | |
base_layer = deepcopy(base_layer) | |
weight = dequantize_module_weight(base_layer) | |
if weight.data.ndim >= 4: # For handling LoRAs applied to Conv layers. | |
lora_weight = torch.mm(lora_B.flatten(start_dim=1), lora_A.flatten(start_dim=1)) | |
lora_weight = lora_weight.reshape(weight.shape) | |
else: | |
lora_weight = lora_B @ lora_A | |
if dtype_is_fp16: | |
lora_weight = lora_weight.half() | |
weight_norm = self.get_weight_norm(weight.to(lora_A.device), lora_weight, scaling) | |
if place_on_cpu: | |
weight_norm = weight_norm.to("cpu") | |
self.weight = nn.Parameter(weight_norm, requires_grad=True) | |
def forward(self, x, *, lora_A, lora_B, scaling, base_layer, base_result=None): | |
""" | |
For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer | |
output. | |
""" | |
# Don't use `lora_weight = lora_B.weight @ lora_A.weight` because this causes errors with FSDP. Instead, | |
# calculate the same but using forward. | |
x_eye = torch.eye(lora_A.weight.shape[1], device=lora_A.weight.device, dtype=x.dtype) | |
lora_weight = lora_B(lora_A(x_eye)).T | |
magnitude = self.weight | |
weight = dequantize_module_weight(base_layer) | |
weight = weight.to(x.dtype) | |
weight_norm = self.get_weight_norm(weight, lora_weight.detach(), scaling) | |
# see section 4.3 of DoRA (https://arxiv.org/abs/2402.09353) | |
# "[...] we suggest treating ||V +∆V ||_c in | |
# Eq. (5) as a constant, thereby detaching it from the gradient | |
# graph. This means that while ||V + ∆V ||_c dynamically | |
# reflects the updates of ∆V , it won’t receive any gradient | |
# during backpropagation" | |
weight_norm = weight_norm.detach() | |
mag_norm_scale = (magnitude / weight_norm).view(1, -1) | |
lora_result = lora_B(lora_A(x)) | |
bias = None | |
if base_result is not None: | |
bias = base_layer.bias | |
if bias is not None: | |
base_result = base_result - bias | |
else: | |
base_result = F.linear(x, transpose(weight, self.fan_in_fan_out)) | |
result_dora = (mag_norm_scale - 1) * base_result + mag_norm_scale * lora_result * scaling | |
return result_dora | |
def __repr__(self) -> str: | |
rep = super().__repr__() | |
return "lora.dora." + rep | |
class DoraEmbeddingLayer(DoraLinearLayer): | |
def forward(self, x, *, lora_A, lora_B, scaling, base_layer, embed_fn): | |
""" | |
For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer | |
output. | |
""" | |
lora_weight = (lora_A @ lora_B).T | |
magnitude = self.weight | |
weight = base_layer.weight | |
weight_norm = self.get_weight_norm(weight, lora_weight.detach(), scaling) | |
# see section 4.3 of DoRA (https://arxiv.org/abs/2402.09353) | |
# "[...] we suggest treating ||V +∆V ||_c in | |
# Eq. (5) as a constant, thereby detaching it from the gradient | |
# graph. This means that while ||V + ∆V ||_c dynamically | |
# reflects the updates of ∆V , it won’t receive any gradient | |
# during backpropagation" | |
weight_norm = weight_norm.detach() | |
mag_norm_scale = magnitude / weight_norm | |
result_dora = mag_norm_scale * (embed_fn(x, lora_A) @ lora_B) * scaling | |
return mag_norm_scale, result_dora | |
def __repr__(self) -> str: | |
rep = super().__repr__() | |
return "lora.dora." + rep | |
class _DoraConvNdLayer(DoraLinearLayer): | |
def get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor: | |
# calculate L2 norm of weight matrix, column-wise | |
weight = weight + scaling * lora_weight | |
# the following is needed to have compatibility with the 4/5D weight tensors of Conv2D/3D | |
dim = tuple(range(1, weight.dim())) | |
weight_norm = weight.norm(p=2, dim=dim, keepdim=True).transpose(1, 0) | |
return weight_norm | |
def forward(self, x, *, lora_A, lora_B, scaling, base_layer): | |
""" | |
For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer | |
output. | |
""" | |
weight = base_layer.weight | |
lora_weight = torch.mm(lora_B.weight.flatten(start_dim=1), lora_A.weight.flatten(start_dim=1)) | |
lora_weight = lora_weight.reshape(weight.shape) | |
magnitude = self.weight | |
weight_norm = self.get_weight_norm(weight, lora_weight.detach(), scaling) | |
# see section 4.3 of DoRA (https://arxiv.org/abs/2402.09353) | |
# "[...] we suggest treating ||V +∆V ||_c in | |
# Eq. (5) as a constant, thereby detaching it from the gradient | |
# graph. This means that while ||V + ∆V ||_c dynamically | |
# reflects the updates of ∆V , it won’t receive any gradient | |
# during backpropagation" | |
weight_norm = weight_norm.detach() | |
mag_norm_scale = magnitude / weight_norm | |
result_dora = (mag_norm_scale - 1) * ( | |
self.conv_fn( | |
x, | |
weight, | |
bias=None, | |
stride=base_layer.stride, | |
padding=base_layer.padding, | |
dilation=base_layer.dilation, | |
groups=base_layer.groups, | |
) | |
) + mag_norm_scale * lora_B(lora_A(x)) * scaling | |
return result_dora | |
def __repr__(self) -> str: | |
rep = super().__repr__() | |
return "lora.dora." + rep | |
class DoraConv2dLayer(_DoraConvNdLayer): | |
def __init__(self, fan_in_fan_out): | |
super().__init__(fan_in_fan_out) | |
self.conv_fn = F.conv2d | |
class DoraConv3dLayer(_DoraConvNdLayer): | |
def __init__(self, fan_in_fan_out): | |
super().__init__(fan_in_fan_out) | |
self.conv_fn = F.conv3d | |