jbilcke-hf's picture
jbilcke-hf HF Staff
upgrading finetrainers (and losing my extra code + improvements)
80ebcb3
raw
history blame
772 Bytes
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from ..args import BaseArgs
from ..parallel import ParallelBackendType
def perform_patches_for_training(args: "BaseArgs", parallel_backend: "ParallelBackendType") -> None:
# To avoid circular imports
from ..config import ModelType, TrainingType
if args.model_name == ModelType.LTX_VIDEO:
from .models.ltx_video import patch
patch.patch_transformer_forward()
if parallel_backend.tensor_parallel_enabled:
patch.patch_apply_rotary_emb_for_tp_compatibility()
if args.training_type == TrainingType.LORA and len(args.layerwise_upcasting_modules) > 0:
from dependencies.peft import patch
patch.patch_peft_move_adapter_to_device_of_base_layer()