Spaces:
Runtime error
Runtime error
File size: 6,241 Bytes
f2dbf59 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import hashlib
import os
import torch
import nodes
import server
import folder_paths
import numpy as np
from typing import Iterable
from PIL import Image
BIGMIN = -(2**53-1)
BIGMAX = (2**53-1)
DIMMAX = 8192
def tensor_to_int(tensor, bits):
#TODO: investigate benefit of rounding by adding 0.5 before clip/cast
tensor = tensor.cpu().numpy() * (2**bits-1)
return np.clip(tensor, 0, (2**bits-1))
def tensor_to_shorts(tensor):
return tensor_to_int(tensor, 16).astype(np.uint16)
def tensor_to_bytes(tensor):
return tensor_to_int(tensor, 8).astype(np.uint8)
def tensor2pil(x):
return Image.fromarray(np.clip(255. * x.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
def pil2tensor(image: Image.Image) -> torch.Tensor:
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
def is_url(url):
return url.split("://")[0] in ["http", "https"]
def strip_path(path):
#This leaves whitespace inside quotes and only a single "
#thus ' ""test"' -> '"test'
#consider path.strip(string.whitespace+"\"")
#or weightier re.fullmatch("[\\s\"]*(.+?)[\\s\"]*", path).group(1)
path = path.strip()
if path.startswith("\""):
path = path[1:]
if path.endswith("\""):
path = path[:-1]
return path
def hash_path(path):
if path is None:
return "input"
if is_url(path):
return "url"
return calculate_file_hash(strip_path(path))
# modified from https://stackoverflow.com/questions/22058048/hashing-a-file-in-python
def calculate_file_hash(filename: str, hash_every_n: int = 1):
#Larger video files were taking >.5 seconds to hash even when cached,
#so instead the modified time from the filesystem is used as a hash
h = hashlib.sha256()
h.update(filename.encode())
h.update(str(os.path.getmtime(filename)).encode())
return h.hexdigest()
def is_safe_path(path):
if "VHS_STRICT_PATHS" not in os.environ:
return True
basedir = os.path.abspath('.')
try:
common_path = os.path.commonpath([basedir, path])
except:
#Different drive on windows
return False
return common_path == basedir
def validate_path(path, allow_none=False, allow_url=True):
if path is None:
return allow_none
if is_url(path):
#Probably not feasible to check if url resolves here
if not allow_url:
return "URLs are unsupported for this path"
return is_safe_path(path)
if not os.path.isfile(strip_path(path)):
return "Invalid file path: {}".format(path)
return is_safe_path(path)
def common_annotator_call(model, tensor_image, input_batch=False, show_pbar=False, **kwargs):
if "detect_resolution" in kwargs:
del kwargs["detect_resolution"] #Prevent weird case?
if "resolution" in kwargs:
detect_resolution = kwargs["resolution"] if type(kwargs["resolution"]) == int and kwargs["resolution"] >= 64 else 512
del kwargs["resolution"]
else:
detect_resolution = 512
if input_batch:
np_images = np.asarray(tensor_image * 255., dtype=np.uint8)
np_results = model(np_images, output_type="np", detect_resolution=detect_resolution, **kwargs)
return torch.from_numpy(np_results.astype(np.float32) / 255.0)
batch_size = tensor_image.shape[0]
out_tensor = None
for i, image in enumerate(tensor_image):
np_image = np.asarray(image.cpu() * 255., dtype=np.uint8)
np_result = model(np_image, output_type="np", detect_resolution=detect_resolution, **kwargs)
out = torch.from_numpy(np_result.astype(np.float32) / 255.0)
if out_tensor is None:
out_tensor = torch.zeros(batch_size, *out.shape, dtype=torch.float32)
out_tensor[i] = out
return out_tensor
def create_node_input_types(**extra_kwargs):
return {
"required": {
"image": ("IMAGE",)
},
"optional": {
**extra_kwargs,
"resolution": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 64})
}
}
prompt_queue = server.PromptServer.instance.prompt_queue
def requeue_workflow_unchecked():
"""Requeues the current workflow without checking for multiple requeues"""
currently_running = prompt_queue.currently_running
print(f'requeue_workflow_unchecked >>>>>> ')
(_, _, prompt, extra_data, outputs_to_execute) = next(iter(currently_running.values()))
#Ensure batch_managers are marked stale
prompt = prompt.copy()
for uid in prompt:
if prompt[uid]['class_type'] == 'BatchManager':
prompt[uid]['inputs']['requeue'] = prompt[uid]['inputs'].get('requeue',0)+1
#execution.py has guards for concurrency, but server doesn't.
#TODO: Check that this won't be an issue
number = -server.PromptServer.instance.number
server.PromptServer.instance.number += 1
prompt_id = str(server.uuid.uuid4())
prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
print(f'requeue_workflow_unchecked <<<<<<<<<< prompt_id:{prompt_id}, number:{number}')
requeue_guard = [None, 0, 0, {}]
def requeue_workflow(requeue_required=(-1,True)):
assert(len(prompt_queue.currently_running) == 1)
global requeue_guard
(run_number, _, prompt, _, _) = next(iter(prompt_queue.currently_running.values()))
print(f'requeue_workflow >> run_number:{run_number}\n')
if requeue_guard[0] != run_number:
#Calculate a count of how many outputs are managed by a batch manager
managed_outputs=0
for bm_uid in prompt:
if prompt[bm_uid]['class_type'] == 'BatchManager':
for output_uid in prompt:
if prompt[output_uid]['class_type'] in ["VideoSaver"]:
for inp in prompt[output_uid]['inputs'].values():
if inp == [bm_uid, 0]:
managed_outputs+=1
requeue_guard = [run_number, 0, managed_outputs, {}]
requeue_guard[1] = requeue_guard[1]+1
requeue_guard[3][requeue_required[0]] = requeue_required[1]
if requeue_guard[1] == requeue_guard[2] and max(requeue_guard[3].values()):
requeue_workflow_unchecked()
|