Spaces:
Runtime error
Runtime error
import collections | |
import importlib | |
import os | |
import sys | |
import math | |
import threading | |
import enum | |
import torch | |
import re | |
import safetensors.torch | |
from omegaconf import OmegaConf, ListConfig | |
from urllib import request | |
import gc | |
import contextlib | |
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches | |
from modules.shared import opts, cmd_opts | |
from modules.timer import Timer | |
import numpy as np | |
from backend.loader import forge_loader | |
from backend import memory_management | |
from backend.args import dynamic_args | |
from backend.utils import load_torch_file | |
model_dir = "Stable-diffusion" | |
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir)) | |
checkpoints_list = {} | |
checkpoint_aliases = {} | |
checkpoint_alisases = checkpoint_aliases # for compatibility with old name | |
checkpoints_loaded = collections.OrderedDict() | |
class ModelType(enum.Enum): | |
SD1 = 1 | |
SD2 = 2 | |
SDXL = 3 | |
SSD = 4 | |
SD3 = 5 | |
def replace_key(d, key, new_key, value): | |
keys = list(d.keys()) | |
d[new_key] = value | |
if key not in keys: | |
return d | |
index = keys.index(key) | |
keys[index] = new_key | |
new_d = {k: d[k] for k in keys} | |
d.clear() | |
d.update(new_d) | |
return d | |
class CheckpointInfo: | |
def __init__(self, filename): | |
self.filename = filename | |
abspath = os.path.abspath(filename) | |
abs_ckpt_dir = os.path.abspath(shared.cmd_opts.ckpt_dir) if shared.cmd_opts.ckpt_dir is not None else None | |
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors" | |
if abs_ckpt_dir and abspath.startswith(abs_ckpt_dir): | |
name = abspath.replace(abs_ckpt_dir, '') | |
elif abspath.startswith(model_path): | |
name = abspath.replace(model_path, '') | |
else: | |
name = os.path.basename(filename) | |
if name.startswith("\\") or name.startswith("/"): | |
name = name[1:] | |
def read_metadata(): | |
metadata = read_metadata_from_safetensors(filename) | |
self.modelspec_thumbnail = metadata.pop('modelspec.thumbnail', None) | |
return metadata | |
self.metadata = {} | |
if self.is_safetensors: | |
try: | |
self.metadata = cache.cached_data_for_file('safetensors-metadata', "checkpoint/" + name, filename, read_metadata) | |
except Exception as e: | |
errors.display(e, f"reading metadata for {filename}") | |
self.name = name | |
self.name_for_extra = os.path.splitext(os.path.basename(filename))[0] | |
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] | |
self.hash = model_hash(filename) | |
self.sha256 = hashes.sha256_from_cache(self.filename, f"checkpoint/{name}") | |
self.shorthash = self.sha256[0:10] if self.sha256 else None | |
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]' | |
self.short_title = self.name_for_extra if self.shorthash is None else f'{self.name_for_extra} [{self.shorthash}]' | |
self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]'] | |
if self.shorthash: | |
self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]', f'{self.name_for_extra} [{self.shorthash}]'] | |
def register(self): | |
checkpoints_list[self.title] = self | |
for id in self.ids: | |
checkpoint_aliases[id] = self | |
def calculate_shorthash(self): | |
self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}") | |
if self.sha256 is None: | |
return | |
shorthash = self.sha256[0:10] | |
if self.shorthash == self.sha256[0:10]: | |
return self.shorthash | |
self.shorthash = shorthash | |
if self.shorthash not in self.ids: | |
self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]', f'{self.name_for_extra} [{self.shorthash}]'] | |
old_title = self.title | |
self.title = f'{self.name} [{self.shorthash}]' | |
self.short_title = f'{self.name_for_extra} [{self.shorthash}]' | |
replace_key(checkpoints_list, old_title, self.title, self) | |
self.register() | |
return self.shorthash | |
def __str__(self): | |
return str(dict(filename=self.filename, hash=self.hash)) | |
def __repr__(self): | |
return str(dict(filename=self.filename, hash=self.hash)) | |
# try: | |
# # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. | |
# from transformers import logging, CLIPModel # noqa: F401 | |
# | |
# logging.set_verbosity_error() | |
# except Exception: | |
# pass | |
def setup_model(): | |
"""called once at startup to do various one-time tasks related to SD models""" | |
os.makedirs(model_path, exist_ok=True) | |
enable_midas_autodownload() | |
patch_given_betas() | |
def checkpoint_tiles(use_short=False): | |
return [x.short_title if use_short else x.name for x in checkpoints_list.values()] | |
def list_models(): | |
checkpoints_list.clear() | |
checkpoint_aliases.clear() | |
cmd_ckpt = shared.cmd_opts.ckpt | |
if shared.cmd_opts.no_download_sd_model or cmd_ckpt != shared.sd_model_file or os.path.exists(cmd_ckpt): | |
model_url = None | |
expected_sha256 = None | |
else: | |
model_url = "https://huggingface.co/lllyasviel/fav_models/resolve/main/fav/realisticVisionV51_v51VAE.safetensors" | |
model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="realisticVisionV51_v51VAE.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"]) | |
if os.path.exists(cmd_ckpt): | |
checkpoint_info = CheckpointInfo(cmd_ckpt) | |
checkpoint_info.register() | |
shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title | |
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: | |
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) | |
for filename in model_list: | |
checkpoint_info = CheckpointInfo(filename) | |
checkpoint_info.register() | |
re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$") | |
def get_closet_checkpoint_match(search_string): | |
if not search_string: | |
return None | |
checkpoint_info = checkpoint_aliases.get(search_string, None) | |
if checkpoint_info is not None: | |
return checkpoint_info | |
found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title)) | |
if found: | |
return found[0] | |
search_string_without_checksum = re.sub(re_strip_checksum, '', search_string) | |
found = sorted([info for info in checkpoints_list.values() if search_string_without_checksum in info.title], key=lambda x: len(x.title)) | |
if found: | |
return found[0] | |
return None | |
def model_hash(filename): | |
"""old hash that only looks at a small part of the file and is prone to collisions""" | |
try: | |
with open(filename, "rb") as file: | |
import hashlib | |
m = hashlib.sha256() | |
file.seek(0x100000) | |
m.update(file.read(0x10000)) | |
return m.hexdigest()[0:8] | |
except FileNotFoundError: | |
return 'NOFILE' | |
def select_checkpoint(): | |
"""Raises `FileNotFoundError` if no checkpoints are found.""" | |
model_checkpoint = shared.opts.sd_model_checkpoint | |
checkpoint_info = checkpoint_aliases.get(model_checkpoint, None) | |
if checkpoint_info is not None: | |
return checkpoint_info | |
if len(checkpoints_list) == 0: | |
error_message = "No checkpoints found. When searching for checkpoints, looked at:" | |
if shared.cmd_opts.ckpt is not None: | |
error_message += f"\n - file {os.path.abspath(shared.cmd_opts.ckpt)}" | |
error_message += f"\n - directory {model_path}" | |
if shared.cmd_opts.ckpt_dir is not None: | |
error_message += f"\n - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}" | |
error_message += "Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations." | |
raise FileNotFoundError(error_message) | |
checkpoint_info = next(iter(checkpoints_list.values())) | |
if model_checkpoint is not None: | |
print(f"Checkpoint {model_checkpoint} not found; loading fallback {checkpoint_info.title}", file=sys.stderr) | |
return checkpoint_info | |
def transform_checkpoint_dict_key(k, replacements): | |
pass | |
def get_state_dict_from_checkpoint(pl_sd): | |
pass | |
def read_metadata_from_safetensors(filename): | |
import json | |
with open(filename, mode="rb") as file: | |
metadata_len = file.read(8) | |
metadata_len = int.from_bytes(metadata_len, "little") | |
json_start = file.read(2) | |
assert metadata_len > 2 and json_start in (b'{"', b"{'"), f"{filename} is not a safetensors file" | |
res = {} | |
try: | |
json_data = json_start + file.read(metadata_len-2) | |
json_obj = json.loads(json_data) | |
for k, v in json_obj.get("__metadata__", {}).items(): | |
res[k] = v | |
if isinstance(v, str) and v[0:1] == '{': | |
try: | |
res[k] = json.loads(v) | |
except Exception: | |
pass | |
except Exception: | |
errors.report(f"Error reading metadata from file: {filename}", exc_info=True) | |
return res | |
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None): | |
pass | |
def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): | |
sd_model_hash = checkpoint_info.calculate_shorthash() | |
timer.record("calculate hash") | |
if checkpoint_info in checkpoints_loaded: | |
# use checkpoint cache | |
print(f"Loading weights [{sd_model_hash}] from cache") | |
# move to end as latest | |
checkpoints_loaded.move_to_end(checkpoint_info) | |
return checkpoints_loaded[checkpoint_info] | |
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}") | |
res = load_torch_file(checkpoint_info.filename) | |
timer.record("load weights from disk") | |
return res | |
def SkipWritingToConfig(): | |
return contextlib.nullcontext() | |
def check_fp8(model): | |
pass | |
def set_model_type(model, state_dict): | |
pass | |
def set_model_fields(model): | |
pass | |
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): | |
pass | |
def enable_midas_autodownload(): | |
pass | |
def patch_given_betas(): | |
pass | |
def repair_config(sd_config, state_dict=None): | |
pass | |
def rescale_zero_terminal_snr_abar(alphas_cumprod): | |
alphas_bar_sqrt = alphas_cumprod.sqrt() | |
# Store old values. | |
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() | |
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() | |
# Shift so the last timestep is zero. | |
alphas_bar_sqrt -= (alphas_bar_sqrt_T) | |
# Scale so the first timestep is back to the old value. | |
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) | |
# Convert alphas_bar_sqrt to betas | |
alphas_bar = alphas_bar_sqrt ** 2 # Revert sqrt | |
alphas_bar[-1] = 4.8973451890853435e-08 | |
return alphas_bar | |
def apply_alpha_schedule_override(sd_model, p=None): | |
""" | |
Applies an override to the alpha schedule of the model according to settings. | |
- downcasts the alpha schedule to half precision | |
- rescales the alpha schedule to have zero terminal SNR | |
""" | |
if not hasattr(sd_model, 'alphas_cumprod') or not hasattr(sd_model, 'alphas_cumprod_original'): | |
return | |
sd_model.alphas_cumprod = sd_model.alphas_cumprod_original.to(shared.device) | |
if opts.use_downcasted_alpha_bar: | |
if p is not None: | |
p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar | |
sd_model.alphas_cumprod = sd_model.alphas_cumprod.half().to(shared.device) | |
if opts.sd_noise_schedule == "Zero Terminal SNR": | |
if p is not None: | |
p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule | |
sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(sd_model.alphas_cumprod).to(shared.device) | |
# This is a dummy class for backward compatibility when model is not load - for extensions like prompt all in one. | |
class FakeInitialModel: | |
def __init__(self): | |
self.cond_stage_model = None | |
self.chunk_length = 75 | |
def get_prompt_lengths_on_ui(self, prompt): | |
r = len(prompt.strip('!,. ').replace(' ', ',').replace('.', ',').replace('!', ',').replace(',,', ',').replace(',,', ',').replace(',,', ',').replace(',,', ',').split(',')) | |
return r, math.ceil(max(r, 1) / self.chunk_length) * self.chunk_length | |
class SdModelData: | |
def __init__(self): | |
self.sd_model = FakeInitialModel() | |
self.forge_loading_parameters = {} | |
self.forge_hash = '' | |
def get_sd_model(self): | |
return self.sd_model | |
def set_sd_model(self, v): | |
self.sd_model = v | |
model_data = SdModelData() | |
def get_empty_cond(sd_model): | |
pass | |
def send_model_to_cpu(m): | |
pass | |
def model_target_device(m): | |
return devices.device | |
def send_model_to_device(m): | |
pass | |
def send_model_to_trash(m): | |
pass | |
def instantiate_from_config(config, state_dict=None): | |
pass | |
def get_obj_from_str(string, reload=False): | |
pass | |
def load_model(checkpoint_info=None, already_loaded_state_dict=None): | |
pass | |
def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): | |
pass | |
def reload_model_weights(sd_model=None, info=None, forced_reload=False): | |
pass | |
def unload_model_weights(sd_model=None, info=None): | |
pass | |
def apply_token_merging(sd_model, token_merging_ratio): | |
if token_merging_ratio <= 0: | |
return | |
print(f'token_merging_ratio = {token_merging_ratio}') | |
from backend.misc.tomesd import TomePatcher | |
sd_model.forge_objects.unet = TomePatcher().patch( | |
model=sd_model.forge_objects.unet, | |
ratio=token_merging_ratio | |
) | |
return | |
def forge_model_reload(): | |
current_hash = str(model_data.forge_loading_parameters) | |
if model_data.forge_hash == current_hash: | |
return model_data.sd_model, False | |
print('Loading Model: ' + str(model_data.forge_loading_parameters)) | |
timer = Timer() | |
if model_data.sd_model: | |
model_data.sd_model = None | |
memory_management.unload_all_models() | |
memory_management.soft_empty_cache() | |
gc.collect() | |
timer.record("unload existing model") | |
checkpoint_info = model_data.forge_loading_parameters['checkpoint_info'] | |
state_dict = load_torch_file(checkpoint_info.filename) | |
timer.record("load state dict") | |
state_dict_vae = model_data.forge_loading_parameters.get('vae_filename', None) | |
if state_dict_vae is not None: | |
state_dict_vae = load_torch_file(state_dict_vae) | |
timer.record("load vae state dict") | |
if shared.opts.sd_checkpoint_cache > 0: | |
# cache newly loaded model | |
checkpoints_loaded[checkpoint_info] = state_dict.copy() | |
timer.record("cache state dict") | |
dynamic_args['forge_unet_storage_dtype'] = model_data.forge_loading_parameters.get('unet_storage_dtype', None) | |
dynamic_args['embedding_dir'] = cmd_opts.embeddings_dir | |
dynamic_args['emphasis_name'] = opts.emphasis | |
sd_model = forge_loader(state_dict, sd_vae=state_dict_vae) | |
del state_dict | |
timer.record("forge model load") | |
sd_model.extra_generation_params = {} | |
sd_model.comments = [] | |
sd_model.sd_checkpoint_info = checkpoint_info | |
sd_model.filename = checkpoint_info.filename | |
sd_model.sd_model_hash = checkpoint_info.calculate_shorthash() | |
timer.record("calculate hash") | |
# clean up cache if limit is reached | |
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: | |
checkpoints_loaded.popitem(last=False) | |
shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256 | |
model_data.set_sd_model(sd_model) | |
script_callbacks.model_loaded_callback(sd_model) | |
timer.record("scripts callbacks") | |
print(f"Model loaded in {timer.summary()}.") | |
model_data.forge_hash = current_hash | |
return sd_model, True | |