File size: 772 Bytes
80ebcb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from typing import TYPE_CHECKING


if TYPE_CHECKING:
    from ..args import BaseArgs
    from ..parallel import ParallelBackendType


def perform_patches_for_training(args: "BaseArgs", parallel_backend: "ParallelBackendType") -> None:
    # To avoid circular imports
    from ..config import ModelType, TrainingType

    if args.model_name == ModelType.LTX_VIDEO:
        from .models.ltx_video import patch

        patch.patch_transformer_forward()
        if parallel_backend.tensor_parallel_enabled:
            patch.patch_apply_rotary_emb_for_tp_compatibility()

    if args.training_type == TrainingType.LORA and len(args.layerwise_upcasting_modules) > 0:
        from dependencies.peft import patch

        patch.patch_peft_move_adapter_to_device_of_base_layer()