import torch import os import json import safetensors.torch import backend.misc.checkpoint_pickle def read_arbitrary_config(directory): config_path = os.path.join(directory, 'config.json') if not os.path.exists(config_path): raise FileNotFoundError(f"No config.json file found in the directory: {directory}") with open(config_path, 'rt', encoding='utf-8') as file: config_data = json.load(file) return config_data def load_torch_file(ckpt, safe_load=False, device=None): if device is None: device = torch.device("cpu") if ckpt.lower().endswith(".safetensors"): sd = safetensors.torch.load_file(ckpt, device=device.type) else: if safe_load: if not 'weights_only' in torch.load.__code__.co_varnames: print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.") safe_load = False if safe_load: pl_sd = torch.load(ckpt, map_location=device, weights_only=True) else: pl_sd = torch.load(ckpt, map_location=device, pickle_module=backend.misc.checkpoint_pickle) if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") if "state_dict" in pl_sd: sd = pl_sd["state_dict"] else: sd = pl_sd return sd def set_attr(obj, attr, value): attrs = attr.split(".") for name in attrs[:-1]: obj = getattr(obj, name) setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False)) def set_attr_raw(obj, attr, value): attrs = attr.split(".") for name in attrs[:-1]: obj = getattr(obj, name) setattr(obj, attrs[-1], value) def copy_to_param(obj, attr, value): attrs = attr.split(".") for name in attrs[:-1]: obj = getattr(obj, name) prev = getattr(obj, attrs[-1]) prev.data.copy_(value) def get_attr(obj, attr): attrs = attr.split(".") for name in attrs: obj = getattr(obj, name) return obj def calculate_parameters(sd, prefix=""): params = 0 for k in sd.keys(): if k.startswith(prefix): params += sd[k].nelement() return params