|
|
|
|
|
import ast |
|
from typing import Dict, List, Optional |
|
import torch |
|
import torch.nn as nn |
|
|
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
logging.basicConfig(level=logging.INFO) |
|
|
|
import networks.lora as lora |
|
|
|
|
|
FRAMEPACK_TARGET_REPLACE_MODULES = ["HunyuanVideoTransformerBlock", "HunyuanVideoSingleTransformerBlock"] |
|
|
|
|
|
def create_arch_network( |
|
multiplier: float, |
|
network_dim: Optional[int], |
|
network_alpha: Optional[float], |
|
vae: nn.Module, |
|
text_encoders: List[nn.Module], |
|
unet: nn.Module, |
|
neuron_dropout: Optional[float] = None, |
|
**kwargs, |
|
): |
|
|
|
exclude_patterns = kwargs.get("exclude_patterns", None) |
|
if exclude_patterns is None: |
|
exclude_patterns = [] |
|
else: |
|
exclude_patterns = ast.literal_eval(exclude_patterns) |
|
|
|
|
|
exclude_patterns.append(r".*(norm).*") |
|
|
|
kwargs["exclude_patterns"] = exclude_patterns |
|
|
|
return lora.create_network( |
|
FRAMEPACK_TARGET_REPLACE_MODULES, |
|
"lora_unet", |
|
multiplier, |
|
network_dim, |
|
network_alpha, |
|
vae, |
|
text_encoders, |
|
unet, |
|
neuron_dropout=neuron_dropout, |
|
**kwargs, |
|
) |
|
|
|
|
|
def create_arch_network_from_weights( |
|
multiplier: float, |
|
weights_sd: Dict[str, torch.Tensor], |
|
text_encoders: Optional[List[nn.Module]] = None, |
|
unet: Optional[nn.Module] = None, |
|
for_inference: bool = False, |
|
**kwargs, |
|
) -> lora.LoRANetwork: |
|
return lora.create_network_from_weights( |
|
FRAMEPACK_TARGET_REPLACE_MODULES, multiplier, weights_sd, text_encoders, unet, for_inference, **kwargs |
|
) |
|
|