Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import logging | |
import importlib | |
import backend.args | |
import huggingface_guess | |
from diffusers import DiffusionPipeline | |
from transformers import modeling_utils | |
from backend import memory_management | |
from backend.utils import read_arbitrary_config | |
from backend.state_dict import try_filter_state_dict, load_state_dict | |
from backend.operations import using_forge_operations | |
from backend.nn.vae import IntegratedAutoencoderKL | |
from backend.nn.clip import IntegratedCLIP | |
from backend.nn.unet import IntegratedUNet2DConditionModel | |
from backend.diffusion_engine.sd15 import StableDiffusion | |
from backend.diffusion_engine.sd20 import StableDiffusion2 | |
from backend.diffusion_engine.sdxl import StableDiffusionXL | |
from backend.diffusion_engine.flux import Flux | |
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXL, Flux] | |
logging.getLogger("diffusers").setLevel(logging.ERROR) | |
dir_path = os.path.dirname(__file__) | |
def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_path, state_dict): | |
config_path = os.path.join(repo_path, component_name) | |
if component_name in ['feature_extractor', 'safety_checker']: | |
return None | |
if lib_name in ['transformers', 'diffusers']: | |
if component_name in ['scheduler'] or component_name.startswith('tokenizer'): | |
cls = getattr(importlib.import_module(lib_name), cls_name) | |
return cls.from_pretrained(os.path.join(repo_path, component_name)) | |
if cls_name in ['AutoencoderKL']: | |
config = IntegratedAutoencoderKL.load_config(config_path) | |
with using_forge_operations(device=memory_management.cpu, dtype=memory_management.vae_dtype()): | |
model = IntegratedAutoencoderKL.from_config(config) | |
load_state_dict(model, state_dict, ignore_start='loss.') | |
return model | |
if component_name.startswith('text_encoder') and cls_name in ['CLIPTextModel', 'CLIPTextModelWithProjection']: | |
from transformers import CLIPTextConfig, CLIPTextModel | |
config = CLIPTextConfig.from_pretrained(config_path) | |
to_args = dict(device=memory_management.cpu, dtype=memory_management.text_encoder_dtype()) | |
with modeling_utils.no_init_weights(): | |
with using_forge_operations(**to_args, manual_cast_enabled=True): | |
model = IntegratedCLIP(CLIPTextModel, config, add_text_projection=True).to(**to_args) | |
load_state_dict(model, state_dict, ignore_errors=[ | |
'transformer.text_projection.weight', | |
'transformer.text_model.embeddings.position_ids', | |
'logit_scale' | |
], log_name=cls_name) | |
return model | |
if cls_name == 'T5EncoderModel': | |
from backend.nn.t5 import IntegratedT5 | |
config = read_arbitrary_config(config_path) | |
dtype = memory_management.text_encoder_dtype() | |
sd_dtype = memory_management.state_dict_dtype(state_dict) | |
if sd_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: | |
print(f'Using Detected T5 Data Type: {sd_dtype}') | |
dtype = sd_dtype | |
with modeling_utils.no_init_weights(): | |
with using_forge_operations(device=memory_management.cpu, dtype=dtype, manual_cast_enabled=True): | |
model = IntegratedT5(config) | |
load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight']) | |
return model | |
if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel']: | |
model_loader = None | |
if cls_name == 'UNet2DConditionModel': | |
model_loader = lambda c: IntegratedUNet2DConditionModel.from_config(c) | |
if cls_name == 'FluxTransformer2DModel': | |
from backend.nn.flux import IntegratedFluxTransformer2DModel | |
model_loader = lambda c: IntegratedFluxTransformer2DModel(**c) | |
unet_config = guess.unet_config.copy() | |
state_dict_size = memory_management.state_dict_size(state_dict) | |
state_dict_dtype = memory_management.state_dict_dtype(state_dict) | |
storage_dtype = memory_management.unet_dtype(model_params=state_dict_size, supported_dtypes=guess.supported_inference_dtypes) | |
unet_storage_dtype_overwrite = backend.args.dynamic_args.get('forge_unet_storage_dtype') | |
if unet_storage_dtype_overwrite is not None: | |
storage_dtype = unet_storage_dtype_overwrite | |
else: | |
if state_dict_dtype in [torch.float8_e4m3fn, torch.float8_e5m2, 'nf4', 'fp4']: | |
print(f'Using Detected UNet Type: {state_dict_dtype}') | |
storage_dtype = state_dict_dtype | |
if state_dict_dtype in ['nf4', 'fp4']: | |
print(f'Using pre-quant state dict!') | |
load_device = memory_management.get_torch_device() | |
computation_dtype = memory_management.get_computation_dtype(load_device, supported_dtypes=guess.supported_inference_dtypes) | |
offload_device = memory_management.unet_offload_device() | |
if storage_dtype in ['nf4', 'fp4']: | |
initial_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=computation_dtype) | |
with using_forge_operations(device=initial_device, dtype=computation_dtype, manual_cast_enabled=False, bnb_dtype=storage_dtype): | |
model = model_loader(unet_config) | |
else: | |
initial_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=storage_dtype) | |
need_manual_cast = storage_dtype != computation_dtype | |
to_args = dict(device=initial_device, dtype=storage_dtype) | |
with using_forge_operations(**to_args, manual_cast_enabled=need_manual_cast): | |
model = model_loader(unet_config).to(**to_args) | |
load_state_dict(model, state_dict) | |
if hasattr(model, '_internal_dict'): | |
model._internal_dict = unet_config | |
else: | |
model.config = unet_config | |
model.storage_dtype = storage_dtype | |
model.computation_dtype = computation_dtype | |
model.load_device = load_device | |
model.initial_device = initial_device | |
model.offload_device = offload_device | |
return model | |
print(f'Skipped: {component_name} = {lib_name}.{cls_name}') | |
return None | |
def split_state_dict(sd, sd_vae=None): | |
guess = huggingface_guess.guess(sd) | |
guess.clip_target = guess.clip_target(sd) | |
if sd_vae is not None: | |
print(f'Using external VAE state dict: {len(sd_vae)}') | |
state_dict = { | |
guess.unet_target: try_filter_state_dict(sd, guess.unet_key_prefix), | |
guess.vae_target: try_filter_state_dict(sd, guess.vae_key_prefix) if sd_vae is None else sd_vae | |
} | |
sd = guess.process_clip_state_dict(sd) | |
for k, v in guess.clip_target.items(): | |
state_dict[v] = try_filter_state_dict(sd, [k + '.']) | |
state_dict['ignore'] = sd | |
print_dict = {k: len(v) for k, v in state_dict.items()} | |
print(f'StateDict Keys: {print_dict}') | |
del state_dict['ignore'] | |
return state_dict, guess | |
def forge_loader(sd, sd_vae=None): | |
state_dicts, estimated_config = split_state_dict(sd, sd_vae=sd_vae) | |
repo_name = estimated_config.huggingface_repo | |
local_path = os.path.join(dir_path, 'huggingface', repo_name) | |
config: dict = DiffusionPipeline.load_config(local_path) | |
huggingface_components = {} | |
for component_name, v in config.items(): | |
if isinstance(v, list) and len(v) == 2: | |
lib_name, cls_name = v | |
component_sd = state_dicts.get(component_name, None) | |
component = load_huggingface_component(estimated_config, component_name, lib_name, cls_name, local_path, component_sd) | |
if component_sd is not None: | |
del state_dicts[component_name] | |
if component is not None: | |
huggingface_components[component_name] = component | |
for M in possible_models: | |
if any(isinstance(estimated_config, x) for x in M.matched_guesses): | |
return M(estimated_config=estimated_config, huggingface_components=huggingface_components) | |
print('Failed to recognize model type!') | |
return None | |