import cv2 import numpy as np import torch import os import tempfile import warnings from contextlib import suppress from pathlib import Path from huggingface_hub import constants, hf_hub_download from ast import literal_eval TEMP_DIR = tempfile.gettempdir() ANNOTATOR_CKPTS_PATH = os.path.join(Path(__file__).parents[2], 'ckpts') USE_SYMLINKS = False BIGMIN = -(2**53-1) BIGMAX = (2**53-1) DIMMAX = 8192 try: ANNOTATOR_CKPTS_PATH = os.environ['AUX_ANNOTATOR_CKPTS_PATH'] except: warnings.warn("Custom pressesor model path not set successfully.") pass try: USE_SYMLINKS = literal_eval(os.environ['AUX_USE_SYMLINKS']) except: warnings.warn("USE_SYMLINKS not set successfully. Using default value: False to download models.") pass try: TEMP_DIR = os.environ['AUX_TEMP_DIR'] if len(TEMP_DIR) >= 60: warnings.warn(f"custom temp dir is too long. Using default") TEMP_DIR = tempfile.gettempdir() except: warnings.warn(f"custom temp dir not set successfully") pass here = Path(__file__).parent.resolve() def safer_memory(x): # Fix many MAC/AMD problems return np.ascontiguousarray(x.copy()).copy() UPSCALE_METHODS = ["INTER_NEAREST", "INTER_LINEAR", "INTER_AREA", "INTER_CUBIC", "INTER_LANCZOS4"] def get_upscale_method(method_str): assert method_str in UPSCALE_METHODS, f"Method {method_str} not found in {UPSCALE_METHODS}" return getattr(cv2, method_str) def pad64(x): return int(np.ceil(float(x) / 64.0) * 64 - x) def resize_image_with_pad(input_image, resolution, upscale_method = "", skip_hwc3=False, mode='edge'): if skip_hwc3: img = input_image else: img = HWC3(input_image) H_raw, W_raw, _ = img.shape if resolution == 0: return img, lambda x: x k = float(resolution) / float(min(H_raw, W_raw)) H_target = int(np.round(float(H_raw) * k)) W_target = int(np.round(float(W_raw) * k)) img = cv2.resize(img, (W_target, H_target), interpolation=get_upscale_method(upscale_method) if k > 1 else cv2.INTER_AREA) H_pad, W_pad = pad64(H_target), pad64(W_target) img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode) def remove_pad(x): return safer_memory(x[:H_target, :W_target, ...]) return safer_memory(img_padded), remove_pad def common_input_validate(input_image, output_type, **kwargs): if "img" in kwargs: warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning) input_image = kwargs.pop("img") if "return_pil" in kwargs: warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning) output_type = "pil" if kwargs["return_pil"] else "np" if type(output_type) is bool: warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions") if output_type: output_type = "pil" if input_image is None: raise ValueError("input_image must be defined.") if not isinstance(input_image, np.ndarray): input_image = np.array(input_image, dtype=np.uint8) output_type = output_type or "pil" else: output_type = output_type or "np" return (input_image, output_type) def custom_hf_download(pretrained_model_or_path, filename, cache_dir=TEMP_DIR, ckpts_dir=ANNOTATOR_CKPTS_PATH, subfolder=str(""), use_symlinks=USE_SYMLINKS, repo_type="model"): print(f'cache_dir: {cache_dir}') print(f'ckpts_dir: {ckpts_dir}') print(f'use_symlinks: {use_symlinks}') local_dir = os.path.join(ckpts_dir, pretrained_model_or_path) model_path = os.path.join(local_dir, *subfolder.split('/'), filename) if len(str(model_path)) >= 255: warnings.warn(f"Path {model_path} is too long, \n please change annotator_ckpts_path in config.yaml") if not os.path.exists(model_path): print(f"Failed to find {model_path}.\n Downloading from huggingface.co") print(f"cacher folder is {cache_dir}, you can change it by custom_tmp_path in config.yaml") if use_symlinks: cache_dir_d = constants.HF_HUB_CACHE # use huggingface newer env variables `HF_HUB_CACHE` if cache_dir_d is None: import platform if platform.system() == "Windows": cache_dir_d = os.path.join(os.getenv("USERPROFILE"), ".cache", "huggingface", "hub") else: cache_dir_d = os.path.join(os.getenv("HOME"), ".cache", "huggingface", "hub") try: # test_link Path(cache_dir_d).mkdir(parents=True, exist_ok=True) Path(ckpts_dir).mkdir(parents=True, exist_ok=True) (Path(cache_dir_d) / f"linktest_{filename}.txt").touch() # symlink instead of link avoid `invalid cross-device link` error. os.symlink(os.path.join(cache_dir_d, f"linktest_{filename}.txt"), os.path.join(ckpts_dir, f"linktest_{filename}.txt")) print("Using symlinks to download models. \n",\ "Make sure you have enough space on your cache folder. \n",\ "And do not purge the cache folder after downloading.\n",\ "Otherwise, you will have to re-download the models every time you run the script.\n",\ "You can use USE_SYMLINKS: False in config.yaml to avoid this behavior.") except: print("Maybe not able to create symlink. Disable using symlinks.") use_symlinks = False cache_dir_d = os.path.join(cache_dir, "ckpts", pretrained_model_or_path) finally: # always remove test link files with suppress(FileNotFoundError): os.remove(os.path.join(ckpts_dir, f"linktest_{filename}.txt")) os.remove(os.path.join(cache_dir_d, f"linktest_{filename}.txt")) else: cache_dir_d = os.path.join(cache_dir, "ckpts", pretrained_model_or_path) model_path = hf_hub_download(repo_id=pretrained_model_or_path, cache_dir=cache_dir_d, local_dir=local_dir, subfolder=subfolder, filename=filename, local_dir_use_symlinks=use_symlinks, resume_download=True, etag_timeout=100, repo_type=repo_type ) if not use_symlinks: try: import shutil shutil.rmtree(os.path.join(cache_dir, "ckpts")) except Exception as e : print(e) print(f"model_path is {model_path}") return model_path def HWC3(x): assert x.dtype == np.uint8 if x.ndim == 2: x = x[:, :, None] assert x.ndim == 3 H, W, C = x.shape assert C == 1 or C == 3 or C == 4 if C == 3: return x if C == 1: return np.concatenate([x, x, x], axis=2) if C == 4: color = x[:, :, 0:3].astype(np.float32) alpha = x[:, :, 3:4].astype(np.float32) / 255.0 y = color * alpha + 255.0 * (1.0 - alpha) y = y.clip(0, 255).astype(np.uint8) return y