Spaces:
Runtime error
Runtime error
Commit
·
f2dbf59
1
Parent(s):
6a1c163
custom_nodes
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +9 -0
- .gitignore +1 -1
- custom_nodes/ComfyUI-tbox/.gitignore +3 -0
- custom_nodes/ComfyUI-tbox/README.md +5 -0
- custom_nodes/ComfyUI-tbox/__init__.py +82 -0
- custom_nodes/ComfyUI-tbox/config.yaml +4 -0
- custom_nodes/ComfyUI-tbox/nodes/face/__init__.py +8 -0
- custom_nodes/ComfyUI-tbox/nodes/face/face_enhance_node.py +81 -0
- custom_nodes/ComfyUI-tbox/nodes/image/load_node.py +81 -0
- custom_nodes/ComfyUI-tbox/nodes/image/save_node.py +79 -0
- custom_nodes/ComfyUI-tbox/nodes/image/size_node.py +121 -0
- custom_nodes/ComfyUI-tbox/nodes/image/watermark_node.py +58 -0
- custom_nodes/ComfyUI-tbox/nodes/mask/mask_node.py +86 -0
- custom_nodes/ComfyUI-tbox/nodes/other/vram_node.py +45 -0
- custom_nodes/ComfyUI-tbox/nodes/preprocessor/canny_node.py +22 -0
- custom_nodes/ComfyUI-tbox/nodes/preprocessor/densepose_node.py +22 -0
- custom_nodes/ComfyUI-tbox/nodes/preprocessor/dwpose_node.py +158 -0
- custom_nodes/ComfyUI-tbox/nodes/preprocessor/lineart_node.py +44 -0
- custom_nodes/ComfyUI-tbox/nodes/preprocessor/midas_node.py +25 -0
- custom_nodes/ComfyUI-tbox/nodes/utils.py +165 -0
- custom_nodes/ComfyUI-tbox/nodes/video/batch_node.py +69 -0
- custom_nodes/ComfyUI-tbox/nodes/video/ffmpeg.py +129 -0
- custom_nodes/ComfyUI-tbox/nodes/video/info_node.py +39 -0
- custom_nodes/ComfyUI-tbox/nodes/video/load_node.py +261 -0
- custom_nodes/ComfyUI-tbox/nodes/video/save_node.py +415 -0
- custom_nodes/ComfyUI-tbox/nodes/video/video_formats/h264-mp4.json +10 -0
- custom_nodes/ComfyUI-tbox/nodes/video/video_formats/h265-mp4.json +13 -0
- custom_nodes/ComfyUI-tbox/nodes/video/video_formats/nvenc_h264-mp4.json +12 -0
- custom_nodes/ComfyUI-tbox/nodes/video/video_formats/nvenc_h265-mp4.json +10 -0
- custom_nodes/ComfyUI-tbox/nodes/video/video_formats/webm.json +11 -0
- custom_nodes/ComfyUI-tbox/requirements.txt +11 -0
- custom_nodes/ComfyUI-tbox/src/canny/__init__.py +17 -0
- custom_nodes/ComfyUI-tbox/src/common.py +186 -0
- custom_nodes/ComfyUI-tbox/src/densepose/__init__.py +67 -0
- custom_nodes/ComfyUI-tbox/src/densepose/densepose.py +347 -0
- custom_nodes/ComfyUI-tbox/src/dwpose/LICENSE +108 -0
- custom_nodes/ComfyUI-tbox/src/dwpose/__init__.py +328 -0
- custom_nodes/ComfyUI-tbox/src/dwpose/animalpose.py +273 -0
- custom_nodes/ComfyUI-tbox/src/dwpose/body.py +261 -0
- custom_nodes/ComfyUI-tbox/src/dwpose/dw_onnx/__init__.py +1 -0
- custom_nodes/ComfyUI-tbox/src/dwpose/dw_onnx/cv_ox_det.py +129 -0
- custom_nodes/ComfyUI-tbox/src/dwpose/dw_onnx/cv_ox_pose.py +363 -0
- custom_nodes/ComfyUI-tbox/src/dwpose/dw_onnx/cv_ox_yolo_nas.py +60 -0
- custom_nodes/ComfyUI-tbox/src/dwpose/dw_torchscript/__init__.py +1 -0
- custom_nodes/ComfyUI-tbox/src/dwpose/dw_torchscript/jit_det.py +125 -0
- custom_nodes/ComfyUI-tbox/src/dwpose/dw_torchscript/jit_pose.py +363 -0
- custom_nodes/ComfyUI-tbox/src/dwpose/face.py +362 -0
- custom_nodes/ComfyUI-tbox/src/dwpose/hand.py +94 -0
- custom_nodes/ComfyUI-tbox/src/dwpose/model.py +218 -0
- custom_nodes/ComfyUI-tbox/src/dwpose/types.py +30 -0
.gitattributes
CHANGED
@@ -15,3 +15,12 @@ disaster_girl.png filter=lfs diff=lfs merge=lfs -text
|
|
15 |
istambul.jpg filter=lfs diff=lfs merge=lfs -text
|
16 |
mona.png filter=lfs diff=lfs merge=lfs -text
|
17 |
natasha.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
istambul.jpg filter=lfs diff=lfs merge=lfs -text
|
16 |
mona.png filter=lfs diff=lfs merge=lfs -text
|
17 |
natasha.png filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.ttf filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.whl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.task filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.woff2 filter=lfs diff=lfs merge=lfs -text
|
.gitignore
CHANGED
@@ -5,7 +5,7 @@ __pycache__/
|
|
5 |
!/input/example.png
|
6 |
/models/
|
7 |
/temp/
|
8 |
-
|
9 |
!custom_nodes/example_node.py.example
|
10 |
extra_model_paths.yaml
|
11 |
/.vs
|
|
|
5 |
!/input/example.png
|
6 |
/models/
|
7 |
/temp/
|
8 |
+
#/custom_nodes/
|
9 |
!custom_nodes/example_node.py.example
|
10 |
extra_model_paths.yaml
|
11 |
/.vs
|
custom_nodes/ComfyUI-tbox/.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
.idea
|
2 |
+
**/__pycache__
|
3 |
+
|
custom_nodes/ComfyUI-tbox/README.md
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Image Node:
|
2 |
+
|
3 |
+
## Video Node
|
4 |
+
|
5 |
+
## ControlNet PreProcessor
|
custom_nodes/ComfyUI-tbox/__init__.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from pathlib import Path
|
3 |
+
from .utils import here
|
4 |
+
import platform
|
5 |
+
|
6 |
+
sys.path.insert(0, str(Path(here, "src").resolve()))
|
7 |
+
|
8 |
+
from .nodes.image.load_node import LoadImageNode
|
9 |
+
from .nodes.image.save_node import SaveImageNode
|
10 |
+
from .nodes.image.save_node import SaveImagesNode
|
11 |
+
from .nodes.image.size_node import ImageResizeNode
|
12 |
+
from .nodes.image.size_node import ImageSizeNode
|
13 |
+
from .nodes.image.size_node import ConstrainImageNode
|
14 |
+
from .nodes.image.watermark_node import WatermarkNode
|
15 |
+
from .nodes.mask.mask_node import MaskAddNode
|
16 |
+
from .nodes.video.load_node import LoadVideoNode
|
17 |
+
from .nodes.video.save_node import SaveVideoNode
|
18 |
+
from .nodes.video.info_node import VideoInfoNode
|
19 |
+
from .nodes.video.batch_node import BatchManagerNode
|
20 |
+
from .nodes.preprocessor.canny_node import Canny_Preprocessor
|
21 |
+
from .nodes.preprocessor.lineart_node import LineArt_Preprocessor
|
22 |
+
from .nodes.preprocessor.lineart_node import Lineart_Standard_Preprocessor
|
23 |
+
from .nodes.preprocessor.midas_node import MIDAS_Depth_Map_Preprocessor
|
24 |
+
from .nodes.preprocessor.dwpose_node import DWPose_Preprocessor, AnimalPose_Preprocessor
|
25 |
+
from .nodes.preprocessor.densepose_node import DensePose_Preprocessor
|
26 |
+
from .nodes.face.face_enhance_node import GFPGANNode
|
27 |
+
from .nodes.other.vram_node import PurgeVRAMNode
|
28 |
+
|
29 |
+
NODE_CLASS_MAPPINGS = {
|
30 |
+
"PurgeVRAMNode": PurgeVRAMNode,
|
31 |
+
"GFPGANNode": GFPGANNode,
|
32 |
+
"MaskAddNode": MaskAddNode,
|
33 |
+
"ImageLoader": LoadImageNode,
|
34 |
+
"ImageSaver": SaveImageNode,
|
35 |
+
"ImagesSaver": SaveImagesNode,
|
36 |
+
"ImageResize": ImageResizeNode,
|
37 |
+
"ImageSize": ImageSizeNode,
|
38 |
+
"WatermarkNode": WatermarkNode,
|
39 |
+
"VideoLoader": LoadVideoNode,
|
40 |
+
"VideoSaver": SaveVideoNode,
|
41 |
+
"VideoInfo": VideoInfoNode,
|
42 |
+
"BatchManager": BatchManagerNode,
|
43 |
+
"ConstrainImageNode": ConstrainImageNode,
|
44 |
+
"DensePosePreprocessor": DensePose_Preprocessor,
|
45 |
+
"DWPosePreprocessor": DWPose_Preprocessor,
|
46 |
+
"AnimalPosePreprocessor": AnimalPose_Preprocessor,
|
47 |
+
"MiDaSDepthPreprocessor": MIDAS_Depth_Map_Preprocessor,
|
48 |
+
"CannyPreprocessor": Canny_Preprocessor,
|
49 |
+
"LineArtPreprocessor": LineArt_Preprocessor,
|
50 |
+
"LineartStandardPreprocessor": Lineart_Standard_Preprocessor,
|
51 |
+
}
|
52 |
+
|
53 |
+
NODE_DISPLAY_NAME_MAPPINGS = {
|
54 |
+
"PurgeVRAMNode":"PurgeVRAMNode",
|
55 |
+
"GFPGANNode": "GFPGANNode",
|
56 |
+
"MaskAddNode": "MaskAddNode",
|
57 |
+
"ImageLoader": "Image Load",
|
58 |
+
"ImageSaver": "Image Save",
|
59 |
+
"ImagesSaver": "Image List Save",
|
60 |
+
"ImageResize": "Image Resize",
|
61 |
+
"ImageSize": "Image Size",
|
62 |
+
"WatermarkNode": "Watermark",
|
63 |
+
"VideoLoader": "Video Load",
|
64 |
+
"VideoSaver": "Video Save",
|
65 |
+
"VideoInfo": "Video Info",
|
66 |
+
"BatchManager": "Batch Manager",
|
67 |
+
"ConstrainImageNode": "Image Constrain",
|
68 |
+
"DensePosePreprocessor": "DensePose Estimator",
|
69 |
+
"DWPosePreprocessor": "DWPose Estimator",
|
70 |
+
"AnimalPosePreprocessor": "AnimalPose Estimator",
|
71 |
+
"MiDaSDepthPreprocessor": "MiDaS Depth Estimator",
|
72 |
+
"CannyPreprocessor": "Canny Edge Estimator",
|
73 |
+
"LineArtPreprocessor": "Realistic Lineart",
|
74 |
+
"LineartStandardPreprocessor": "Standard Lineart",
|
75 |
+
}
|
76 |
+
|
77 |
+
|
78 |
+
if platform.system() == "Darwin":
|
79 |
+
WEB_DIRECTORY = "./web"
|
80 |
+
__all__ = ["WEB_DIRECTORY", "NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
|
81 |
+
else:
|
82 |
+
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
|
custom_nodes/ComfyUI-tbox/config.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
annotator_ckpts_path: "../../models/annotator"
|
2 |
+
custom_temp_path: "../../temp"
|
3 |
+
USE_SYMLINKS: False
|
4 |
+
EP_list: ["CUDAExecutionProvider", "DirectMLExecutionProvider", "OpenVINOExecutionProvider", "ROCMExecutionProvider", "CPUExecutionProvider"]
|
custom_nodes/ComfyUI-tbox/nodes/face/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import folder_paths
|
4 |
+
|
5 |
+
|
6 |
+
model_path = os.path.join(folder_paths.models_dir, "facefusion")
|
7 |
+
folder_paths.add_model_folder_path('facefusion', model_path)
|
8 |
+
|
custom_nodes/ComfyUI-tbox/nodes/face/face_enhance_node.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
from facefusion.gfpgan_onnx import GFPGANOnnx
|
7 |
+
from facefusion.yoloface_onnx import YoloFaceOnnx
|
8 |
+
from facefusion.affine import create_box_mask, warp_face_by_landmark, paste_back
|
9 |
+
|
10 |
+
import folder_paths
|
11 |
+
from ..utils import tensor2pil, pil2tensor
|
12 |
+
|
13 |
+
# class GFPGANProvider:
|
14 |
+
# @classmethod
|
15 |
+
# def INPUT_TYPES(s):
|
16 |
+
# return {
|
17 |
+
# "required": {
|
18 |
+
# "model_name": ("IMAGE", ["gfpgan_1.4.onnx"]),
|
19 |
+
# },
|
20 |
+
# }
|
21 |
+
|
22 |
+
# RETURN_TYPES = ("GFPGAN_MODEL",)
|
23 |
+
# RETURN_NAMES = ("model",)
|
24 |
+
# FUNCTION = "load_model"
|
25 |
+
# CATEGORY = "tbox/facefusion"
|
26 |
+
|
27 |
+
# def load_model(self, model_name):
|
28 |
+
# return (model_name,)
|
29 |
+
|
30 |
+
|
31 |
+
class GFPGANNode:
|
32 |
+
@classmethod
|
33 |
+
def INPUT_TYPES(cls):
|
34 |
+
return {
|
35 |
+
"required": {
|
36 |
+
"images": ("IMAGE",),
|
37 |
+
"model_name": (['gfpgan_1.3', 'gfpgan_1.4'], {"default": 'gfpgan_1.4'}),
|
38 |
+
"device": (['CPU', 'CUDA', 'CoreML', 'ROCM'], {"default": 'CPU'}),
|
39 |
+
"weight": ("FLOAT", {"default": 0.8}),
|
40 |
+
}
|
41 |
+
}
|
42 |
+
|
43 |
+
RETURN_TYPES = ("IMAGE", )
|
44 |
+
FUNCTION = "process"
|
45 |
+
CATEGORY = "tbox/FaceFusion"
|
46 |
+
|
47 |
+
def process(self, images, model_name, device='CPU', weight=0.8):
|
48 |
+
providers = ['CPUExecutionProvider']
|
49 |
+
if device== 'CUDA':
|
50 |
+
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
51 |
+
elif device == 'CoreML':
|
52 |
+
providers = ['CoreMLExecutionProvider', 'CPUExecutionProvider']
|
53 |
+
elif device == 'ROCM':
|
54 |
+
providers = ['ROCMExecutionProvider', 'CPUExecutionProvider']
|
55 |
+
|
56 |
+
gfpgan_path = folder_paths.get_full_path("facefusion", f'{model_name}.onnx')
|
57 |
+
yolo_path = folder_paths.get_full_path("facefusion", 'yoloface_8n.onnx')
|
58 |
+
|
59 |
+
detector = YoloFaceOnnx(model_path=yolo_path, providers=providers)
|
60 |
+
enhancer = GFPGANOnnx(model_path=gfpgan_path, providers=providers)
|
61 |
+
|
62 |
+
image_list = []
|
63 |
+
for i, img in enumerate(images):
|
64 |
+
pil = tensor2pil(img)
|
65 |
+
image = np.ascontiguousarray(pil)
|
66 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
67 |
+
output = image
|
68 |
+
face_list = detector.detect(image=image, conf=0.7)
|
69 |
+
for index, face in enumerate(face_list):
|
70 |
+
cropped, affine_matrix = warp_face_by_landmark(image, face.landmarks, enhancer.input_size)
|
71 |
+
box_mask = create_box_mask(enhancer.input_size, 0.3, (0,0,0,0))
|
72 |
+
crop_mask = np.minimum.reduce([box_mask]).clip(0, 1)
|
73 |
+
result = enhancer.run(cropped)
|
74 |
+
output = paste_back(output, result, crop_mask, affine_matrix)
|
75 |
+
image = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
|
76 |
+
pil = Image.fromarray(image)
|
77 |
+
image_list.append(pil2tensor(pil))
|
78 |
+
image_list = torch.stack([tensor.squeeze() for tensor in image_list])
|
79 |
+
return (image_list,)
|
80 |
+
|
81 |
+
|
custom_nodes/ComfyUI-tbox/nodes/image/load_node.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import cv2
|
4 |
+
import torch
|
5 |
+
import requests
|
6 |
+
import itertools
|
7 |
+
import folder_paths
|
8 |
+
import psutil
|
9 |
+
import numpy as np
|
10 |
+
from comfy.utils import common_upscale
|
11 |
+
from io import BytesIO
|
12 |
+
from PIL import Image, ImageSequence, ImageOps
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
def pil2tensor(img):
|
17 |
+
output_images = []
|
18 |
+
output_masks = []
|
19 |
+
for i in ImageSequence.Iterator(img):
|
20 |
+
i = ImageOps.exif_transpose(i)
|
21 |
+
if i.mode == 'I':
|
22 |
+
i = i.point(lambda i: i * (1 / 255))
|
23 |
+
image = i.convert("RGB")
|
24 |
+
image = np.array(image).astype(np.float32) / 255.0
|
25 |
+
image = torch.from_numpy(image)[None,]
|
26 |
+
if 'A' in i.getbands():
|
27 |
+
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
|
28 |
+
mask = 1. - torch.from_numpy(mask)
|
29 |
+
else:
|
30 |
+
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
|
31 |
+
output_images.append(image)
|
32 |
+
output_masks.append(mask.unsqueeze(0))
|
33 |
+
|
34 |
+
if len(output_images) > 1:
|
35 |
+
output_image = torch.cat(output_images, dim=0)
|
36 |
+
output_mask = torch.cat(output_masks, dim=0)
|
37 |
+
else:
|
38 |
+
output_image = output_images[0]
|
39 |
+
output_mask = output_masks[0]
|
40 |
+
|
41 |
+
return (output_image, output_mask)
|
42 |
+
|
43 |
+
|
44 |
+
def load_image(image_source):
|
45 |
+
if image_source.startswith('http'):
|
46 |
+
print(image_source)
|
47 |
+
response = requests.get(image_source)
|
48 |
+
img = Image.open(BytesIO(response.content))
|
49 |
+
file_name = image_source.split('/')[-1]
|
50 |
+
else:
|
51 |
+
img = Image.open(image_source)
|
52 |
+
file_name = os.path.basename(image_source)
|
53 |
+
return img, file_name
|
54 |
+
|
55 |
+
|
56 |
+
class LoadImageNode:
|
57 |
+
@classmethod
|
58 |
+
def INPUT_TYPES(cls):
|
59 |
+
return {
|
60 |
+
"required": {
|
61 |
+
"path": ("STRING", {"multiline": True, "dynamicPrompts": False})
|
62 |
+
}
|
63 |
+
}
|
64 |
+
|
65 |
+
|
66 |
+
RETURN_TYPES = ("IMAGE", "MASK")
|
67 |
+
FUNCTION = "load_image"
|
68 |
+
CATEGORY = "tbox/Image"
|
69 |
+
|
70 |
+
def load_image(self, path):
|
71 |
+
filepaht = path.split('\n')[0]
|
72 |
+
img, name = load_image(filepaht)
|
73 |
+
img_out, mask_out = pil2tensor(img)
|
74 |
+
return (img_out, mask_out)
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
+
if __name__ == "__main__":
|
79 |
+
img, name = load_image("https://creativestorage.blob.core.chinacloudapi.cn/test/bird.png")
|
80 |
+
img_out, mask_out = pil2tensor(img)
|
81 |
+
|
custom_nodes/ComfyUI-tbox/nodes/image/save_node.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image, ImageSequence, ImageOps
|
5 |
+
|
6 |
+
#from load_node import load_image, pil2tensor
|
7 |
+
|
8 |
+
def save_image(img, filepath, format, quality):
|
9 |
+
try:
|
10 |
+
if format in ["jpg", "jpeg"]:
|
11 |
+
img.convert("RGB").save(filepath, format="JPEG", quality=quality, subsampling=0)
|
12 |
+
elif format == "webp":
|
13 |
+
img.save(filepath, format="WEBP", quality=quality, method=6)
|
14 |
+
elif format == "bmp":
|
15 |
+
img.save(filepath, format="BMP")
|
16 |
+
else:
|
17 |
+
img.save(filepath, format="PNG", optimize=True)
|
18 |
+
except Exception as e:
|
19 |
+
print(f"Error saving {filepath}: {str(e)}")
|
20 |
+
|
21 |
+
class SaveImageNode:
|
22 |
+
@classmethod
|
23 |
+
def INPUT_TYPES(cls):
|
24 |
+
return {
|
25 |
+
"required": {
|
26 |
+
"images": ("IMAGE",),
|
27 |
+
"path": ("STRING", {"multiline": True, "dynamicPrompts": False}),
|
28 |
+
"quality": ([100, 95, 90, 85, 80, 75, 70, 60, 50], {"default": 100}),
|
29 |
+
}
|
30 |
+
}
|
31 |
+
RETURN_TYPES = ()
|
32 |
+
FUNCTION = "save_image"
|
33 |
+
CATEGORY = "tbox/Image"
|
34 |
+
OUTPUT_NODE = True
|
35 |
+
|
36 |
+
def save_image(self, images, path, quality):
|
37 |
+
filepaht = path.split('\n')[0]
|
38 |
+
format = os.path.splitext(filepaht)[1][1:]
|
39 |
+
image = images[0]
|
40 |
+
img = Image.fromarray((255. * image.cpu().numpy()).astype(np.uint8))
|
41 |
+
save_image(img, filepaht, format, quality)
|
42 |
+
return {}
|
43 |
+
|
44 |
+
class SaveImagesNode:
|
45 |
+
@classmethod
|
46 |
+
def INPUT_TYPES(cls):
|
47 |
+
return {
|
48 |
+
"required": {
|
49 |
+
"images": ("IMAGE",),
|
50 |
+
"path": ("STRING", {"multiline": False, "dynamicPrompts": False}),
|
51 |
+
"prefix": ("STRING", {"default": "image"}),
|
52 |
+
"format": (["PNG", "JPG", "WEBP", "BMP"],),
|
53 |
+
"quality": ([100, 95, 90, 85, 80, 75, 70, 60, 50], {"default": 100}),
|
54 |
+
}
|
55 |
+
}
|
56 |
+
RETURN_TYPES = ()
|
57 |
+
FUNCTION = "save_image"
|
58 |
+
CATEGORY = "tbox/Image"
|
59 |
+
OUTPUT_NODE = True
|
60 |
+
|
61 |
+
def save_image(self, images, path, prefix, format, quality):
|
62 |
+
format = format.lower()
|
63 |
+
for i, image in enumerate(images):
|
64 |
+
img = Image.fromarray((255. * image.cpu().numpy()).astype(np.uint8))
|
65 |
+
filepath = self.generate_filename(path, prefix, i, format)
|
66 |
+
save_image(img, filepath, format, quality)
|
67 |
+
return {}
|
68 |
+
|
69 |
+
def IS_CHANGED(s, images):
|
70 |
+
return time.time()
|
71 |
+
|
72 |
+
def generate_filename(self, save_dir, prefix, index, format):
|
73 |
+
base_filename = f"{prefix}_{index+1}.{format}"
|
74 |
+
filename = os.path.join(save_dir, base_filename)
|
75 |
+
counter = 1
|
76 |
+
while os.path.exists(filename):
|
77 |
+
filename = os.path.join(save_dir, f"{prefix}_{index+1}_{counter}.{format}")
|
78 |
+
counter += 1
|
79 |
+
return filename
|
custom_nodes/ComfyUI-tbox/nodes/image/size_node.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import comfy.utils
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image, ImageSequence, ImageOps
|
7 |
+
|
8 |
+
class ConstrainImageNode:
|
9 |
+
"""
|
10 |
+
A node that constrains an image to a maximum and minimum size while maintaining aspect ratio.
|
11 |
+
"""
|
12 |
+
|
13 |
+
@classmethod
|
14 |
+
def INPUT_TYPES(cls):
|
15 |
+
return {
|
16 |
+
"required": {
|
17 |
+
"images": ("IMAGE",),
|
18 |
+
"max_width": ("INT", {"default": 1024, "min": 0}),
|
19 |
+
"max_height": ("INT", {"default": 1024, "min": 0}),
|
20 |
+
"min_width": ("INT", {"default": 0, "min": 0}),
|
21 |
+
"min_height": ("INT", {"default": 0, "min": 0}),
|
22 |
+
"crop_if_required": (["yes", "no"], {"default": "no"}),
|
23 |
+
},
|
24 |
+
}
|
25 |
+
|
26 |
+
RETURN_TYPES = ("IMAGE",)
|
27 |
+
FUNCTION = "constrain_image"
|
28 |
+
CATEGORY = "tbox/Image"
|
29 |
+
OUTPUT_IS_LIST = (True,)
|
30 |
+
|
31 |
+
def constrain_image(self, images, max_width, max_height, min_width, min_height, crop_if_required):
|
32 |
+
crop_if_required = crop_if_required == "yes"
|
33 |
+
results = []
|
34 |
+
for image in images:
|
35 |
+
i = 255. * image.cpu().numpy()
|
36 |
+
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)).convert("RGB")
|
37 |
+
|
38 |
+
current_width, current_height = img.size
|
39 |
+
aspect_ratio = current_width / current_height
|
40 |
+
|
41 |
+
constrained_width = max(min(current_width, min_width), max_width)
|
42 |
+
constrained_height = max(min(current_height, min_height), max_height)
|
43 |
+
|
44 |
+
if constrained_width / constrained_height > aspect_ratio:
|
45 |
+
constrained_width = max(int(constrained_height * aspect_ratio), min_width)
|
46 |
+
if crop_if_required:
|
47 |
+
constrained_height = int(current_height / (current_width / constrained_width))
|
48 |
+
else:
|
49 |
+
constrained_height = max(int(constrained_width / aspect_ratio), min_height)
|
50 |
+
if crop_if_required:
|
51 |
+
constrained_width = int(current_width / (current_height / constrained_height))
|
52 |
+
|
53 |
+
resized_image = img.resize((constrained_width, constrained_height), Image.LANCZOS)
|
54 |
+
|
55 |
+
if crop_if_required and (constrained_width > max_width or constrained_height > max_height):
|
56 |
+
left = max((constrained_width - max_width) // 2, 0)
|
57 |
+
top = max((constrained_height - max_height) // 2, 0)
|
58 |
+
right = min(constrained_width, max_width) + left
|
59 |
+
bottom = min(constrained_height, max_height) + top
|
60 |
+
resized_image = resized_image.crop((left, top, right, bottom))
|
61 |
+
|
62 |
+
resized_image = np.array(resized_image).astype(np.float32) / 255.0
|
63 |
+
resized_image = torch.from_numpy(resized_image)[None,]
|
64 |
+
results.append(resized_image)
|
65 |
+
|
66 |
+
return (results,)
|
67 |
+
|
68 |
+
# https://github.com/bronkula/comfyui-fitsize
|
69 |
+
class ImageSizeNode:
|
70 |
+
@classmethod
|
71 |
+
def INPUT_TYPES(cls):
|
72 |
+
return {
|
73 |
+
"required": {
|
74 |
+
"image": ("IMAGE", ),
|
75 |
+
}
|
76 |
+
}
|
77 |
+
|
78 |
+
RETURN_TYPES = ("INT", "INT", "INT")
|
79 |
+
RETURN_NAMES = ("width", "height", "count")
|
80 |
+
FUNCTION = "get_size"
|
81 |
+
CATEGORY = "tbox/Image"
|
82 |
+
def get_size(self, image):
|
83 |
+
print(f'shape of image:{image.shape}')
|
84 |
+
return (image.shape[2], image.shape[1], image[0])
|
85 |
+
|
86 |
+
|
87 |
+
|
88 |
+
class ImageResizeNode:
|
89 |
+
@classmethod
|
90 |
+
def INPUT_TYPES(cls):
|
91 |
+
return {
|
92 |
+
"required": {
|
93 |
+
"image": ("IMAGE", ),
|
94 |
+
"method": (["nearest", "bilinear", "bicubic", "area", "nearest-exact", "lanczos"],),
|
95 |
+
},
|
96 |
+
"optional": {
|
97 |
+
"width": ("INT,FLOAT", { "default": 0.0, "step": 0.1 }),
|
98 |
+
"height": ("INT,FLOAT", { "default": 0.0, "step": 0.1 }),
|
99 |
+
},
|
100 |
+
}
|
101 |
+
|
102 |
+
RETURN_TYPES = ("IMAGE",)
|
103 |
+
|
104 |
+
FUNCTION = "resize"
|
105 |
+
CATEGORY = "tbox/Image"
|
106 |
+
|
107 |
+
def resize(self, image, method, width, height):
|
108 |
+
print(f'shape of image:{image.shape}, resolution:{width}x{height} type: {type(width)}, {type(height)}')
|
109 |
+
if width == 0 and height == 0:
|
110 |
+
s = image
|
111 |
+
else:
|
112 |
+
samples = image.movedim(-1,1)
|
113 |
+
if width == 0:
|
114 |
+
width = max(1, round(samples.shape[3] * height / samples.shape[2]))
|
115 |
+
elif height == 0:
|
116 |
+
height = max(1, round(samples.shape[2] * width / samples.shape[3]))
|
117 |
+
|
118 |
+
s = comfy.utils.common_upscale(samples, width, height, method, True)
|
119 |
+
s = s.movedim(1,-1)
|
120 |
+
return (s,)
|
121 |
+
|
custom_nodes/ComfyUI-tbox/nodes/image/watermark_node.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image, ImageSequence, ImageOps
|
5 |
+
from ..utils import tensor2pil, pil2tensor
|
6 |
+
|
7 |
+
PADDING = 4
|
8 |
+
|
9 |
+
class WatermarkNode:
|
10 |
+
|
11 |
+
@classmethod
|
12 |
+
def INPUT_TYPES(cls):
|
13 |
+
return {
|
14 |
+
"required": {
|
15 |
+
"images": ("IMAGE",),
|
16 |
+
"logo_list": ("IMAGE",),
|
17 |
+
},
|
18 |
+
"optional": {
|
19 |
+
"logo_mask": ("MASK",),
|
20 |
+
"enabled": ("BOOLEAN", {"default": True}),}
|
21 |
+
}
|
22 |
+
RETURN_TYPES = ("IMAGE",)
|
23 |
+
FUNCTION = "watermark"
|
24 |
+
CATEGORY = "tbox/Image"
|
25 |
+
|
26 |
+
def watermark(self, images, logo_list, logo_mask, enabled):
|
27 |
+
outputs = []
|
28 |
+
if enabled == False:
|
29 |
+
return(images,)
|
30 |
+
print(f'logo shape: {logo_list.shape}')
|
31 |
+
print(f'images shape: {images.shape}')
|
32 |
+
logo = tensor2pil(logo_list[0])
|
33 |
+
if logo_mask is not None:
|
34 |
+
logo_mask = tensor2pil(logo_mask)
|
35 |
+
for i, image in enumerate(images):
|
36 |
+
img = tensor2pil(image) #Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
|
37 |
+
dst = self.add_watermark2(img, logo, logo_mask, 85)
|
38 |
+
result = pil2tensor(dst)
|
39 |
+
outputs.append(result)
|
40 |
+
base_image = torch.stack([tensor.squeeze() for tensor in outputs])
|
41 |
+
return (base_image,)
|
42 |
+
|
43 |
+
def add_watermark2(self, image, logo, logo_mask, opacity=None):
|
44 |
+
logo_width, logo_height = logo.size
|
45 |
+
image_width, image_height = image.size
|
46 |
+
if image_height <= logo_height + PADDING * 2 or image_width <= logo_width + PADDING * 2:
|
47 |
+
return image
|
48 |
+
y = image_height - logo_height - PADDING * 1
|
49 |
+
x = PADDING
|
50 |
+
logo = logo.convert('RGBA')
|
51 |
+
opacity = int(opacity / 100 * 255)
|
52 |
+
logo.putalpha(Image.new("L", logo.size, opacity))
|
53 |
+
if logo_mask is not None:
|
54 |
+
logo.putalpha(ImageOps.invert(logo_mask))
|
55 |
+
|
56 |
+
position = (x, y)
|
57 |
+
image.paste(logo, position, logo)
|
58 |
+
return image
|
custom_nodes/ComfyUI-tbox/nodes/mask/mask_node.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
class MaskSubNode:
|
6 |
+
@classmethod
|
7 |
+
def INPUT_TYPES(cls):
|
8 |
+
return {
|
9 |
+
"required": {
|
10 |
+
"mask": ("MASK",),
|
11 |
+
},
|
12 |
+
"optional": {
|
13 |
+
"src1": ("MASK",),
|
14 |
+
"src2": ("MASK",),
|
15 |
+
"src3": ("MASK",),
|
16 |
+
"src4": ("MASK",),
|
17 |
+
"src5": ("MASK",),
|
18 |
+
"src6": ("MASK",),
|
19 |
+
}
|
20 |
+
}
|
21 |
+
|
22 |
+
CATEGORY = "mask"
|
23 |
+
RETURN_TYPES = ("MASK",)
|
24 |
+
|
25 |
+
FUNCTION = "sub"
|
26 |
+
CATEGORY = "tbox/Mask"
|
27 |
+
|
28 |
+
def sub_mask(self, dst, src):
|
29 |
+
if src != None:
|
30 |
+
mask = src.reshape((-1, src.shape[-2], src.shape[-1]))
|
31 |
+
return dst - mask
|
32 |
+
return dst
|
33 |
+
|
34 |
+
def add(self, mask, src1=None, src2=None, src3=None, src4=None, src5=None, src6=None):
|
35 |
+
print(f'mask shape: {mask.shape}')
|
36 |
+
output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone()
|
37 |
+
output[:, :, :] = self.sub_mask(output, src1)
|
38 |
+
output[:, :, :] = self.sub_mask(output, src2)
|
39 |
+
output[:, :, :] = self.sub_mask(output, src3)
|
40 |
+
output[:, :, :] = self.sub_mask(output, src4)
|
41 |
+
output[:, :, :] = self.sub_mask(output, src5)
|
42 |
+
output[:, :, :] = self.sub_mask(output, src6)
|
43 |
+
output = torch.clamp(output, 0.0, 1.0)
|
44 |
+
return (output, )
|
45 |
+
|
46 |
+
class MaskAddNode:
|
47 |
+
@classmethod
|
48 |
+
def INPUT_TYPES(cls):
|
49 |
+
return {
|
50 |
+
"required": {
|
51 |
+
"mask": ("MASK",),
|
52 |
+
},
|
53 |
+
"optional": {
|
54 |
+
"src1": ("MASK",),
|
55 |
+
"src2": ("MASK",),
|
56 |
+
"src3": ("MASK",),
|
57 |
+
"src4": ("MASK",),
|
58 |
+
"src5": ("MASK",),
|
59 |
+
"src6": ("MASK",),
|
60 |
+
}
|
61 |
+
}
|
62 |
+
|
63 |
+
CATEGORY = "mask"
|
64 |
+
RETURN_TYPES = ("MASK",)
|
65 |
+
|
66 |
+
FUNCTION = "add"
|
67 |
+
CATEGORY = "tbox/Mask"
|
68 |
+
|
69 |
+
def add_mask(self, dst, src):
|
70 |
+
if src != None:
|
71 |
+
mask = src.reshape((-1, src.shape[-2], src.shape[-1]))
|
72 |
+
return dst + mask
|
73 |
+
return dst
|
74 |
+
|
75 |
+
def add(self, mask, src1=None, src2=None, src3=None, src4=None, src5=None, src6=None):
|
76 |
+
print(f'mask shape: {mask.shape}')
|
77 |
+
output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone()
|
78 |
+
output[:, :, :] = self.add_mask(output, src1)
|
79 |
+
output[:, :, :] = self.add_mask(output, src2)
|
80 |
+
output[:, :, :] = self.add_mask(output, src3)
|
81 |
+
output[:, :, :] = self.add_mask(output, src4)
|
82 |
+
output[:, :, :] = self.add_mask(output, src5)
|
83 |
+
output[:, :, :] = self.add_mask(output, src6)
|
84 |
+
output = torch.clamp(output, 0.0, 1.0)
|
85 |
+
return (output, )
|
86 |
+
|
custom_nodes/ComfyUI-tbox/nodes/other/vram_node.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import torch.cuda
|
3 |
+
import comfy.model_management
|
4 |
+
|
5 |
+
class AnyType(str):
|
6 |
+
"""A special class that is always equal in not equal comparisons. Credit to pythongosssss"""
|
7 |
+
def __eq__(self, __value: object) -> bool:
|
8 |
+
return True
|
9 |
+
def __ne__(self, __value: object) -> bool:
|
10 |
+
return False
|
11 |
+
|
12 |
+
any = AnyType("*")
|
13 |
+
|
14 |
+
class PurgeVRAMNode:
|
15 |
+
|
16 |
+
def __init__(self):
|
17 |
+
pass
|
18 |
+
|
19 |
+
@classmethod
|
20 |
+
def INPUT_TYPES(cls):
|
21 |
+
return {
|
22 |
+
"required": {
|
23 |
+
"anything": (any, {}),
|
24 |
+
"purge_cache": ("BOOLEAN", {"default": True}),
|
25 |
+
"purge_models": ("BOOLEAN", {"default": True}),
|
26 |
+
},
|
27 |
+
"optional": {
|
28 |
+
}
|
29 |
+
}
|
30 |
+
|
31 |
+
RETURN_TYPES = ()
|
32 |
+
FUNCTION = "purge_vram"
|
33 |
+
CATEGORY = "tbox/other"
|
34 |
+
OUTPUT_NODE = True
|
35 |
+
|
36 |
+
def purge_vram(self, anything, purge_cache, purge_models):
|
37 |
+
|
38 |
+
gc.collect()
|
39 |
+
if torch.cuda.is_available():
|
40 |
+
torch.cuda.empty_cache()
|
41 |
+
torch.cuda.ipc_collect()
|
42 |
+
if purge_models:
|
43 |
+
comfy.model_management.unload_all_models()
|
44 |
+
comfy.model_management.soft_empty_cache()
|
45 |
+
return (None,)
|
custom_nodes/ComfyUI-tbox/nodes/preprocessor/canny_node.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..utils import common_annotator_call, create_node_input_types
|
2 |
+
import comfy.model_management as model_management
|
3 |
+
import nodes
|
4 |
+
|
5 |
+
class Canny_Preprocessor:
|
6 |
+
@classmethod
|
7 |
+
def INPUT_TYPES(s):
|
8 |
+
return create_node_input_types(
|
9 |
+
low_threshold=("INT", {"default": 100, "min": 0, "max": 255}),
|
10 |
+
high_threshold=("INT", {"default": 100, "min": 0, "max": 255}),
|
11 |
+
resolution=("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 64})
|
12 |
+
)
|
13 |
+
|
14 |
+
RETURN_TYPES = ("IMAGE",)
|
15 |
+
FUNCTION = "execute"
|
16 |
+
|
17 |
+
CATEGORY = "tbox/ControlNet Preprocessors"
|
18 |
+
|
19 |
+
def execute(self, image, low_threshold=100, high_threshold=200, resolution=512, **kwargs):
|
20 |
+
from canny import CannyDetector
|
21 |
+
|
22 |
+
return (common_annotator_call(CannyDetector(), image, low_threshold=low_threshold, high_threshold=high_threshold, resolution=resolution), )
|
custom_nodes/ComfyUI-tbox/nodes/preprocessor/densepose_node.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..utils import common_annotator_call, create_node_input_types
|
2 |
+
import comfy.model_management as model_management
|
3 |
+
|
4 |
+
class DensePose_Preprocessor:
|
5 |
+
@classmethod
|
6 |
+
def INPUT_TYPES(s):
|
7 |
+
return create_node_input_types(
|
8 |
+
model=(["densepose_r50_fpn_dl.torchscript", "densepose_r101_fpn_dl.torchscript"], {"default": "densepose_r50_fpn_dl.torchscript"}),
|
9 |
+
cmap=(["Viridis (MagicAnimate)", "Parula (CivitAI)"], {"default": "Viridis (MagicAnimate)"})
|
10 |
+
)
|
11 |
+
|
12 |
+
RETURN_TYPES = ("IMAGE",)
|
13 |
+
FUNCTION = "execute"
|
14 |
+
|
15 |
+
CATEGORY = "tbox/ControlNet Preprocessors"
|
16 |
+
|
17 |
+
def execute(self, image, model, cmap, resolution=512):
|
18 |
+
from densepose import DenseposeDetector
|
19 |
+
model = DenseposeDetector \
|
20 |
+
.from_pretrained(filename=model) \
|
21 |
+
.to(model_management.get_torch_device())
|
22 |
+
return (common_annotator_call(model, image, cmap="viridis" if "Viridis" in cmap else "parula", resolution=resolution), )
|
custom_nodes/ComfyUI-tbox/nodes/preprocessor/dwpose_node.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..utils import common_annotator_call, create_node_input_types
|
2 |
+
import comfy.model_management as model_management
|
3 |
+
import numpy as np
|
4 |
+
import warnings
|
5 |
+
from dwpose import DwposeDetector, AnimalposeDetector
|
6 |
+
import os
|
7 |
+
import json
|
8 |
+
|
9 |
+
|
10 |
+
DWPOSE_MODEL_NAME = "yzd-v/DWPose"
|
11 |
+
#Trigger startup caching for onnxruntime
|
12 |
+
GPU_PROVIDERS = ["CUDAExecutionProvider", "DirectMLExecutionProvider", "OpenVINOExecutionProvider", "ROCMExecutionProvider", "CoreMLExecutionProvider"]
|
13 |
+
|
14 |
+
def check_ort_gpu():
|
15 |
+
try:
|
16 |
+
import onnxruntime as ort
|
17 |
+
for provider in GPU_PROVIDERS:
|
18 |
+
if provider in ort.get_available_providers():
|
19 |
+
return True
|
20 |
+
return False
|
21 |
+
except:
|
22 |
+
return False
|
23 |
+
|
24 |
+
if not os.environ.get("DWPOSE_ONNXRT_CHECKED"):
|
25 |
+
if check_ort_gpu():
|
26 |
+
print("DWPose: Onnxruntime with acceleration providers detected")
|
27 |
+
else:
|
28 |
+
warnings.warn("DWPose: Onnxruntime not found or doesn't come with acceleration providers, switch to OpenCV with CPU device. DWPose might run very slowly")
|
29 |
+
os.environ['AUX_ORT_PROVIDERS'] = ''
|
30 |
+
os.environ["DWPOSE_ONNXRT_CHECKED"] = '1'
|
31 |
+
|
32 |
+
class DWPose_Preprocessor:
|
33 |
+
@classmethod
|
34 |
+
def INPUT_TYPES(s):
|
35 |
+
input_types = create_node_input_types(
|
36 |
+
detect_hand=(["enable", "disable"], {"default": "enable"}),
|
37 |
+
detect_body=(["enable", "disable"], {"default": "enable"}),
|
38 |
+
detect_face=(["enable", "disable"], {"default": "enable"})
|
39 |
+
)
|
40 |
+
input_types["optional"] = {
|
41 |
+
**input_types["optional"],
|
42 |
+
"bbox_detector": (
|
43 |
+
["yolox_l.torchscript.pt", "yolox_l.onnx", "yolo_nas_l_fp16.onnx", "yolo_nas_m_fp16.onnx", "yolo_nas_s_fp16.onnx"],
|
44 |
+
{"default": "yolox_l.onnx"}
|
45 |
+
),
|
46 |
+
"pose_estimator": (["dw-ll_ucoco_384_bs5.torchscript.pt", "dw-ll_ucoco_384.onnx", "dw-ll_ucoco.onnx"], {"default": "dw-ll_ucoco_384_bs5.torchscript.pt"})
|
47 |
+
}
|
48 |
+
return input_types
|
49 |
+
|
50 |
+
RETURN_TYPES = ("IMAGE", "POSE_KEYPOINT")
|
51 |
+
FUNCTION = "estimate_pose"
|
52 |
+
|
53 |
+
CATEGORY = "tbox/ControlNet Preprocessors"
|
54 |
+
|
55 |
+
def estimate_pose(self, image, detect_hand, detect_body, detect_face, resolution=512, bbox_detector="yolox_l.onnx", pose_estimator="dw-ll_ucoco_384.onnx", **kwargs):
|
56 |
+
if bbox_detector == "yolox_l.onnx":
|
57 |
+
yolo_repo = DWPOSE_MODEL_NAME
|
58 |
+
elif "yolox" in bbox_detector:
|
59 |
+
yolo_repo = "hr16/yolox-onnx"
|
60 |
+
elif "yolo_nas" in bbox_detector:
|
61 |
+
yolo_repo = "hr16/yolo-nas-fp16"
|
62 |
+
else:
|
63 |
+
raise NotImplementedError(f"Download mechanism for {bbox_detector}")
|
64 |
+
|
65 |
+
if pose_estimator == "dw-ll_ucoco_384.onnx":
|
66 |
+
pose_repo = DWPOSE_MODEL_NAME
|
67 |
+
elif pose_estimator.endswith(".onnx"):
|
68 |
+
pose_repo = "hr16/UnJIT-DWPose"
|
69 |
+
elif pose_estimator.endswith(".torchscript.pt"):
|
70 |
+
pose_repo = "hr16/DWPose-TorchScript-BatchSize5"
|
71 |
+
else:
|
72 |
+
raise NotImplementedError(f"Download mechanism for {pose_estimator}")
|
73 |
+
|
74 |
+
model = DwposeDetector.from_pretrained(
|
75 |
+
pose_repo,
|
76 |
+
yolo_repo,
|
77 |
+
det_filename=bbox_detector, pose_filename=pose_estimator,
|
78 |
+
torchscript_device=model_management.get_torch_device()
|
79 |
+
)
|
80 |
+
detect_hand = detect_hand == "enable"
|
81 |
+
detect_body = detect_body == "enable"
|
82 |
+
detect_face = detect_face == "enable"
|
83 |
+
self.openpose_dicts = []
|
84 |
+
def func(image, **kwargs):
|
85 |
+
pose_img, openpose_dict = model(image, **kwargs)
|
86 |
+
self.openpose_dicts.append(openpose_dict)
|
87 |
+
return pose_img
|
88 |
+
|
89 |
+
out = common_annotator_call(func, image, include_hand=detect_hand, include_face=detect_face, include_body=detect_body, image_and_json=True, resolution=resolution)
|
90 |
+
del model
|
91 |
+
return {
|
92 |
+
'ui': { "openpose_json": [json.dumps(self.openpose_dicts, indent=4)] },
|
93 |
+
"result": (out, self.openpose_dicts)
|
94 |
+
}
|
95 |
+
|
96 |
+
class AnimalPose_Preprocessor:
|
97 |
+
@classmethod
|
98 |
+
def INPUT_TYPES(s):
|
99 |
+
return create_node_input_types(
|
100 |
+
bbox_detector = (
|
101 |
+
["yolox_l.torchscript.pt", "yolox_l.onnx", "yolo_nas_l_fp16.onnx", "yolo_nas_m_fp16.onnx", "yolo_nas_s_fp16.onnx"],
|
102 |
+
{"default": "yolox_l.torchscript.pt"}
|
103 |
+
),
|
104 |
+
pose_estimator = (["rtmpose-m_ap10k_256_bs5.torchscript.pt", "rtmpose-m_ap10k_256.onnx"], {"default": "rtmpose-m_ap10k_256_bs5.torchscript.pt"})
|
105 |
+
)
|
106 |
+
|
107 |
+
RETURN_TYPES = ("IMAGE", "POSE_KEYPOINT")
|
108 |
+
FUNCTION = "estimate_pose"
|
109 |
+
|
110 |
+
CATEGORY = "tbox/ControlNet Preprocessors"
|
111 |
+
|
112 |
+
def estimate_pose(self, image, resolution=512, bbox_detector="yolox_l.onnx", pose_estimator="rtmpose-m_ap10k_256.onnx", **kwargs):
|
113 |
+
if bbox_detector == "yolox_l.onnx":
|
114 |
+
yolo_repo = DWPOSE_MODEL_NAME
|
115 |
+
elif "yolox" in bbox_detector:
|
116 |
+
yolo_repo = "hr16/yolox-onnx"
|
117 |
+
elif "yolo_nas" in bbox_detector:
|
118 |
+
yolo_repo = "hr16/yolo-nas-fp16"
|
119 |
+
else:
|
120 |
+
raise NotImplementedError(f"Download mechanism for {bbox_detector}")
|
121 |
+
|
122 |
+
if pose_estimator == "dw-ll_ucoco_384.onnx":
|
123 |
+
pose_repo = DWPOSE_MODEL_NAME
|
124 |
+
elif pose_estimator.endswith(".onnx"):
|
125 |
+
pose_repo = "hr16/UnJIT-DWPose"
|
126 |
+
elif pose_estimator.endswith(".torchscript.pt"):
|
127 |
+
pose_repo = "hr16/DWPose-TorchScript-BatchSize5"
|
128 |
+
else:
|
129 |
+
raise NotImplementedError(f"Download mechanism for {pose_estimator}")
|
130 |
+
|
131 |
+
model = AnimalposeDetector.from_pretrained(
|
132 |
+
pose_repo,
|
133 |
+
yolo_repo,
|
134 |
+
det_filename=bbox_detector, pose_filename=pose_estimator,
|
135 |
+
torchscript_device=model_management.get_torch_device()
|
136 |
+
)
|
137 |
+
|
138 |
+
self.openpose_dicts = []
|
139 |
+
def func(image, **kwargs):
|
140 |
+
pose_img, openpose_dict = model(image, **kwargs)
|
141 |
+
self.openpose_dicts.append(openpose_dict)
|
142 |
+
return pose_img
|
143 |
+
|
144 |
+
out = common_annotator_call(func, image, image_and_json=True, resolution=resolution)
|
145 |
+
del model
|
146 |
+
return {
|
147 |
+
'ui': { "openpose_json": [json.dumps(self.openpose_dicts, indent=4)] },
|
148 |
+
"result": (out, self.openpose_dicts)
|
149 |
+
}
|
150 |
+
|
151 |
+
NODE_CLASS_MAPPINGS = {
|
152 |
+
"DWPreprocessor": DWPose_Preprocessor,
|
153 |
+
"AnimalPosePreprocessor": AnimalPose_Preprocessor
|
154 |
+
}
|
155 |
+
NODE_DISPLAY_NAME_MAPPINGS = {
|
156 |
+
"DWPreprocessor": "DWPose Estimator",
|
157 |
+
"AnimalPosePreprocessor": "AnimalPose Estimator (AP10K)"
|
158 |
+
}
|
custom_nodes/ComfyUI-tbox/nodes/preprocessor/lineart_node.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..utils import common_annotator_call, create_node_input_types
|
2 |
+
import comfy.model_management as model_management
|
3 |
+
import nodes
|
4 |
+
|
5 |
+
class LineArt_Preprocessor:
|
6 |
+
@classmethod
|
7 |
+
def INPUT_TYPES(s):
|
8 |
+
return create_node_input_types(
|
9 |
+
coarse=(["disable", "enable"], {"default": "enable"}),
|
10 |
+
resolution=("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 64})
|
11 |
+
)
|
12 |
+
|
13 |
+
RETURN_TYPES = ("IMAGE",)
|
14 |
+
FUNCTION = "execute"
|
15 |
+
|
16 |
+
CATEGORY = "tbox/ControlNet Preprocessors"
|
17 |
+
|
18 |
+
|
19 |
+
def execute(self, image, resolution=512, **kwargs):
|
20 |
+
from lineart import LineartDetector
|
21 |
+
|
22 |
+
model = LineartDetector.from_pretrained().to(model_management.get_torch_device())
|
23 |
+
out = common_annotator_call(model, image, resolution=resolution, coarse = kwargs["coarse"] == "enable")
|
24 |
+
del model
|
25 |
+
return (out, )
|
26 |
+
|
27 |
+
class Lineart_Standard_Preprocessor:
|
28 |
+
@classmethod
|
29 |
+
def INPUT_TYPES(s):
|
30 |
+
return create_node_input_types(
|
31 |
+
guassian_sigma=("FLOAT", {"default":6.0, "max": 100.0}),
|
32 |
+
intensity_threshold=("INT", {"default": 8, "max": 16}),
|
33 |
+
resolution=("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 64})
|
34 |
+
)
|
35 |
+
|
36 |
+
RETURN_TYPES = ("IMAGE",)
|
37 |
+
FUNCTION = "execute"
|
38 |
+
|
39 |
+
CATEGORY = "tbox/ControlNet Preprocessors"
|
40 |
+
|
41 |
+
|
42 |
+
def execute(self, image, guassian_sigma=6, intensity_threshold=8, resolution=512, **kwargs):
|
43 |
+
from lineart import LineartStandardDetector
|
44 |
+
return (common_annotator_call(LineartStandardDetector(), image, guassian_sigma=guassian_sigma, intensity_threshold=intensity_threshold, resolution=resolution), )
|
custom_nodes/ComfyUI-tbox/nodes/preprocessor/midas_node.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..utils import common_annotator_call, create_node_input_types
|
2 |
+
import comfy.model_management as model_management
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
class MIDAS_Depth_Map_Preprocessor:
|
6 |
+
@classmethod
|
7 |
+
def INPUT_TYPES(s):
|
8 |
+
return create_node_input_types(
|
9 |
+
a = ("FLOAT", {"default": np.pi * 2.0, "min": 0.0, "max": np.pi * 5.0, "step": 0.05}),
|
10 |
+
bg_threshold = ("FLOAT", {"default": 0.1, "min": 0, "max": 1, "step": 0.05})
|
11 |
+
)
|
12 |
+
|
13 |
+
RETURN_TYPES = ("IMAGE",)
|
14 |
+
FUNCTION = "execute"
|
15 |
+
|
16 |
+
CATEGORY = "tbox/ControlNet Preprocessors"
|
17 |
+
|
18 |
+
def execute(self, image, a, bg_threshold, resolution=512, **kwargs):
|
19 |
+
from midas import MidasDetector
|
20 |
+
|
21 |
+
# Ref: https://github.com/lllyasviel/ControlNet/blob/main/gradio_depth2image.py
|
22 |
+
model = MidasDetector.from_pretrained().to(model_management.get_torch_device())
|
23 |
+
out = common_annotator_call(model, image, resolution=resolution, a=a, bg_th=bg_threshold)
|
24 |
+
del model
|
25 |
+
return (out, )
|
custom_nodes/ComfyUI-tbox/nodes/utils.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import nodes
|
5 |
+
import server
|
6 |
+
import folder_paths
|
7 |
+
import numpy as np
|
8 |
+
from typing import Iterable
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
BIGMIN = -(2**53-1)
|
12 |
+
BIGMAX = (2**53-1)
|
13 |
+
|
14 |
+
DIMMAX = 8192
|
15 |
+
|
16 |
+
|
17 |
+
def tensor_to_int(tensor, bits):
|
18 |
+
#TODO: investigate benefit of rounding by adding 0.5 before clip/cast
|
19 |
+
tensor = tensor.cpu().numpy() * (2**bits-1)
|
20 |
+
return np.clip(tensor, 0, (2**bits-1))
|
21 |
+
def tensor_to_shorts(tensor):
|
22 |
+
return tensor_to_int(tensor, 16).astype(np.uint16)
|
23 |
+
def tensor_to_bytes(tensor):
|
24 |
+
return tensor_to_int(tensor, 8).astype(np.uint8)
|
25 |
+
def tensor2pil(x):
|
26 |
+
return Image.fromarray(np.clip(255. * x.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
|
27 |
+
def pil2tensor(image: Image.Image) -> torch.Tensor:
|
28 |
+
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
|
29 |
+
|
30 |
+
|
31 |
+
def is_url(url):
|
32 |
+
return url.split("://")[0] in ["http", "https"]
|
33 |
+
|
34 |
+
def strip_path(path):
|
35 |
+
#This leaves whitespace inside quotes and only a single "
|
36 |
+
#thus ' ""test"' -> '"test'
|
37 |
+
#consider path.strip(string.whitespace+"\"")
|
38 |
+
#or weightier re.fullmatch("[\\s\"]*(.+?)[\\s\"]*", path).group(1)
|
39 |
+
path = path.strip()
|
40 |
+
if path.startswith("\""):
|
41 |
+
path = path[1:]
|
42 |
+
if path.endswith("\""):
|
43 |
+
path = path[:-1]
|
44 |
+
return path
|
45 |
+
def hash_path(path):
|
46 |
+
if path is None:
|
47 |
+
return "input"
|
48 |
+
if is_url(path):
|
49 |
+
return "url"
|
50 |
+
return calculate_file_hash(strip_path(path))
|
51 |
+
|
52 |
+
# modified from https://stackoverflow.com/questions/22058048/hashing-a-file-in-python
|
53 |
+
def calculate_file_hash(filename: str, hash_every_n: int = 1):
|
54 |
+
#Larger video files were taking >.5 seconds to hash even when cached,
|
55 |
+
#so instead the modified time from the filesystem is used as a hash
|
56 |
+
h = hashlib.sha256()
|
57 |
+
h.update(filename.encode())
|
58 |
+
h.update(str(os.path.getmtime(filename)).encode())
|
59 |
+
return h.hexdigest()
|
60 |
+
|
61 |
+
def is_safe_path(path):
|
62 |
+
if "VHS_STRICT_PATHS" not in os.environ:
|
63 |
+
return True
|
64 |
+
basedir = os.path.abspath('.')
|
65 |
+
try:
|
66 |
+
common_path = os.path.commonpath([basedir, path])
|
67 |
+
except:
|
68 |
+
#Different drive on windows
|
69 |
+
return False
|
70 |
+
return common_path == basedir
|
71 |
+
|
72 |
+
def validate_path(path, allow_none=False, allow_url=True):
|
73 |
+
if path is None:
|
74 |
+
return allow_none
|
75 |
+
if is_url(path):
|
76 |
+
#Probably not feasible to check if url resolves here
|
77 |
+
if not allow_url:
|
78 |
+
return "URLs are unsupported for this path"
|
79 |
+
return is_safe_path(path)
|
80 |
+
if not os.path.isfile(strip_path(path)):
|
81 |
+
return "Invalid file path: {}".format(path)
|
82 |
+
return is_safe_path(path)
|
83 |
+
|
84 |
+
def common_annotator_call(model, tensor_image, input_batch=False, show_pbar=False, **kwargs):
|
85 |
+
if "detect_resolution" in kwargs:
|
86 |
+
del kwargs["detect_resolution"] #Prevent weird case?
|
87 |
+
|
88 |
+
if "resolution" in kwargs:
|
89 |
+
detect_resolution = kwargs["resolution"] if type(kwargs["resolution"]) == int and kwargs["resolution"] >= 64 else 512
|
90 |
+
del kwargs["resolution"]
|
91 |
+
else:
|
92 |
+
detect_resolution = 512
|
93 |
+
|
94 |
+
if input_batch:
|
95 |
+
np_images = np.asarray(tensor_image * 255., dtype=np.uint8)
|
96 |
+
np_results = model(np_images, output_type="np", detect_resolution=detect_resolution, **kwargs)
|
97 |
+
return torch.from_numpy(np_results.astype(np.float32) / 255.0)
|
98 |
+
|
99 |
+
batch_size = tensor_image.shape[0]
|
100 |
+
|
101 |
+
out_tensor = None
|
102 |
+
for i, image in enumerate(tensor_image):
|
103 |
+
np_image = np.asarray(image.cpu() * 255., dtype=np.uint8)
|
104 |
+
np_result = model(np_image, output_type="np", detect_resolution=detect_resolution, **kwargs)
|
105 |
+
out = torch.from_numpy(np_result.astype(np.float32) / 255.0)
|
106 |
+
if out_tensor is None:
|
107 |
+
out_tensor = torch.zeros(batch_size, *out.shape, dtype=torch.float32)
|
108 |
+
out_tensor[i] = out
|
109 |
+
|
110 |
+
return out_tensor
|
111 |
+
|
112 |
+
def create_node_input_types(**extra_kwargs):
|
113 |
+
return {
|
114 |
+
"required": {
|
115 |
+
"image": ("IMAGE",)
|
116 |
+
},
|
117 |
+
"optional": {
|
118 |
+
**extra_kwargs,
|
119 |
+
"resolution": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 64})
|
120 |
+
}
|
121 |
+
}
|
122 |
+
|
123 |
+
|
124 |
+
prompt_queue = server.PromptServer.instance.prompt_queue
|
125 |
+
def requeue_workflow_unchecked():
|
126 |
+
"""Requeues the current workflow without checking for multiple requeues"""
|
127 |
+
currently_running = prompt_queue.currently_running
|
128 |
+
print(f'requeue_workflow_unchecked >>>>>> ')
|
129 |
+
(_, _, prompt, extra_data, outputs_to_execute) = next(iter(currently_running.values()))
|
130 |
+
|
131 |
+
#Ensure batch_managers are marked stale
|
132 |
+
prompt = prompt.copy()
|
133 |
+
for uid in prompt:
|
134 |
+
if prompt[uid]['class_type'] == 'BatchManager':
|
135 |
+
prompt[uid]['inputs']['requeue'] = prompt[uid]['inputs'].get('requeue',0)+1
|
136 |
+
|
137 |
+
#execution.py has guards for concurrency, but server doesn't.
|
138 |
+
#TODO: Check that this won't be an issue
|
139 |
+
number = -server.PromptServer.instance.number
|
140 |
+
server.PromptServer.instance.number += 1
|
141 |
+
prompt_id = str(server.uuid.uuid4())
|
142 |
+
prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
|
143 |
+
print(f'requeue_workflow_unchecked <<<<<<<<<< prompt_id:{prompt_id}, number:{number}')
|
144 |
+
|
145 |
+
requeue_guard = [None, 0, 0, {}]
|
146 |
+
def requeue_workflow(requeue_required=(-1,True)):
|
147 |
+
assert(len(prompt_queue.currently_running) == 1)
|
148 |
+
global requeue_guard
|
149 |
+
(run_number, _, prompt, _, _) = next(iter(prompt_queue.currently_running.values()))
|
150 |
+
print(f'requeue_workflow >> run_number:{run_number}\n')
|
151 |
+
if requeue_guard[0] != run_number:
|
152 |
+
#Calculate a count of how many outputs are managed by a batch manager
|
153 |
+
managed_outputs=0
|
154 |
+
for bm_uid in prompt:
|
155 |
+
if prompt[bm_uid]['class_type'] == 'BatchManager':
|
156 |
+
for output_uid in prompt:
|
157 |
+
if prompt[output_uid]['class_type'] in ["VideoSaver"]:
|
158 |
+
for inp in prompt[output_uid]['inputs'].values():
|
159 |
+
if inp == [bm_uid, 0]:
|
160 |
+
managed_outputs+=1
|
161 |
+
requeue_guard = [run_number, 0, managed_outputs, {}]
|
162 |
+
requeue_guard[1] = requeue_guard[1]+1
|
163 |
+
requeue_guard[3][requeue_required[0]] = requeue_required[1]
|
164 |
+
if requeue_guard[1] == requeue_guard[2] and max(requeue_guard[3].values()):
|
165 |
+
requeue_workflow_unchecked()
|
custom_nodes/ComfyUI-tbox/nodes/video/batch_node.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import hashlib
|
3 |
+
import os
|
4 |
+
|
5 |
+
class BatchManagerNode:
|
6 |
+
def __init__(self, frames_per_batch=-1):
|
7 |
+
print("BatchNode init")
|
8 |
+
self.frames_per_batch = frames_per_batch
|
9 |
+
self.inputs = {}
|
10 |
+
self.outputs = {}
|
11 |
+
self.unique_id = None
|
12 |
+
self.has_closed_inputs = False
|
13 |
+
self.total_frames = float('inf')
|
14 |
+
|
15 |
+
def reset(self):
|
16 |
+
print("BatchNode reset")
|
17 |
+
self.close_inputs()
|
18 |
+
for key in self.outputs:
|
19 |
+
if getattr(self.outputs[key][-1], "gi_suspended", False):
|
20 |
+
try:
|
21 |
+
self.outputs[key][-1].send(None)
|
22 |
+
except StopIteration:
|
23 |
+
pass
|
24 |
+
self.__init__(self.frames_per_batch)
|
25 |
+
def has_open_inputs(self):
|
26 |
+
return len(self.inputs) > 0
|
27 |
+
def close_inputs(self):
|
28 |
+
for key in self.inputs:
|
29 |
+
if getattr(self.inputs[key][-1], "gi_suspended", False):
|
30 |
+
try:
|
31 |
+
self.inputs[key][-1].send(1)
|
32 |
+
except StopIteration:
|
33 |
+
pass
|
34 |
+
self.inputs = {}
|
35 |
+
|
36 |
+
@classmethod
|
37 |
+
def INPUT_TYPES(s):
|
38 |
+
return {
|
39 |
+
"required": {"frames_per_batch": ("INT", {"default": 16, "min": 1, "max": 128, "step": 1})},
|
40 |
+
"hidden": {"prompt": "PROMPT", "unique_id": "UNIQUE_ID"},
|
41 |
+
}
|
42 |
+
|
43 |
+
RETURN_TYPES = ("BatchManager",)
|
44 |
+
RETURN_NAMES = ("meta_batch",)
|
45 |
+
CATEGORY = "tbox/Video"
|
46 |
+
FUNCTION = "update_batch"
|
47 |
+
|
48 |
+
def update_batch(self, frames_per_batch, prompt=None, unique_id=None):
|
49 |
+
if unique_id is not None and prompt is not None:
|
50 |
+
requeue = prompt[unique_id]['inputs'].get('requeue', 0)
|
51 |
+
else:
|
52 |
+
requeue = 0
|
53 |
+
print(f'update_batch >> unique_id: {unique_id}; requeue: {requeue}')
|
54 |
+
if requeue == 0:
|
55 |
+
self.reset()
|
56 |
+
self.frames_per_batch = frames_per_batch
|
57 |
+
self.unique_id = unique_id
|
58 |
+
else:
|
59 |
+
num_batches = (self.total_frames+self.frames_per_batch-1)//frames_per_batch
|
60 |
+
print(f'Meta-Batch {requeue}/{num_batches}')
|
61 |
+
#onExecuted seems to not be called unless some message is sent
|
62 |
+
return (self,)
|
63 |
+
|
64 |
+
@classmethod
|
65 |
+
def IS_CHANGED(self, frames_per_batch, prompt=None, unique_id=None):
|
66 |
+
print(f"BatchManagerNode >>> IS_CHANGED : {result}")
|
67 |
+
random_bytes = os.urandom(32)
|
68 |
+
result = hashlib.sha256(random_bytes).hexdigest()
|
69 |
+
return result
|
custom_nodes/ComfyUI-tbox/nodes/video/ffmpeg.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import shutil
|
5 |
+
import subprocess
|
6 |
+
import folder_paths
|
7 |
+
from torch import Tensor
|
8 |
+
from collections.abc import Mapping
|
9 |
+
|
10 |
+
audio_extensions = ['mp3', 'mp4', 'wav', 'ogg']
|
11 |
+
video_extensions = ['webm', 'mp4', 'mov']
|
12 |
+
|
13 |
+
def ffmpeg_suitability(path):
|
14 |
+
try:
|
15 |
+
version = subprocess.run([path, "-version"], check=True,
|
16 |
+
capture_output=True).stdout.decode("utf-8")
|
17 |
+
except:
|
18 |
+
return 0
|
19 |
+
score = 0
|
20 |
+
#rough layout of the importance of various features
|
21 |
+
simple_criterion = [("libvpx", 20),("264",10), ("265",3),
|
22 |
+
("svtav1",5),("libopus", 1)]
|
23 |
+
for criterion in simple_criterion:
|
24 |
+
if version.find(criterion[0]) >= 0:
|
25 |
+
score += criterion[1]
|
26 |
+
#obtain rough compile year from copyright information
|
27 |
+
copyright_index = version.find('2000-2')
|
28 |
+
if copyright_index >= 0:
|
29 |
+
copyright_year = version[copyright_index+6:copyright_index+9]
|
30 |
+
if copyright_year.isnumeric():
|
31 |
+
score += int(copyright_year)
|
32 |
+
return score
|
33 |
+
|
34 |
+
folder_paths.folder_names_and_paths["VHS_video_formats"] = (
|
35 |
+
[
|
36 |
+
os.path.join(os.path.dirname(os.path.abspath(__file__)), "video_formats"),
|
37 |
+
],
|
38 |
+
[".json"]
|
39 |
+
)
|
40 |
+
|
41 |
+
|
42 |
+
if "VHS_FORCE_FFMPEG_PATH" in os.environ:
|
43 |
+
ffmpeg_path = os.environ.get("VHS_FORCE_FFMPEG_PATH")
|
44 |
+
else:
|
45 |
+
ffmpeg_paths = []
|
46 |
+
try:
|
47 |
+
from imageio_ffmpeg import get_ffmpeg_exe
|
48 |
+
imageio_ffmpeg_path = get_ffmpeg_exe()
|
49 |
+
ffmpeg_paths.append(imageio_ffmpeg_path)
|
50 |
+
except:
|
51 |
+
if "VHS_USE_IMAGEIO_FFMPEG" in os.environ:
|
52 |
+
raise
|
53 |
+
print("Failed to import imageio_ffmpeg")
|
54 |
+
if "VHS_USE_IMAGEIO_FFMPEG" in os.environ:
|
55 |
+
ffmpeg_path = imageio_ffmpeg_path
|
56 |
+
else:
|
57 |
+
system_ffmpeg = shutil.which("ffmpeg")
|
58 |
+
if system_ffmpeg is not None:
|
59 |
+
ffmpeg_paths.append(system_ffmpeg)
|
60 |
+
if os.path.isfile("ffmpeg"):
|
61 |
+
ffmpeg_paths.append(os.path.abspath("ffmpeg"))
|
62 |
+
if os.path.isfile("ffmpeg.exe"):
|
63 |
+
ffmpeg_paths.append(os.path.abspath("ffmpeg.exe"))
|
64 |
+
if len(ffmpeg_paths) == 0:
|
65 |
+
print("No valid ffmpeg found.")
|
66 |
+
ffmpeg_path = None
|
67 |
+
elif len(ffmpeg_paths) == 1:
|
68 |
+
#Evaluation of suitability isn't required, can take sole option
|
69 |
+
#to reduce startup time
|
70 |
+
ffmpeg_path = ffmpeg_paths[0]
|
71 |
+
else:
|
72 |
+
ffmpeg_path = max(ffmpeg_paths, key=ffmpeg_suitability)
|
73 |
+
|
74 |
+
gifski_path = os.environ.get("VHS_GIFSKI", None)
|
75 |
+
if gifski_path is None:
|
76 |
+
gifski_path = os.environ.get("JOV_GIFSKI", None)
|
77 |
+
if gifski_path is None:
|
78 |
+
gifski_path = shutil.which("gifski")
|
79 |
+
|
80 |
+
ytdl_path = os.environ.get("VHS_YTDL", None) or shutil.which('yt-dlp') \
|
81 |
+
or shutil.which('youtube-dl')
|
82 |
+
|
83 |
+
def get_audio(file, start_time=0, duration=0):
|
84 |
+
args = [ffmpeg_path, "-i", file]
|
85 |
+
if start_time > 0:
|
86 |
+
args += ["-ss", str(start_time)]
|
87 |
+
if duration > 0:
|
88 |
+
args += ["-t", str(duration)]
|
89 |
+
try:
|
90 |
+
#TODO: scan for sample rate and maintain
|
91 |
+
res = subprocess.run(args + ["-f", "f32le", "-"],
|
92 |
+
capture_output=True, check=True)
|
93 |
+
audio = torch.frombuffer(bytearray(res.stdout), dtype=torch.float32)
|
94 |
+
match = re.search(', (\\d+) Hz, (\\w+), ',res.stderr.decode('utf-8'))
|
95 |
+
except subprocess.CalledProcessError as e:
|
96 |
+
raise Exception(f"VHS failed to extract audio from {file}:\n" \
|
97 |
+
+ e.stderr.decode("utf-8"))
|
98 |
+
if match:
|
99 |
+
ar = int(match.group(1))
|
100 |
+
#NOTE: Just throwing an error for other channel types right now
|
101 |
+
#Will deal with issues if they come
|
102 |
+
ac = {"mono": 1, "stereo": 2}[match.group(2)]
|
103 |
+
else:
|
104 |
+
ar = 44100
|
105 |
+
ac = 2
|
106 |
+
audio = audio.reshape((-1,ac)).transpose(0,1).unsqueeze(0)
|
107 |
+
return {'waveform': audio, 'sample_rate': ar}
|
108 |
+
|
109 |
+
class LazyAudioMap(Mapping):
|
110 |
+
def __init__(self, file, start_time, duration):
|
111 |
+
self.file = file
|
112 |
+
self.start_time=start_time
|
113 |
+
self.duration=duration
|
114 |
+
self._dict=None
|
115 |
+
def __getitem__(self, key):
|
116 |
+
if self._dict is None:
|
117 |
+
self._dict = get_audio(self.file, self.start_time, self.duration)
|
118 |
+
return self._dict[key]
|
119 |
+
def __iter__(self):
|
120 |
+
if self._dict is None:
|
121 |
+
self._dict = get_audio(self.file, self.start_time, self.duration)
|
122 |
+
return iter(self._dict)
|
123 |
+
def __len__(self):
|
124 |
+
if self._dict is None:
|
125 |
+
self._dict = get_audio(self.file, self.start_time, self.duration)
|
126 |
+
return len(self._dict)
|
127 |
+
|
128 |
+
def lazy_get_audio(file, start_time=0, duration=0):
|
129 |
+
return LazyAudioMap(file, start_time, duration)
|
custom_nodes/ComfyUI-tbox/nodes/video/info_node.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
class VideoInfoNode:
|
4 |
+
@classmethod
|
5 |
+
def INPUT_TYPES(s):
|
6 |
+
return {
|
7 |
+
"required": {
|
8 |
+
"video_info": ("VHS_VIDEOINFO",),
|
9 |
+
}
|
10 |
+
}
|
11 |
+
|
12 |
+
CATEGORY = "tbox/Video"
|
13 |
+
|
14 |
+
RETURN_TYPES = ("FLOAT", "INT", "FLOAT", "INT", "INT", "FLOAT","INT", "FLOAT", "INT", "INT")
|
15 |
+
RETURN_NAMES = (
|
16 |
+
"source_fps",
|
17 |
+
"source_frame_count",
|
18 |
+
"source_duration",
|
19 |
+
"source_width",
|
20 |
+
"source_height",
|
21 |
+
"loaded_fps",
|
22 |
+
"loaded_frame_count",
|
23 |
+
"loaded_duration",
|
24 |
+
"loaded_width",
|
25 |
+
"loaded_height",
|
26 |
+
)
|
27 |
+
FUNCTION = "get_video_info"
|
28 |
+
|
29 |
+
def get_video_info(self, video_info):
|
30 |
+
keys = ["fps", "frame_count", "duration", "width", "height"]
|
31 |
+
|
32 |
+
source_info = []
|
33 |
+
loaded_info = []
|
34 |
+
|
35 |
+
for key in keys:
|
36 |
+
source_info.append(video_info[f"source_{key}"])
|
37 |
+
loaded_info.append(video_info[f"loaded_{key}"])
|
38 |
+
|
39 |
+
return (*source_info, *loaded_info)
|
custom_nodes/ComfyUI-tbox/nodes/video/load_node.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import cv2
|
4 |
+
import torch
|
5 |
+
import requests
|
6 |
+
import itertools
|
7 |
+
import folder_paths
|
8 |
+
import psutil
|
9 |
+
import numpy as np
|
10 |
+
from comfy.utils import common_upscale
|
11 |
+
from io import BytesIO
|
12 |
+
from PIL import Image, ImageSequence, ImageOps
|
13 |
+
from .ffmpeg import lazy_get_audio, video_extensions
|
14 |
+
from ..utils import BIGMAX, DIMMAX, strip_path, validate_path
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
def is_gif(filename) -> bool:
|
20 |
+
file_parts = filename.split('.')
|
21 |
+
return len(file_parts) > 1 and file_parts[-1] == "gif"
|
22 |
+
|
23 |
+
def target_size(width, height, force_size, custom_width, custom_height, downscale_ratio=8) -> tuple[int, int]:
|
24 |
+
if force_size == "Disabled":
|
25 |
+
pass
|
26 |
+
elif force_size == "Custom Width" or force_size.endswith('x?'):
|
27 |
+
height *= custom_width/width
|
28 |
+
width = custom_width
|
29 |
+
elif force_size == "Custom Height" or force_size.startswith('?x'):
|
30 |
+
width *= custom_height/height
|
31 |
+
height = custom_height
|
32 |
+
else:
|
33 |
+
width = custom_width
|
34 |
+
height = custom_height
|
35 |
+
width = int(width/downscale_ratio + 0.5) * downscale_ratio
|
36 |
+
height = int(height/downscale_ratio + 0.5) * downscale_ratio
|
37 |
+
return (width, height)
|
38 |
+
|
39 |
+
def cv_frame_generator(path, force_rate, frame_load_cap, skip_first_frames,
|
40 |
+
select_every_nth, meta_batch=None, unique_id=None):
|
41 |
+
video_cap = cv2.VideoCapture(strip_path(path))
|
42 |
+
if not video_cap.isOpened():
|
43 |
+
raise ValueError(f"{path} could not be loaded with cv.")
|
44 |
+
|
45 |
+
# extract video metadata
|
46 |
+
fps = video_cap.get(cv2.CAP_PROP_FPS)
|
47 |
+
width = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
48 |
+
height = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
49 |
+
total_frames = int(video_cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
50 |
+
duration = total_frames / fps
|
51 |
+
|
52 |
+
# set video_cap to look at start_index frame
|
53 |
+
total_frame_count = 0
|
54 |
+
total_frames_evaluated = -1
|
55 |
+
frames_added = 0
|
56 |
+
base_frame_time = 1 / fps
|
57 |
+
prev_frame = None
|
58 |
+
|
59 |
+
if force_rate == 0:
|
60 |
+
target_frame_time = base_frame_time
|
61 |
+
else:
|
62 |
+
target_frame_time = 1/force_rate
|
63 |
+
|
64 |
+
yield (width, height, fps, duration, total_frames, target_frame_time)
|
65 |
+
if total_frames > 0:
|
66 |
+
if force_rate != 0:
|
67 |
+
yieldable_frames = int(total_frames / fps * force_rate)
|
68 |
+
else:
|
69 |
+
yieldable_frames = total_frames
|
70 |
+
if frame_load_cap != 0:
|
71 |
+
yieldable_frames = min(frame_load_cap, yieldable_frames)
|
72 |
+
else:
|
73 |
+
yieldable_frames = 0
|
74 |
+
|
75 |
+
if meta_batch is not None:
|
76 |
+
yield yieldable_frames
|
77 |
+
|
78 |
+
time_offset=target_frame_time - base_frame_time
|
79 |
+
while video_cap.isOpened():
|
80 |
+
if time_offset < target_frame_time:
|
81 |
+
is_returned = video_cap.grab()
|
82 |
+
# if didn't return frame, video has ended
|
83 |
+
if not is_returned:
|
84 |
+
break
|
85 |
+
time_offset += base_frame_time
|
86 |
+
if time_offset < target_frame_time:
|
87 |
+
continue
|
88 |
+
time_offset -= target_frame_time
|
89 |
+
# if not at start_index, skip doing anything with frame
|
90 |
+
total_frame_count += 1
|
91 |
+
if total_frame_count <= skip_first_frames:
|
92 |
+
continue
|
93 |
+
else:
|
94 |
+
total_frames_evaluated += 1
|
95 |
+
|
96 |
+
# if should not be selected, skip doing anything with frame
|
97 |
+
if total_frames_evaluated%select_every_nth != 0:
|
98 |
+
continue
|
99 |
+
|
100 |
+
# opencv loads images in BGR format (yuck), so need to convert to RGB for ComfyUI use
|
101 |
+
# follow up: can videos ever have an alpha channel?
|
102 |
+
# To my testing: No. opencv has no support for alpha
|
103 |
+
unused, frame = video_cap.retrieve()
|
104 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
105 |
+
# convert frame to comfyui's expected format
|
106 |
+
# TODO: frame contains no exif information. Check if opencv2 has already applied
|
107 |
+
frame = np.array(frame, dtype=np.float32)
|
108 |
+
torch.from_numpy(frame).div_(255)
|
109 |
+
if prev_frame is not None:
|
110 |
+
inp = yield prev_frame
|
111 |
+
if inp is not None:
|
112 |
+
#ensure the finally block is called
|
113 |
+
return
|
114 |
+
prev_frame = frame
|
115 |
+
frames_added += 1
|
116 |
+
|
117 |
+
# if cap exists and we've reached it, stop processing frames
|
118 |
+
if frame_load_cap > 0 and frames_added >= frame_load_cap:
|
119 |
+
break
|
120 |
+
|
121 |
+
if meta_batch is not None:
|
122 |
+
meta_batch.inputs.pop(unique_id)
|
123 |
+
meta_batch.has_closed_inputs = True
|
124 |
+
if prev_frame is not None:
|
125 |
+
yield prev_frame
|
126 |
+
|
127 |
+
def batched(it, n):
|
128 |
+
while batch := tuple(itertools.islice(it, n)):
|
129 |
+
yield batch
|
130 |
+
|
131 |
+
def load_video_cv(path: str, force_rate: int, force_size: str,
|
132 |
+
custom_width: int,custom_height: int, frame_load_cap: int,
|
133 |
+
skip_first_frames: int, select_every_nth: int,
|
134 |
+
meta_batch=None, unique_id=None,
|
135 |
+
memory_limit_mb=None):
|
136 |
+
|
137 |
+
if meta_batch is None or unique_id not in meta_batch.inputs:
|
138 |
+
gen = cv_frame_generator(path, force_rate, frame_load_cap, skip_first_frames,
|
139 |
+
select_every_nth, meta_batch, unique_id)
|
140 |
+
(width, height, fps, duration, total_frames, target_frame_time) = next(gen)
|
141 |
+
|
142 |
+
if meta_batch is not None:
|
143 |
+
meta_batch.inputs[unique_id] = (gen, width, height, fps, duration, total_frames, target_frame_time)
|
144 |
+
yieldable_frames = next(gen)
|
145 |
+
if yieldable_frames:
|
146 |
+
meta_batch.total_frames = min(meta_batch.total_frames, yieldable_frames)
|
147 |
+
else:
|
148 |
+
(gen, width, height, fps, duration, total_frames, target_frame_time) = meta_batch.inputs[unique_id]
|
149 |
+
|
150 |
+
print(f'[{width}x{height}]@{fps} - duration:{duration}, total_frames: {total_frames}')
|
151 |
+
|
152 |
+
memory_limit = memory_limit_mb
|
153 |
+
if memory_limit_mb is not None:
|
154 |
+
memory_limit *= 2 ** 20
|
155 |
+
else:
|
156 |
+
#TODO: verify if garbage collection should be performed here.
|
157 |
+
#leaves ~128 MB unreserved for safety
|
158 |
+
try:
|
159 |
+
memory_limit = (psutil.virtual_memory().available + psutil.swap_memory().free) - 2 ** 27
|
160 |
+
except:
|
161 |
+
print("Failed to calculate available memory. Memory load limit has been disabled")
|
162 |
+
|
163 |
+
if memory_limit is not None:
|
164 |
+
#TODO: use better estimate for when vae is not None
|
165 |
+
#Consider completely ignoring for load_latent case?
|
166 |
+
max_loadable_frames = int(memory_limit//(width*height*3*(.1)))
|
167 |
+
|
168 |
+
if meta_batch is not None:
|
169 |
+
if meta_batch.frames_per_batch > max_loadable_frames:
|
170 |
+
raise RuntimeError(f"Meta Batch set to {meta_batch.frames_per_batch} frames but only {max_loadable_frames} can fit in memory")
|
171 |
+
gen = itertools.islice(gen, meta_batch.frames_per_batch)
|
172 |
+
else:
|
173 |
+
original_gen = gen
|
174 |
+
gen = itertools.islice(gen, max_loadable_frames)
|
175 |
+
|
176 |
+
downscale_ratio = 8
|
177 |
+
frames_per_batch = (1920 * 1080 * 16) // (width * height) or 1
|
178 |
+
if force_size != "Disabled":
|
179 |
+
new_size = target_size(width, height, force_size, custom_width, custom_height, downscale_ratio)
|
180 |
+
if new_size[0] != width or new_size[1] != height:
|
181 |
+
def rescale(frame):
|
182 |
+
s = torch.from_numpy(np.fromiter(frame, np.dtype((np.float32, (height, width, 3)))))
|
183 |
+
s = s.movedim(-1,1)
|
184 |
+
s = common_upscale(s, new_size[0], new_size[1], "lanczos", "center")
|
185 |
+
return s.movedim(1,-1).numpy()
|
186 |
+
gen = itertools.chain.from_iterable(map(rescale, batched(gen, frames_per_batch)))
|
187 |
+
else:
|
188 |
+
new_size = width, height
|
189 |
+
|
190 |
+
#Some minor wizardry to eliminate a copy and reduce max memory by a factor of ~2
|
191 |
+
images = torch.from_numpy(np.fromiter(gen, np.dtype((np.float32, (new_size[1], new_size[0], 3)))))
|
192 |
+
if meta_batch is None and memory_limit is not None:
|
193 |
+
try:
|
194 |
+
next(original_gen)
|
195 |
+
raise RuntimeError(f"Memory limit hit after loading {len(images)} frames. Stopping execution.")
|
196 |
+
except StopIteration:
|
197 |
+
pass
|
198 |
+
if len(images) == 0:
|
199 |
+
raise RuntimeError("No frames generated")
|
200 |
+
|
201 |
+
#Setup lambda for lazy audio capture
|
202 |
+
audio = lazy_get_audio(path, skip_first_frames * target_frame_time,
|
203 |
+
frame_load_cap*target_frame_time*select_every_nth)
|
204 |
+
#Adjust target_frame_time for select_every_nth
|
205 |
+
target_frame_time *= select_every_nth
|
206 |
+
video_info = {
|
207 |
+
"source_fps": fps,
|
208 |
+
"source_frame_count": total_frames,
|
209 |
+
"source_duration": duration,
|
210 |
+
"source_width": width,
|
211 |
+
"source_height": height,
|
212 |
+
"loaded_fps": 1/target_frame_time,
|
213 |
+
"loaded_frame_count": len(images),
|
214 |
+
"loaded_duration": len(images) * target_frame_time,
|
215 |
+
"loaded_width": new_size[0],
|
216 |
+
"loaded_height": new_size[1],
|
217 |
+
}
|
218 |
+
|
219 |
+
return (images, len(images), audio, video_info)
|
220 |
+
|
221 |
+
|
222 |
+
|
223 |
+
class LoadVideoNode:
|
224 |
+
@classmethod
|
225 |
+
def INPUT_TYPES(s):
|
226 |
+
return {
|
227 |
+
"required": {
|
228 |
+
"path": ("STRING", {"default": "/Users/wadahana/Desktop/live-motion2.mp4", "multiline": True, "vhs_path_extensions": video_extensions}),
|
229 |
+
"force_rate": ("INT", {"default": 0, "min": 0, "max": 60, "step": 1}),
|
230 |
+
"force_size": (["Disabled", "Custom Height", "Custom Width", "Custom", "256x?", "?x256", "256x256", "512x?", "?x512", "512x512"],),
|
231 |
+
"custom_width": ("INT", {"default": 512, "min": 0, "max": DIMMAX, "step": 8}),
|
232 |
+
"custom_height": ("INT", {"default": 512, "min": 0, "max": DIMMAX, "step": 8}),
|
233 |
+
"frame_load_cap": ("INT", {"default": 0, "min": 0, "max": BIGMAX, "step": 1}),
|
234 |
+
"skip_first_frames": ("INT", {"default": 0, "min": 0, "max": BIGMAX, "step": 1}),
|
235 |
+
"select_every_nth": ("INT", {"default": 1, "min": 1, "max": BIGMAX, "step": 1}),
|
236 |
+
},
|
237 |
+
"optional": {
|
238 |
+
"meta_batch": ("BatchManager",),
|
239 |
+
},
|
240 |
+
"hidden": {
|
241 |
+
"unique_id": "UNIQUE_ID"
|
242 |
+
},
|
243 |
+
}
|
244 |
+
|
245 |
+
CATEGORY = "tbox/Video"
|
246 |
+
|
247 |
+
RETURN_TYPES = ("IMAGE", "INT", "AUDIO", "VHS_VIDEOINFO")
|
248 |
+
RETURN_NAMES = ("IMAGE", "frame_count", "audio", "video_info")
|
249 |
+
|
250 |
+
FUNCTION = "load_video"
|
251 |
+
|
252 |
+
def load_video(self, **kwargs):
|
253 |
+
if kwargs['path'] is None :
|
254 |
+
raise Exception("video is not a valid path: " + kwargs['path'])
|
255 |
+
|
256 |
+
kwargs['path'] = kwargs['path'].split('\n')[0]
|
257 |
+
if validate_path(kwargs['path']) != True:
|
258 |
+
raise Exception("video is not a valid path: " + kwargs['path'])
|
259 |
+
# if is_url(kwargs['video']):
|
260 |
+
# kwargs['video'] = try_download_video(kwargs['video']) or kwargs['video']
|
261 |
+
return load_video_cv(**kwargs)
|
custom_nodes/ComfyUI-tbox/nodes/video/save_node.py
ADDED
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import cv2
|
4 |
+
import sys
|
5 |
+
import json
|
6 |
+
import torch
|
7 |
+
import datetime
|
8 |
+
import itertools
|
9 |
+
import subprocess
|
10 |
+
import folder_paths
|
11 |
+
import numpy as np
|
12 |
+
from string import Template
|
13 |
+
from pathlib import Path
|
14 |
+
from PIL import Image, ExifTags
|
15 |
+
from PIL.PngImagePlugin import PngInfo
|
16 |
+
from .ffmpeg import ffmpeg_path, gifski_path
|
17 |
+
from ..utils import tensor_to_bytes, tensor_to_shorts, requeue_workflow
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
def gen_format_widgets(video_format):
|
22 |
+
for k in video_format:
|
23 |
+
if k.endswith("_pass"):
|
24 |
+
for i in range(len(video_format[k])):
|
25 |
+
if isinstance(video_format[k][i], list):
|
26 |
+
item = [video_format[k][i]]
|
27 |
+
yield item
|
28 |
+
video_format[k][i] = item[0]
|
29 |
+
else:
|
30 |
+
if isinstance(video_format[k], list):
|
31 |
+
item = [video_format[k]]
|
32 |
+
yield item
|
33 |
+
video_format[k] = item[0]
|
34 |
+
|
35 |
+
def get_format_widget_defaults(format_name):
|
36 |
+
video_format_path = folder_paths.get_full_path("VHS_video_formats", format_name + ".json")
|
37 |
+
with open(video_format_path, 'r') as stream:
|
38 |
+
video_format = json.load(stream)
|
39 |
+
results = {}
|
40 |
+
for w in gen_format_widgets(video_format):
|
41 |
+
if len(w[0]) > 2 and 'default' in w[0][2]:
|
42 |
+
default = w[0][2]['default']
|
43 |
+
else:
|
44 |
+
if type(w[0][1]) is list:
|
45 |
+
default = w[0][1][0]
|
46 |
+
else:
|
47 |
+
#NOTE: This doesn't respect max/min, but should be good enough as a fallback to a fallback to a fallback
|
48 |
+
default = {"BOOLEAN": False, "INT": 0, "FLOAT": 0, "STRING": ""}[w[0][1]]
|
49 |
+
results[w[0][0]] = default
|
50 |
+
return results
|
51 |
+
|
52 |
+
def get_video_formats():
|
53 |
+
formats = []
|
54 |
+
for format_name in folder_paths.get_filename_list("VHS_video_formats"):
|
55 |
+
format_name = format_name[:-5]
|
56 |
+
formats.append("video/" + format_name)
|
57 |
+
return formats
|
58 |
+
|
59 |
+
def gifski_process(args, video_format, file_path, env):
|
60 |
+
frame_data = yield
|
61 |
+
with subprocess.Popen(args + video_format['main_pass'] + ['-f', 'yuv4mpegpipe', '-'],
|
62 |
+
stderr=subprocess.PIPE, stdin=subprocess.PIPE,
|
63 |
+
stdout=subprocess.PIPE, env=env) as procff:
|
64 |
+
with subprocess.Popen([gifski_path] + video_format['gifski_pass']
|
65 |
+
+ ['-q', '-o', file_path, '-'], stderr=subprocess.PIPE,
|
66 |
+
stdin=procff.stdout, stdout=subprocess.PIPE,
|
67 |
+
env=env) as procgs:
|
68 |
+
try:
|
69 |
+
while frame_data is not None:
|
70 |
+
procff.stdin.write(frame_data)
|
71 |
+
frame_data = yield
|
72 |
+
procff.stdin.flush()
|
73 |
+
procff.stdin.close()
|
74 |
+
resff = procff.stderr.read()
|
75 |
+
resgs = procgs.stderr.read()
|
76 |
+
outgs = procgs.stdout.read()
|
77 |
+
except BrokenPipeError as e:
|
78 |
+
procff.stdin.close()
|
79 |
+
resff = procff.stderr.read()
|
80 |
+
resgs = procgs.stderr.read()
|
81 |
+
raise Exception("An error occurred while creating gifski output\n" \
|
82 |
+
+ "Make sure you are using gifski --version >=1.32.0\nffmpeg: " \
|
83 |
+
+ resff.decode("utf-8") + '\ngifski: ' + resgs.decode("utf-8"))
|
84 |
+
if len(resff) > 0:
|
85 |
+
print(resff.decode("utf-8"), end="", file=sys.stderr)
|
86 |
+
if len(resgs) > 0:
|
87 |
+
print(resgs.decode("utf-8"), end="", file=sys.stderr)
|
88 |
+
#should always be empty as the quiet flag is passed
|
89 |
+
if len(outgs) > 0:
|
90 |
+
print(outgs.decode("utf-8"))
|
91 |
+
|
92 |
+
def ffmpeg_process(args, video_format, video_metadata, file_path, env):
|
93 |
+
|
94 |
+
res = None
|
95 |
+
frame_data = yield
|
96 |
+
total_frames_output = 0
|
97 |
+
if video_format.get('save_metadata', 'False') != 'False':
|
98 |
+
os.makedirs(folder_paths.get_temp_directory(), exist_ok=True)
|
99 |
+
metadata = json.dumps(video_metadata)
|
100 |
+
metadata_path = os.path.join(folder_paths.get_temp_directory(), "metadata.txt")
|
101 |
+
#metadata from file should escape = ; # \ and newline
|
102 |
+
metadata = metadata.replace("\\","\\\\")
|
103 |
+
metadata = metadata.replace(";","\\;")
|
104 |
+
metadata = metadata.replace("#","\\#")
|
105 |
+
metadata = metadata.replace("=","\\=")
|
106 |
+
metadata = metadata.replace("\n","\\\n")
|
107 |
+
metadata = "comment=" + metadata
|
108 |
+
with open(metadata_path, "w") as f:
|
109 |
+
f.write(";FFMETADATA1\n")
|
110 |
+
f.write(metadata)
|
111 |
+
m_args = args[:1] + ["-i", metadata_path] + args[1:] + ["-metadata", "creation_time=now"]
|
112 |
+
print(f'ffmpeg: {m_args}')
|
113 |
+
with subprocess.Popen(m_args + [file_path], stderr=subprocess.PIPE,
|
114 |
+
stdin=subprocess.PIPE, env=env) as proc:
|
115 |
+
try:
|
116 |
+
while frame_data is not None:
|
117 |
+
proc.stdin.write(frame_data)
|
118 |
+
#TODO: skip flush for increased speed
|
119 |
+
frame_data = yield
|
120 |
+
total_frames_output+=1
|
121 |
+
proc.stdin.flush()
|
122 |
+
proc.stdin.close()
|
123 |
+
res = proc.stderr.read()
|
124 |
+
except BrokenPipeError as e:
|
125 |
+
err = proc.stderr.read()
|
126 |
+
#Check if output file exists. If it does, the re-execution
|
127 |
+
#will also fail. This obscures the cause of the error
|
128 |
+
#and seems to never occur concurrent to the metadata issue
|
129 |
+
if os.path.exists(file_path):
|
130 |
+
raise Exception("An error occurred in the ffmpeg subprocess:\n" \
|
131 |
+
+ err.decode("utf-8"))
|
132 |
+
#Res was not set
|
133 |
+
print(err.decode("utf-8"), end="", file=sys.stderr)
|
134 |
+
print("An error occurred when saving with metadata")
|
135 |
+
if res != b'':
|
136 |
+
with subprocess.Popen(args + [file_path], stderr=subprocess.PIPE,
|
137 |
+
stdin=subprocess.PIPE, env=env) as proc:
|
138 |
+
try:
|
139 |
+
while frame_data is not None:
|
140 |
+
proc.stdin.write(frame_data)
|
141 |
+
frame_data = yield
|
142 |
+
total_frames_output+=1
|
143 |
+
proc.stdin.flush()
|
144 |
+
proc.stdin.close()
|
145 |
+
res = proc.stderr.read()
|
146 |
+
except BrokenPipeError as e:
|
147 |
+
res = proc.stderr.read()
|
148 |
+
raise Exception("An error occurred in the ffmpeg subprocess:\n" \
|
149 |
+
+ res.decode("utf-8"))
|
150 |
+
yield total_frames_output
|
151 |
+
if len(res) > 0:
|
152 |
+
print(res.decode("utf-8"), end="", file=sys.stderr)
|
153 |
+
|
154 |
+
def to_pingpong(inp):
|
155 |
+
if not hasattr(inp, "__getitem__"):
|
156 |
+
inp = list(inp)
|
157 |
+
yield from inp
|
158 |
+
for i in range(len(inp)-2,0,-1):
|
159 |
+
yield inp[i]
|
160 |
+
|
161 |
+
def apply_format_widgets(format_name, kwargs):
|
162 |
+
video_format_path = folder_paths.get_full_path("VHS_video_formats", format_name + ".json")
|
163 |
+
with open(video_format_path, 'r') as stream:
|
164 |
+
video_format = json.load(stream)
|
165 |
+
|
166 |
+
for w in gen_format_widgets(video_format):
|
167 |
+
assert(w[0][0] in kwargs)
|
168 |
+
if len(w[0]) > 3:
|
169 |
+
w[0] = Template(w[0][3]).substitute(val=kwargs[w[0][0]])
|
170 |
+
else:
|
171 |
+
w[0] = str(kwargs[w[0][0]])
|
172 |
+
return video_format
|
173 |
+
|
174 |
+
class SaveVideoNode:
|
175 |
+
@classmethod
|
176 |
+
def INPUT_TYPES(s):
|
177 |
+
ffmpeg_formats = get_video_formats()
|
178 |
+
return {
|
179 |
+
"required": {
|
180 |
+
"path": ("STRING", {"multiline": True, "dynamicPrompts": False}),
|
181 |
+
"format": (ffmpeg_formats,),
|
182 |
+
"quality": ([100, 95, 90, 85, 80, 75, 70, 60, 50], {"default": 100}),
|
183 |
+
"pingpong": ("BOOLEAN", {"default": False}),
|
184 |
+
},
|
185 |
+
"optional": {
|
186 |
+
"images": ("IMAGE",),
|
187 |
+
"audio": ("AUDIO",),
|
188 |
+
"frame_rate": ("INT,FLOAT", { "default": 25.0, "step": 1.0, "min": 1.0, "max": 60.0 }),
|
189 |
+
"meta_batch": ("BatchManager",),
|
190 |
+
},
|
191 |
+
"hidden": {
|
192 |
+
"prompt": "PROMPT",
|
193 |
+
"unique_id": "UNIQUE_ID"
|
194 |
+
},
|
195 |
+
}
|
196 |
+
|
197 |
+
RETURN_TYPES = ()
|
198 |
+
CATEGORY = "tbox/Video"
|
199 |
+
FUNCTION = "save_video"
|
200 |
+
OUTPUT_NODE = True
|
201 |
+
|
202 |
+
def save_video(
|
203 |
+
self,
|
204 |
+
path,
|
205 |
+
frame_rate=25,
|
206 |
+
images=None,
|
207 |
+
format="video/h264-mp4",
|
208 |
+
quality=85,
|
209 |
+
pingpong=False,
|
210 |
+
audio=None,
|
211 |
+
prompt=None,
|
212 |
+
meta_batch=None,
|
213 |
+
unique_id=None,
|
214 |
+
manual_format_widgets=None,
|
215 |
+
):
|
216 |
+
if images is None:
|
217 |
+
return {}
|
218 |
+
if isinstance(images, torch.Tensor) and images.size(0) == 0:
|
219 |
+
return {}
|
220 |
+
|
221 |
+
if frame_rate < 1:
|
222 |
+
frame_rate = 1
|
223 |
+
elif frame_rate > 120:
|
224 |
+
frame_rate = 120
|
225 |
+
|
226 |
+
num_frames = len(images)
|
227 |
+
|
228 |
+
first_image = images[0]
|
229 |
+
images = iter(images)
|
230 |
+
|
231 |
+
file_path = os.path.abspath(path.split('\n')[0])
|
232 |
+
output_dir = os.path.dirname(file_path)
|
233 |
+
filename = os.path.basename(file_path)
|
234 |
+
name, extension = os.path.splitext(filename)
|
235 |
+
|
236 |
+
output_process = None
|
237 |
+
|
238 |
+
video_metadata = {}
|
239 |
+
if prompt is not None:
|
240 |
+
video_metadata["prompt"] = prompt
|
241 |
+
|
242 |
+
if meta_batch is not None and unique_id in meta_batch.outputs:
|
243 |
+
(counter, output_process) = meta_batch.outputs[unique_id]
|
244 |
+
else:
|
245 |
+
counter = 0
|
246 |
+
output_process = None
|
247 |
+
|
248 |
+
format_type, format_ext = format.split("/")
|
249 |
+
|
250 |
+
# Use ffmpeg to save a video
|
251 |
+
if ffmpeg_path is None:
|
252 |
+
raise ProcessLookupError(f"ffmpeg is required for video outputs and could not be found.\nIn order to use video outputs, you must either:\n- Install imageio-ffmpeg with pip,\n- Place a ffmpeg executable in {os.path.abspath('')}, or\n- Install ffmpeg and add it to the system path.")
|
253 |
+
|
254 |
+
#Acquire additional format_widget values
|
255 |
+
kwargs = None
|
256 |
+
if manual_format_widgets is None:
|
257 |
+
if prompt is not None:
|
258 |
+
kwargs = prompt[unique_id]['inputs']
|
259 |
+
else:
|
260 |
+
manual_format_widgets = {}
|
261 |
+
|
262 |
+
if kwargs is None:
|
263 |
+
kwargs = get_format_widget_defaults(format_ext)
|
264 |
+
missing = {}
|
265 |
+
for k in kwargs.keys():
|
266 |
+
if k in manual_format_widgets:
|
267 |
+
kwargs[k] = manual_format_widgets[k]
|
268 |
+
else:
|
269 |
+
missing[k] = kwargs[k]
|
270 |
+
if len(missing) > 0:
|
271 |
+
print("Extra format values were not provided, the following defaults will be used: " + str(kwargs) + "\nThis is likely due to usage of ComfyUI-to-python. These values can be manually set by supplying a manual_format_widgets argument")
|
272 |
+
|
273 |
+
video_format = apply_format_widgets(format_ext, kwargs)
|
274 |
+
has_alpha = first_image.shape[-1] == 4
|
275 |
+
dim_alignment = video_format.get("dim_alignment", 8)
|
276 |
+
if (first_image.shape[1] % dim_alignment) or (first_image.shape[0] % dim_alignment):
|
277 |
+
#output frames must be padded
|
278 |
+
to_pad = (-first_image.shape[1] % dim_alignment,
|
279 |
+
-first_image.shape[0] % dim_alignment)
|
280 |
+
padding = (to_pad[0]//2, to_pad[0] - to_pad[0]//2,
|
281 |
+
to_pad[1]//2, to_pad[1] - to_pad[1]//2)
|
282 |
+
padfunc = torch.nn.ReplicationPad2d(padding)
|
283 |
+
def pad(image):
|
284 |
+
image = image.permute((2,0,1))#HWC to CHW
|
285 |
+
padded = padfunc(image.to(dtype=torch.float32))
|
286 |
+
return padded.permute((1,2,0))
|
287 |
+
images = map(pad, images)
|
288 |
+
new_dims = (-first_image.shape[1] % dim_alignment + first_image.shape[1],
|
289 |
+
-first_image.shape[0] % dim_alignment + first_image.shape[0])
|
290 |
+
dimensions = f"{new_dims[0]}x{new_dims[1]}"
|
291 |
+
print(f"Output images were not of valid resolution and have had padding applied: {dimensions}")
|
292 |
+
else:
|
293 |
+
dimensions = f"{first_image.shape[1]}x{first_image.shape[0]}"
|
294 |
+
|
295 |
+
if pingpong:
|
296 |
+
if meta_batch is not None:
|
297 |
+
print("pingpong is incompatible with batched output")
|
298 |
+
images = to_pingpong(images)
|
299 |
+
|
300 |
+
images = map(tensor_to_bytes, images)
|
301 |
+
if has_alpha:
|
302 |
+
i_pix_fmt = 'rgba'
|
303 |
+
else:
|
304 |
+
i_pix_fmt = 'rgb24'
|
305 |
+
|
306 |
+
args = [ffmpeg_path, "-v", "error", "-f", "rawvideo", "-pix_fmt", i_pix_fmt,
|
307 |
+
"-s", dimensions, "-r", str(frame_rate), "-i", "-"]
|
308 |
+
|
309 |
+
images = map(lambda x: x.tobytes(), images)
|
310 |
+
env=os.environ.copy()
|
311 |
+
if "environment" in video_format:
|
312 |
+
env.update(video_format["environment"])
|
313 |
+
|
314 |
+
if "pre_pass" in video_format:
|
315 |
+
images = [b''.join(images)]
|
316 |
+
os.makedirs(folder_paths.get_temp_directory(), exist_ok=True)
|
317 |
+
pre_pass_args = args[:13] + video_format['pre_pass']
|
318 |
+
try:
|
319 |
+
subprocess.run(pre_pass_args, input=images[0], env=env,
|
320 |
+
capture_output=True, check=True)
|
321 |
+
except subprocess.CalledProcessError as e:
|
322 |
+
raise Exception("An error occurred in the ffmpeg prepass:\n" \
|
323 |
+
+ e.stderr.decode("utf-8"))
|
324 |
+
if "inputs_main_pass" in video_format:
|
325 |
+
args = args[:13] + video_format['inputs_main_pass'] + args[13:]
|
326 |
+
|
327 |
+
if output_process is None:
|
328 |
+
args += video_format['main_pass']
|
329 |
+
output_process = ffmpeg_process(args, video_format, video_metadata, file_path, env)
|
330 |
+
#Proceed to first yield
|
331 |
+
output_process.send(None)
|
332 |
+
if meta_batch is not None:
|
333 |
+
meta_batch.outputs[unique_id] = (0, output_process)
|
334 |
+
|
335 |
+
for image in images:
|
336 |
+
output_process.send(image)
|
337 |
+
if meta_batch is not None:
|
338 |
+
requeue_workflow((meta_batch.unique_id, not meta_batch.has_closed_inputs))
|
339 |
+
if meta_batch is None or meta_batch.has_closed_inputs:
|
340 |
+
#Close pipe and wait for termination.
|
341 |
+
try:
|
342 |
+
total_frames_output = output_process.send(None)
|
343 |
+
output_process.send(None)
|
344 |
+
except StopIteration:
|
345 |
+
pass
|
346 |
+
if meta_batch is not None:
|
347 |
+
meta_batch.outputs.pop(unique_id)
|
348 |
+
#if len(meta_batch.outputs) == 0:
|
349 |
+
# meta_batch.reset()
|
350 |
+
else:
|
351 |
+
return {}
|
352 |
+
|
353 |
+
a_waveform = None
|
354 |
+
if audio is not None:
|
355 |
+
try:
|
356 |
+
#safely check if audio produced by VHS_LoadVideo actually exists
|
357 |
+
a_waveform = audio['waveform']
|
358 |
+
except:
|
359 |
+
print(f'save audio >> not waveform')
|
360 |
+
pass
|
361 |
+
if a_waveform is not None:
|
362 |
+
# Create audio file if input was provided
|
363 |
+
output_file_with_audio = f"{name}-audio{extension}"
|
364 |
+
output_file_with_audio_path = os.path.join(output_dir, output_file_with_audio)
|
365 |
+
if "audio_pass" not in video_format:
|
366 |
+
print("Selected video format does not have explicit audio support")
|
367 |
+
video_format["audio_pass"] = ["-c:a", "libopus"]
|
368 |
+
|
369 |
+
|
370 |
+
# FFmpeg command with audio re-encoding
|
371 |
+
#TODO: expose audio quality options if format widgets makes it in
|
372 |
+
#Reconsider forcing apad/shortest
|
373 |
+
channels = audio['waveform'].size(1)
|
374 |
+
min_audio_dur = total_frames_output / frame_rate + 1
|
375 |
+
mux_args = [ffmpeg_path, "-v", "error", "-i", file_path,
|
376 |
+
"-ar", str(audio['sample_rate']), "-ac", str(channels),
|
377 |
+
"-y","-f", "f32le", "-i", "-", "-c:v", "copy"] \
|
378 |
+
+ video_format["audio_pass"] \
|
379 |
+
+ ["-af", "apad=whole_dur="+str(min_audio_dur),
|
380 |
+
"-shortest", output_file_with_audio_path]
|
381 |
+
|
382 |
+
audio_data = audio['waveform'].squeeze(0).transpose(0,1) \
|
383 |
+
.numpy().tobytes()
|
384 |
+
try:
|
385 |
+
res = subprocess.run(mux_args, input=audio_data,
|
386 |
+
env=env, capture_output=True, check=True)
|
387 |
+
if res.returncode == 0:
|
388 |
+
self.replace_file(output_file_with_audio_path, file_path)
|
389 |
+
except subprocess.CalledProcessError as e:
|
390 |
+
raise Exception("An error occured in the ffmpeg subprocess:\n" \
|
391 |
+
+ e.stderr.decode("utf-8"))
|
392 |
+
if res.stderr:
|
393 |
+
print(res.stderr.decode("utf-8"), end="", file=sys.stderr)
|
394 |
+
|
395 |
+
|
396 |
+
return {}
|
397 |
+
|
398 |
+
@classmethod
|
399 |
+
def VALIDATE_INPUTS(self, format, **kwargs):
|
400 |
+
return True
|
401 |
+
|
402 |
+
def replace_file(self, audio_path, file_path):
|
403 |
+
try:
|
404 |
+
# 删除 file_path 文件(如果存在)
|
405 |
+
if os.path.exists(file_path):
|
406 |
+
os.remove(file_path)
|
407 |
+
print(f"Deleted file: {file_path}")
|
408 |
+
else:
|
409 |
+
print(f"File not found, skipping deletion: {file_path}")
|
410 |
+
|
411 |
+
# 将 output_file_with_audio_path 重命名为 file_path
|
412 |
+
os.rename(audio_path, file_path)
|
413 |
+
print(f"Renamed {audio_path} to {file_path}")
|
414 |
+
except Exception as e:
|
415 |
+
print(f"An error occurred: {e}")
|
custom_nodes/ComfyUI-tbox/nodes/video/video_formats/h264-mp4.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"main_pass":
|
3 |
+
[
|
4 |
+
"-y", "-c:v", "libx264",
|
5 |
+
"-pix_fmt", "yuv420p",
|
6 |
+
"-crf", "20"
|
7 |
+
],
|
8 |
+
"audio_pass": ["-c:a", "aac"],
|
9 |
+
"extension": "mp4"
|
10 |
+
}
|
custom_nodes/ComfyUI-tbox/nodes/video/video_formats/h265-mp4.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"main_pass":
|
3 |
+
[
|
4 |
+
"-y", "-c:v", "libx265",
|
5 |
+
"-vtag", "hvc1",
|
6 |
+
"-pix_fmt", "yuv420p10le",
|
7 |
+
"-crf", "22",
|
8 |
+
"-preset", "medium",
|
9 |
+
"-x265-params", "log-level=quiet"
|
10 |
+
],
|
11 |
+
"audio_pass": ["-c:a", "aac"],
|
12 |
+
"extension": "mp4"
|
13 |
+
}
|
custom_nodes/ComfyUI-tbox/nodes/video/video_formats/nvenc_h264-mp4.json
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"main_pass":
|
3 |
+
[
|
4 |
+
"-y", "-c:v", "h264_nvenc",
|
5 |
+
"-pix_fmt", "yuv420p",
|
6 |
+
"-qp", "20"
|
7 |
+
],
|
8 |
+
"audio_pass": ["-c:a", "aac"],
|
9 |
+
"bitrate": ["bitrate","INT", {"default": 10, "min": 1, "max": 999, "step": 1 }],
|
10 |
+
"megabit": ["megabit","BOOLEAN", {"default": true}],
|
11 |
+
"extension": "mp4"
|
12 |
+
}
|
custom_nodes/ComfyUI-tbox/nodes/video/video_formats/nvenc_h265-mp4.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"main_pass":
|
3 |
+
[
|
4 |
+
"-y", "-c:v", "hevc_nvenc",
|
5 |
+
"-vtag", "hvc1",
|
6 |
+
"-qp", "22"
|
7 |
+
],
|
8 |
+
"audio_pass": ["-c:a", "aac"],
|
9 |
+
"extension": "mp4"
|
10 |
+
}
|
custom_nodes/ComfyUI-tbox/nodes/video/video_formats/webm.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"main_pass":
|
3 |
+
[
|
4 |
+
"-y",
|
5 |
+
"-crf", 20,
|
6 |
+
"-pix_fmt", "yuv420p",
|
7 |
+
"-b:v", "0"
|
8 |
+
],
|
9 |
+
"audio_pass": ["-c:a", "libvorbis"],
|
10 |
+
"extension": "webm"
|
11 |
+
}
|
custom_nodes/ComfyUI-tbox/requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
opencv-python-headless>=4.7.0.72
|
4 |
+
scikit-learn
|
5 |
+
scikit-image
|
6 |
+
insightface
|
7 |
+
ultralytics
|
8 |
+
onnxruntime-gpu==1.18.0
|
9 |
+
onnxruntime==1.18.0
|
10 |
+
imageio_ffmpeg
|
11 |
+
pykalman
|
custom_nodes/ComfyUI-tbox/src/canny/__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from common import resize_image_with_pad, common_input_validate, HWC3
|
6 |
+
|
7 |
+
class CannyDetector:
|
8 |
+
def __call__(self, input_image=None, low_threshold=100, high_threshold=200, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs):
|
9 |
+
input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
|
10 |
+
detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
|
11 |
+
detected_map = cv2.Canny(detected_map, low_threshold, high_threshold)
|
12 |
+
detected_map = HWC3(remove_pad(detected_map))
|
13 |
+
|
14 |
+
if output_type == "pil":
|
15 |
+
detected_map = Image.fromarray(detected_map)
|
16 |
+
|
17 |
+
return detected_map
|
custom_nodes/ComfyUI-tbox/src/common.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import os
|
5 |
+
import tempfile
|
6 |
+
import warnings
|
7 |
+
from contextlib import suppress
|
8 |
+
from pathlib import Path
|
9 |
+
from huggingface_hub import constants, hf_hub_download
|
10 |
+
from ast import literal_eval
|
11 |
+
|
12 |
+
TEMP_DIR = tempfile.gettempdir()
|
13 |
+
ANNOTATOR_CKPTS_PATH = os.path.join(Path(__file__).parents[2], 'ckpts')
|
14 |
+
USE_SYMLINKS = False
|
15 |
+
|
16 |
+
|
17 |
+
BIGMIN = -(2**53-1)
|
18 |
+
BIGMAX = (2**53-1)
|
19 |
+
|
20 |
+
DIMMAX = 8192
|
21 |
+
|
22 |
+
try:
|
23 |
+
ANNOTATOR_CKPTS_PATH = os.environ['AUX_ANNOTATOR_CKPTS_PATH']
|
24 |
+
except:
|
25 |
+
warnings.warn("Custom pressesor model path not set successfully.")
|
26 |
+
pass
|
27 |
+
|
28 |
+
try:
|
29 |
+
USE_SYMLINKS = literal_eval(os.environ['AUX_USE_SYMLINKS'])
|
30 |
+
except:
|
31 |
+
warnings.warn("USE_SYMLINKS not set successfully. Using default value: False to download models.")
|
32 |
+
pass
|
33 |
+
|
34 |
+
try:
|
35 |
+
TEMP_DIR = os.environ['AUX_TEMP_DIR']
|
36 |
+
if len(TEMP_DIR) >= 60:
|
37 |
+
warnings.warn(f"custom temp dir is too long. Using default")
|
38 |
+
TEMP_DIR = tempfile.gettempdir()
|
39 |
+
except:
|
40 |
+
warnings.warn(f"custom temp dir not set successfully")
|
41 |
+
pass
|
42 |
+
|
43 |
+
here = Path(__file__).parent.resolve()
|
44 |
+
|
45 |
+
def safer_memory(x):
|
46 |
+
# Fix many MAC/AMD problems
|
47 |
+
return np.ascontiguousarray(x.copy()).copy()
|
48 |
+
|
49 |
+
UPSCALE_METHODS = ["INTER_NEAREST", "INTER_LINEAR", "INTER_AREA", "INTER_CUBIC", "INTER_LANCZOS4"]
|
50 |
+
def get_upscale_method(method_str):
|
51 |
+
assert method_str in UPSCALE_METHODS, f"Method {method_str} not found in {UPSCALE_METHODS}"
|
52 |
+
return getattr(cv2, method_str)
|
53 |
+
|
54 |
+
def pad64(x):
|
55 |
+
return int(np.ceil(float(x) / 64.0) * 64 - x)
|
56 |
+
|
57 |
+
def resize_image_with_pad(input_image, resolution, upscale_method = "", skip_hwc3=False, mode='edge'):
|
58 |
+
if skip_hwc3:
|
59 |
+
img = input_image
|
60 |
+
else:
|
61 |
+
img = HWC3(input_image)
|
62 |
+
H_raw, W_raw, _ = img.shape
|
63 |
+
if resolution == 0:
|
64 |
+
return img, lambda x: x
|
65 |
+
k = float(resolution) / float(min(H_raw, W_raw))
|
66 |
+
H_target = int(np.round(float(H_raw) * k))
|
67 |
+
W_target = int(np.round(float(W_raw) * k))
|
68 |
+
img = cv2.resize(img, (W_target, H_target), interpolation=get_upscale_method(upscale_method) if k > 1 else cv2.INTER_AREA)
|
69 |
+
H_pad, W_pad = pad64(H_target), pad64(W_target)
|
70 |
+
img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode)
|
71 |
+
|
72 |
+
def remove_pad(x):
|
73 |
+
return safer_memory(x[:H_target, :W_target, ...])
|
74 |
+
|
75 |
+
return safer_memory(img_padded), remove_pad
|
76 |
+
|
77 |
+
|
78 |
+
def common_input_validate(input_image, output_type, **kwargs):
|
79 |
+
if "img" in kwargs:
|
80 |
+
warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning)
|
81 |
+
input_image = kwargs.pop("img")
|
82 |
+
|
83 |
+
if "return_pil" in kwargs:
|
84 |
+
warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning)
|
85 |
+
output_type = "pil" if kwargs["return_pil"] else "np"
|
86 |
+
|
87 |
+
if type(output_type) is bool:
|
88 |
+
warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions")
|
89 |
+
if output_type:
|
90 |
+
output_type = "pil"
|
91 |
+
|
92 |
+
if input_image is None:
|
93 |
+
raise ValueError("input_image must be defined.")
|
94 |
+
|
95 |
+
if not isinstance(input_image, np.ndarray):
|
96 |
+
input_image = np.array(input_image, dtype=np.uint8)
|
97 |
+
output_type = output_type or "pil"
|
98 |
+
else:
|
99 |
+
output_type = output_type or "np"
|
100 |
+
|
101 |
+
return (input_image, output_type)
|
102 |
+
|
103 |
+
def custom_hf_download(pretrained_model_or_path, filename, cache_dir=TEMP_DIR, ckpts_dir=ANNOTATOR_CKPTS_PATH, subfolder=str(""), use_symlinks=USE_SYMLINKS, repo_type="model"):
|
104 |
+
|
105 |
+
print(f'cache_dir: {cache_dir}')
|
106 |
+
print(f'ckpts_dir: {ckpts_dir}')
|
107 |
+
print(f'use_symlinks: {use_symlinks}')
|
108 |
+
local_dir = os.path.join(ckpts_dir, pretrained_model_or_path)
|
109 |
+
model_path = os.path.join(local_dir, *subfolder.split('/'), filename)
|
110 |
+
|
111 |
+
if len(str(model_path)) >= 255:
|
112 |
+
warnings.warn(f"Path {model_path} is too long, \n please change annotator_ckpts_path in config.yaml")
|
113 |
+
|
114 |
+
if not os.path.exists(model_path):
|
115 |
+
print(f"Failed to find {model_path}.\n Downloading from huggingface.co")
|
116 |
+
print(f"cacher folder is {cache_dir}, you can change it by custom_tmp_path in config.yaml")
|
117 |
+
if use_symlinks:
|
118 |
+
cache_dir_d = constants.HF_HUB_CACHE # use huggingface newer env variables `HF_HUB_CACHE`
|
119 |
+
if cache_dir_d is None:
|
120 |
+
import platform
|
121 |
+
if platform.system() == "Windows":
|
122 |
+
cache_dir_d = os.path.join(os.getenv("USERPROFILE"), ".cache", "huggingface", "hub")
|
123 |
+
else:
|
124 |
+
cache_dir_d = os.path.join(os.getenv("HOME"), ".cache", "huggingface", "hub")
|
125 |
+
try:
|
126 |
+
# test_link
|
127 |
+
Path(cache_dir_d).mkdir(parents=True, exist_ok=True)
|
128 |
+
Path(ckpts_dir).mkdir(parents=True, exist_ok=True)
|
129 |
+
(Path(cache_dir_d) / f"linktest_{filename}.txt").touch()
|
130 |
+
# symlink instead of link avoid `invalid cross-device link` error.
|
131 |
+
os.symlink(os.path.join(cache_dir_d, f"linktest_{filename}.txt"), os.path.join(ckpts_dir, f"linktest_{filename}.txt"))
|
132 |
+
print("Using symlinks to download models. \n",\
|
133 |
+
"Make sure you have enough space on your cache folder. \n",\
|
134 |
+
"And do not purge the cache folder after downloading.\n",\
|
135 |
+
"Otherwise, you will have to re-download the models every time you run the script.\n",\
|
136 |
+
"You can use USE_SYMLINKS: False in config.yaml to avoid this behavior.")
|
137 |
+
except:
|
138 |
+
print("Maybe not able to create symlink. Disable using symlinks.")
|
139 |
+
use_symlinks = False
|
140 |
+
cache_dir_d = os.path.join(cache_dir, "ckpts", pretrained_model_or_path)
|
141 |
+
finally: # always remove test link files
|
142 |
+
with suppress(FileNotFoundError):
|
143 |
+
os.remove(os.path.join(ckpts_dir, f"linktest_{filename}.txt"))
|
144 |
+
os.remove(os.path.join(cache_dir_d, f"linktest_{filename}.txt"))
|
145 |
+
else:
|
146 |
+
cache_dir_d = os.path.join(cache_dir, "ckpts", pretrained_model_or_path)
|
147 |
+
|
148 |
+
model_path = hf_hub_download(repo_id=pretrained_model_or_path,
|
149 |
+
cache_dir=cache_dir_d,
|
150 |
+
local_dir=local_dir,
|
151 |
+
subfolder=subfolder,
|
152 |
+
filename=filename,
|
153 |
+
local_dir_use_symlinks=use_symlinks,
|
154 |
+
resume_download=True,
|
155 |
+
etag_timeout=100,
|
156 |
+
repo_type=repo_type
|
157 |
+
)
|
158 |
+
if not use_symlinks:
|
159 |
+
try:
|
160 |
+
import shutil
|
161 |
+
shutil.rmtree(os.path.join(cache_dir, "ckpts"))
|
162 |
+
except Exception as e :
|
163 |
+
print(e)
|
164 |
+
|
165 |
+
print(f"model_path is {model_path}")
|
166 |
+
|
167 |
+
return model_path
|
168 |
+
|
169 |
+
|
170 |
+
def HWC3(x):
|
171 |
+
assert x.dtype == np.uint8
|
172 |
+
if x.ndim == 2:
|
173 |
+
x = x[:, :, None]
|
174 |
+
assert x.ndim == 3
|
175 |
+
H, W, C = x.shape
|
176 |
+
assert C == 1 or C == 3 or C == 4
|
177 |
+
if C == 3:
|
178 |
+
return x
|
179 |
+
if C == 1:
|
180 |
+
return np.concatenate([x, x, x], axis=2)
|
181 |
+
if C == 4:
|
182 |
+
color = x[:, :, 0:3].astype(np.float32)
|
183 |
+
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
|
184 |
+
y = color * alpha + 255.0 * (1.0 - alpha)
|
185 |
+
y = y.clip(0, 255).astype(np.uint8)
|
186 |
+
return y
|
custom_nodes/ComfyUI-tbox/src/densepose/__init__.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torchvision # Fix issue Unknown builtin op: torchvision::nms
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from einops import rearrange
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
from common import HWC3, resize_image_with_pad, common_input_validate, custom_hf_download
|
10 |
+
from .densepose import DensePoseMaskedColormapResultsVisualizer, _extract_i_from_iuvarr, densepose_chart_predictor_output_to_result_with_confidences
|
11 |
+
|
12 |
+
N_PART_LABELS = 24
|
13 |
+
DENSEPOSE_MODEL_NAME = "LayerNorm/DensePose-TorchScript-with-hint-image"
|
14 |
+
|
15 |
+
class DenseposeDetector:
|
16 |
+
def __init__(self, model):
|
17 |
+
self.dense_pose_estimation = model
|
18 |
+
self.device = "cpu"
|
19 |
+
self.result_visualizer = DensePoseMaskedColormapResultsVisualizer(
|
20 |
+
alpha=1,
|
21 |
+
data_extractor=_extract_i_from_iuvarr,
|
22 |
+
segm_extractor=_extract_i_from_iuvarr,
|
23 |
+
val_scale = 255.0 / N_PART_LABELS
|
24 |
+
)
|
25 |
+
|
26 |
+
@classmethod
|
27 |
+
def from_pretrained(cls, pretrained_model_or_path=DENSEPOSE_MODEL_NAME, filename="densepose_r50_fpn_dl.torchscript"):
|
28 |
+
torchscript_model_path = custom_hf_download(pretrained_model_or_path, filename)
|
29 |
+
densepose = torch.jit.load(torchscript_model_path, map_location="cpu")
|
30 |
+
return cls(densepose)
|
31 |
+
|
32 |
+
def to(self, device):
|
33 |
+
self.dense_pose_estimation.to(device)
|
34 |
+
self.device = device
|
35 |
+
return self
|
36 |
+
|
37 |
+
def __call__(self, input_image, detect_resolution=512, output_type="pil", upscale_method="INTER_CUBIC", cmap="viridis", **kwargs):
|
38 |
+
input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
|
39 |
+
input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
|
40 |
+
H, W = input_image.shape[:2]
|
41 |
+
|
42 |
+
hint_image_canvas = np.zeros([H, W], dtype=np.uint8)
|
43 |
+
hint_image_canvas = np.tile(hint_image_canvas[:, :, np.newaxis], [1, 1, 3])
|
44 |
+
|
45 |
+
input_image = rearrange(torch.from_numpy(input_image).to(self.device), 'h w c -> c h w')
|
46 |
+
|
47 |
+
pred_boxes, corase_segm, fine_segm, u, v = self.dense_pose_estimation(input_image)
|
48 |
+
|
49 |
+
extractor = densepose_chart_predictor_output_to_result_with_confidences
|
50 |
+
densepose_results = [extractor(pred_boxes[i:i+1], corase_segm[i:i+1], fine_segm[i:i+1], u[i:i+1], v[i:i+1]) for i in range(len(pred_boxes))]
|
51 |
+
|
52 |
+
if cmap=="viridis":
|
53 |
+
self.result_visualizer.mask_visualizer.cmap = cv2.COLORMAP_VIRIDIS
|
54 |
+
hint_image = self.result_visualizer.visualize(hint_image_canvas, densepose_results)
|
55 |
+
hint_image = cv2.cvtColor(hint_image, cv2.COLOR_BGR2RGB)
|
56 |
+
hint_image[:, :, 0][hint_image[:, :, 0] == 0] = 68
|
57 |
+
hint_image[:, :, 1][hint_image[:, :, 1] == 0] = 1
|
58 |
+
hint_image[:, :, 2][hint_image[:, :, 2] == 0] = 84
|
59 |
+
else:
|
60 |
+
self.result_visualizer.mask_visualizer.cmap = cv2.COLORMAP_PARULA
|
61 |
+
hint_image = self.result_visualizer.visualize(hint_image_canvas, densepose_results)
|
62 |
+
hint_image = cv2.cvtColor(hint_image, cv2.COLOR_BGR2RGB)
|
63 |
+
|
64 |
+
detected_map = remove_pad(HWC3(hint_image))
|
65 |
+
if output_type == "pil":
|
66 |
+
detected_map = Image.fromarray(detected_map)
|
67 |
+
return detected_map
|
custom_nodes/ComfyUI-tbox/src/densepose/densepose.py
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
from enum import IntEnum
|
5 |
+
from typing import List, Tuple, Union
|
6 |
+
import torch
|
7 |
+
from torch.nn import functional as F
|
8 |
+
import logging
|
9 |
+
import cv2
|
10 |
+
|
11 |
+
Image = np.ndarray
|
12 |
+
Boxes = torch.Tensor
|
13 |
+
ImageSizeType = Tuple[int, int]
|
14 |
+
_RawBoxType = Union[List[float], Tuple[float, ...], torch.Tensor, np.ndarray]
|
15 |
+
IntTupleBox = Tuple[int, int, int, int]
|
16 |
+
|
17 |
+
class BoxMode(IntEnum):
|
18 |
+
"""
|
19 |
+
Enum of different ways to represent a box.
|
20 |
+
"""
|
21 |
+
|
22 |
+
XYXY_ABS = 0
|
23 |
+
"""
|
24 |
+
(x0, y0, x1, y1) in absolute floating points coordinates.
|
25 |
+
The coordinates in range [0, width or height].
|
26 |
+
"""
|
27 |
+
XYWH_ABS = 1
|
28 |
+
"""
|
29 |
+
(x0, y0, w, h) in absolute floating points coordinates.
|
30 |
+
"""
|
31 |
+
XYXY_REL = 2
|
32 |
+
"""
|
33 |
+
Not yet supported!
|
34 |
+
(x0, y0, x1, y1) in range [0, 1]. They are relative to the size of the image.
|
35 |
+
"""
|
36 |
+
XYWH_REL = 3
|
37 |
+
"""
|
38 |
+
Not yet supported!
|
39 |
+
(x0, y0, w, h) in range [0, 1]. They are relative to the size of the image.
|
40 |
+
"""
|
41 |
+
XYWHA_ABS = 4
|
42 |
+
"""
|
43 |
+
(xc, yc, w, h, a) in absolute floating points coordinates.
|
44 |
+
(xc, yc) is the center of the rotated box, and the angle a is in degrees ccw.
|
45 |
+
"""
|
46 |
+
|
47 |
+
@staticmethod
|
48 |
+
def convert(box: _RawBoxType, from_mode: "BoxMode", to_mode: "BoxMode") -> _RawBoxType:
|
49 |
+
"""
|
50 |
+
Args:
|
51 |
+
box: can be a k-tuple, k-list or an Nxk array/tensor, where k = 4 or 5
|
52 |
+
from_mode, to_mode (BoxMode)
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
The converted box of the same type.
|
56 |
+
"""
|
57 |
+
if from_mode == to_mode:
|
58 |
+
return box
|
59 |
+
|
60 |
+
original_type = type(box)
|
61 |
+
is_numpy = isinstance(box, np.ndarray)
|
62 |
+
single_box = isinstance(box, (list, tuple))
|
63 |
+
if single_box:
|
64 |
+
assert len(box) == 4 or len(box) == 5, (
|
65 |
+
"BoxMode.convert takes either a k-tuple/list or an Nxk array/tensor,"
|
66 |
+
" where k == 4 or 5"
|
67 |
+
)
|
68 |
+
arr = torch.tensor(box)[None, :]
|
69 |
+
else:
|
70 |
+
# avoid modifying the input box
|
71 |
+
if is_numpy:
|
72 |
+
arr = torch.from_numpy(np.asarray(box)).clone()
|
73 |
+
else:
|
74 |
+
arr = box.clone()
|
75 |
+
|
76 |
+
assert to_mode not in [BoxMode.XYXY_REL, BoxMode.XYWH_REL] and from_mode not in [
|
77 |
+
BoxMode.XYXY_REL,
|
78 |
+
BoxMode.XYWH_REL,
|
79 |
+
], "Relative mode not yet supported!"
|
80 |
+
|
81 |
+
if from_mode == BoxMode.XYWHA_ABS and to_mode == BoxMode.XYXY_ABS:
|
82 |
+
assert (
|
83 |
+
arr.shape[-1] == 5
|
84 |
+
), "The last dimension of input shape must be 5 for XYWHA format"
|
85 |
+
original_dtype = arr.dtype
|
86 |
+
arr = arr.double()
|
87 |
+
|
88 |
+
w = arr[:, 2]
|
89 |
+
h = arr[:, 3]
|
90 |
+
a = arr[:, 4]
|
91 |
+
c = torch.abs(torch.cos(a * math.pi / 180.0))
|
92 |
+
s = torch.abs(torch.sin(a * math.pi / 180.0))
|
93 |
+
# This basically computes the horizontal bounding rectangle of the rotated box
|
94 |
+
new_w = c * w + s * h
|
95 |
+
new_h = c * h + s * w
|
96 |
+
|
97 |
+
# convert center to top-left corner
|
98 |
+
arr[:, 0] -= new_w / 2.0
|
99 |
+
arr[:, 1] -= new_h / 2.0
|
100 |
+
# bottom-right corner
|
101 |
+
arr[:, 2] = arr[:, 0] + new_w
|
102 |
+
arr[:, 3] = arr[:, 1] + new_h
|
103 |
+
|
104 |
+
arr = arr[:, :4].to(dtype=original_dtype)
|
105 |
+
elif from_mode == BoxMode.XYWH_ABS and to_mode == BoxMode.XYWHA_ABS:
|
106 |
+
original_dtype = arr.dtype
|
107 |
+
arr = arr.double()
|
108 |
+
arr[:, 0] += arr[:, 2] / 2.0
|
109 |
+
arr[:, 1] += arr[:, 3] / 2.0
|
110 |
+
angles = torch.zeros((arr.shape[0], 1), dtype=arr.dtype)
|
111 |
+
arr = torch.cat((arr, angles), axis=1).to(dtype=original_dtype)
|
112 |
+
else:
|
113 |
+
if to_mode == BoxMode.XYXY_ABS and from_mode == BoxMode.XYWH_ABS:
|
114 |
+
arr[:, 2] += arr[:, 0]
|
115 |
+
arr[:, 3] += arr[:, 1]
|
116 |
+
elif from_mode == BoxMode.XYXY_ABS and to_mode == BoxMode.XYWH_ABS:
|
117 |
+
arr[:, 2] -= arr[:, 0]
|
118 |
+
arr[:, 3] -= arr[:, 1]
|
119 |
+
else:
|
120 |
+
raise NotImplementedError(
|
121 |
+
"Conversion from BoxMode {} to {} is not supported yet".format(
|
122 |
+
from_mode, to_mode
|
123 |
+
)
|
124 |
+
)
|
125 |
+
|
126 |
+
if single_box:
|
127 |
+
return original_type(arr.flatten().tolist())
|
128 |
+
if is_numpy:
|
129 |
+
return arr.numpy()
|
130 |
+
else:
|
131 |
+
return arr
|
132 |
+
|
133 |
+
class MatrixVisualizer:
|
134 |
+
"""
|
135 |
+
Base visualizer for matrix data
|
136 |
+
"""
|
137 |
+
|
138 |
+
def __init__(
|
139 |
+
self,
|
140 |
+
inplace=True,
|
141 |
+
cmap=cv2.COLORMAP_PARULA,
|
142 |
+
val_scale=1.0,
|
143 |
+
alpha=0.7,
|
144 |
+
interp_method_matrix=cv2.INTER_LINEAR,
|
145 |
+
interp_method_mask=cv2.INTER_NEAREST,
|
146 |
+
):
|
147 |
+
self.inplace = inplace
|
148 |
+
self.cmap = cmap
|
149 |
+
self.val_scale = val_scale
|
150 |
+
self.alpha = alpha
|
151 |
+
self.interp_method_matrix = interp_method_matrix
|
152 |
+
self.interp_method_mask = interp_method_mask
|
153 |
+
|
154 |
+
def visualize(self, image_bgr, mask, matrix, bbox_xywh):
|
155 |
+
self._check_image(image_bgr)
|
156 |
+
self._check_mask_matrix(mask, matrix)
|
157 |
+
if self.inplace:
|
158 |
+
image_target_bgr = image_bgr
|
159 |
+
else:
|
160 |
+
image_target_bgr = image_bgr * 0
|
161 |
+
x, y, w, h = [int(v) for v in bbox_xywh]
|
162 |
+
if w <= 0 or h <= 0:
|
163 |
+
return image_bgr
|
164 |
+
mask, matrix = self._resize(mask, matrix, w, h)
|
165 |
+
mask_bg = np.tile((mask == 0)[:, :, np.newaxis], [1, 1, 3])
|
166 |
+
matrix_scaled = matrix.astype(np.float32) * self.val_scale
|
167 |
+
_EPSILON = 1e-6
|
168 |
+
if np.any(matrix_scaled > 255 + _EPSILON):
|
169 |
+
logger = logging.getLogger(__name__)
|
170 |
+
logger.warning(
|
171 |
+
f"Matrix has values > {255 + _EPSILON} after " f"scaling, clipping to [0..255]"
|
172 |
+
)
|
173 |
+
matrix_scaled_8u = matrix_scaled.clip(0, 255).astype(np.uint8)
|
174 |
+
matrix_vis = cv2.applyColorMap(matrix_scaled_8u, self.cmap)
|
175 |
+
matrix_vis[mask_bg] = image_target_bgr[y : y + h, x : x + w, :][mask_bg]
|
176 |
+
image_target_bgr[y : y + h, x : x + w, :] = (
|
177 |
+
image_target_bgr[y : y + h, x : x + w, :] * (1.0 - self.alpha) + matrix_vis * self.alpha
|
178 |
+
)
|
179 |
+
return image_target_bgr.astype(np.uint8)
|
180 |
+
|
181 |
+
def _resize(self, mask, matrix, w, h):
|
182 |
+
if (w != mask.shape[1]) or (h != mask.shape[0]):
|
183 |
+
mask = cv2.resize(mask, (w, h), self.interp_method_mask)
|
184 |
+
if (w != matrix.shape[1]) or (h != matrix.shape[0]):
|
185 |
+
matrix = cv2.resize(matrix, (w, h), self.interp_method_matrix)
|
186 |
+
return mask, matrix
|
187 |
+
|
188 |
+
def _check_image(self, image_rgb):
|
189 |
+
assert len(image_rgb.shape) == 3
|
190 |
+
assert image_rgb.shape[2] == 3
|
191 |
+
assert image_rgb.dtype == np.uint8
|
192 |
+
|
193 |
+
def _check_mask_matrix(self, mask, matrix):
|
194 |
+
assert len(matrix.shape) == 2
|
195 |
+
assert len(mask.shape) == 2
|
196 |
+
assert mask.dtype == np.uint8
|
197 |
+
|
198 |
+
class DensePoseResultsVisualizer:
|
199 |
+
def visualize(
|
200 |
+
self,
|
201 |
+
image_bgr: Image,
|
202 |
+
results,
|
203 |
+
) -> Image:
|
204 |
+
context = self.create_visualization_context(image_bgr)
|
205 |
+
for i, result in enumerate(results):
|
206 |
+
boxes_xywh, labels, uv = result
|
207 |
+
iuv_array = torch.cat(
|
208 |
+
(labels[None].type(torch.float32), uv * 255.0)
|
209 |
+
).type(torch.uint8)
|
210 |
+
self.visualize_iuv_arr(context, iuv_array.cpu().numpy(), boxes_xywh)
|
211 |
+
image_bgr = self.context_to_image_bgr(context)
|
212 |
+
return image_bgr
|
213 |
+
|
214 |
+
def create_visualization_context(self, image_bgr: Image):
|
215 |
+
return image_bgr
|
216 |
+
|
217 |
+
def visualize_iuv_arr(self, context, iuv_arr: np.ndarray, bbox_xywh) -> None:
|
218 |
+
pass
|
219 |
+
|
220 |
+
def context_to_image_bgr(self, context):
|
221 |
+
return context
|
222 |
+
|
223 |
+
def get_image_bgr_from_context(self, context):
|
224 |
+
return context
|
225 |
+
|
226 |
+
class DensePoseMaskedColormapResultsVisualizer(DensePoseResultsVisualizer):
|
227 |
+
def __init__(
|
228 |
+
self,
|
229 |
+
data_extractor,
|
230 |
+
segm_extractor,
|
231 |
+
inplace=True,
|
232 |
+
cmap=cv2.COLORMAP_PARULA,
|
233 |
+
alpha=0.7,
|
234 |
+
val_scale=1.0,
|
235 |
+
**kwargs,
|
236 |
+
):
|
237 |
+
self.mask_visualizer = MatrixVisualizer(
|
238 |
+
inplace=inplace, cmap=cmap, val_scale=val_scale, alpha=alpha
|
239 |
+
)
|
240 |
+
self.data_extractor = data_extractor
|
241 |
+
self.segm_extractor = segm_extractor
|
242 |
+
|
243 |
+
def context_to_image_bgr(self, context):
|
244 |
+
return context
|
245 |
+
|
246 |
+
def visualize_iuv_arr(self, context, iuv_arr: np.ndarray, bbox_xywh) -> None:
|
247 |
+
image_bgr = self.get_image_bgr_from_context(context)
|
248 |
+
matrix = self.data_extractor(iuv_arr)
|
249 |
+
segm = self.segm_extractor(iuv_arr)
|
250 |
+
mask = np.zeros(matrix.shape, dtype=np.uint8)
|
251 |
+
mask[segm > 0] = 1
|
252 |
+
image_bgr = self.mask_visualizer.visualize(image_bgr, mask, matrix, bbox_xywh)
|
253 |
+
|
254 |
+
|
255 |
+
def _extract_i_from_iuvarr(iuv_arr):
|
256 |
+
return iuv_arr[0, :, :]
|
257 |
+
|
258 |
+
|
259 |
+
def _extract_u_from_iuvarr(iuv_arr):
|
260 |
+
return iuv_arr[1, :, :]
|
261 |
+
|
262 |
+
|
263 |
+
def _extract_v_from_iuvarr(iuv_arr):
|
264 |
+
return iuv_arr[2, :, :]
|
265 |
+
|
266 |
+
def make_int_box(box: torch.Tensor) -> IntTupleBox:
|
267 |
+
int_box = [0, 0, 0, 0]
|
268 |
+
int_box[0], int_box[1], int_box[2], int_box[3] = tuple(box.long().tolist())
|
269 |
+
return int_box[0], int_box[1], int_box[2], int_box[3]
|
270 |
+
|
271 |
+
def densepose_chart_predictor_output_to_result_with_confidences(
|
272 |
+
boxes: Boxes,
|
273 |
+
coarse_segm,
|
274 |
+
fine_segm,
|
275 |
+
u, v
|
276 |
+
|
277 |
+
):
|
278 |
+
boxes_xyxy_abs = boxes.clone()
|
279 |
+
boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
|
280 |
+
box_xywh = make_int_box(boxes_xywh_abs[0])
|
281 |
+
|
282 |
+
labels = resample_fine_and_coarse_segm_tensors_to_bbox(fine_segm, coarse_segm, box_xywh).squeeze(0)
|
283 |
+
uv = resample_uv_tensors_to_bbox(u, v, labels, box_xywh)
|
284 |
+
confidences = []
|
285 |
+
return box_xywh, labels, uv
|
286 |
+
|
287 |
+
def resample_fine_and_coarse_segm_tensors_to_bbox(
|
288 |
+
fine_segm: torch.Tensor, coarse_segm: torch.Tensor, box_xywh_abs: IntTupleBox
|
289 |
+
):
|
290 |
+
"""
|
291 |
+
Resample fine and coarse segmentation tensors to the given
|
292 |
+
bounding box and derive labels for each pixel of the bounding box
|
293 |
+
|
294 |
+
Args:
|
295 |
+
fine_segm: float tensor of shape [1, C, Hout, Wout]
|
296 |
+
coarse_segm: float tensor of shape [1, K, Hout, Wout]
|
297 |
+
box_xywh_abs (tuple of 4 int): bounding box given by its upper-left
|
298 |
+
corner coordinates, width (W) and height (H)
|
299 |
+
Return:
|
300 |
+
Labels for each pixel of the bounding box, a long tensor of size [1, H, W]
|
301 |
+
"""
|
302 |
+
x, y, w, h = box_xywh_abs
|
303 |
+
w = max(int(w), 1)
|
304 |
+
h = max(int(h), 1)
|
305 |
+
# coarse segmentation
|
306 |
+
coarse_segm_bbox = F.interpolate(
|
307 |
+
coarse_segm,
|
308 |
+
(h, w),
|
309 |
+
mode="bilinear",
|
310 |
+
align_corners=False,
|
311 |
+
).argmax(dim=1)
|
312 |
+
# combined coarse and fine segmentation
|
313 |
+
labels = (
|
314 |
+
F.interpolate(fine_segm, (h, w), mode="bilinear", align_corners=False).argmax(dim=1)
|
315 |
+
* (coarse_segm_bbox > 0).long()
|
316 |
+
)
|
317 |
+
return labels
|
318 |
+
|
319 |
+
def resample_uv_tensors_to_bbox(
|
320 |
+
u: torch.Tensor,
|
321 |
+
v: torch.Tensor,
|
322 |
+
labels: torch.Tensor,
|
323 |
+
box_xywh_abs: IntTupleBox,
|
324 |
+
) -> torch.Tensor:
|
325 |
+
"""
|
326 |
+
Resamples U and V coordinate estimates for the given bounding box
|
327 |
+
|
328 |
+
Args:
|
329 |
+
u (tensor [1, C, H, W] of float): U coordinates
|
330 |
+
v (tensor [1, C, H, W] of float): V coordinates
|
331 |
+
labels (tensor [H, W] of long): labels obtained by resampling segmentation
|
332 |
+
outputs for the given bounding box
|
333 |
+
box_xywh_abs (tuple of 4 int): bounding box that corresponds to predictor outputs
|
334 |
+
Return:
|
335 |
+
Resampled U and V coordinates - a tensor [2, H, W] of float
|
336 |
+
"""
|
337 |
+
x, y, w, h = box_xywh_abs
|
338 |
+
w = max(int(w), 1)
|
339 |
+
h = max(int(h), 1)
|
340 |
+
u_bbox = F.interpolate(u, (h, w), mode="bilinear", align_corners=False)
|
341 |
+
v_bbox = F.interpolate(v, (h, w), mode="bilinear", align_corners=False)
|
342 |
+
uv = torch.zeros([2, h, w], dtype=torch.float32, device=u.device)
|
343 |
+
for part_id in range(1, u_bbox.size(1)):
|
344 |
+
uv[0][labels == part_id] = u_bbox[0, part_id][labels == part_id]
|
345 |
+
uv[1][labels == part_id] = v_bbox[0, part_id][labels == part_id]
|
346 |
+
return uv
|
347 |
+
|
custom_nodes/ComfyUI-tbox/src/dwpose/LICENSE
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
OPENPOSE: MULTIPERSON KEYPOINT DETECTION
|
2 |
+
SOFTWARE LICENSE AGREEMENT
|
3 |
+
ACADEMIC OR NON-PROFIT ORGANIZATION NONCOMMERCIAL RESEARCH USE ONLY
|
4 |
+
|
5 |
+
BY USING OR DOWNLOADING THE SOFTWARE, YOU ARE AGREEING TO THE TERMS OF THIS LICENSE AGREEMENT. IF YOU DO NOT AGREE WITH THESE TERMS, YOU MAY NOT USE OR DOWNLOAD THE SOFTWARE.
|
6 |
+
|
7 |
+
This is a license agreement ("Agreement") between your academic institution or non-profit organization or self (called "Licensee" or "You" in this Agreement) and Carnegie Mellon University (called "Licensor" in this Agreement). All rights not specifically granted to you in this Agreement are reserved for Licensor.
|
8 |
+
|
9 |
+
RESERVATION OF OWNERSHIP AND GRANT OF LICENSE:
|
10 |
+
Licensor retains exclusive ownership of any copy of the Software (as defined below) licensed under this Agreement and hereby grants to Licensee a personal, non-exclusive,
|
11 |
+
non-transferable license to use the Software for noncommercial research purposes, without the right to sublicense, pursuant to the terms and conditions of this Agreement. As used in this Agreement, the term "Software" means (i) the actual copy of all or any portion of code for program routines made accessible to Licensee by Licensor pursuant to this Agreement, inclusive of backups, updates, and/or merged copies permitted hereunder or subsequently supplied by Licensor, including all or any file structures, programming instructions, user interfaces and screen formats and sequences as well as any and all documentation and instructions related to it, and (ii) all or any derivatives and/or modifications created or made by You to any of the items specified in (i).
|
12 |
+
|
13 |
+
CONFIDENTIALITY: Licensee acknowledges that the Software is proprietary to Licensor, and as such, Licensee agrees to receive all such materials in confidence and use the Software only in accordance with the terms of this Agreement. Licensee agrees to use reasonable effort to protect the Software from unauthorized use, reproduction, distribution, or publication.
|
14 |
+
|
15 |
+
COPYRIGHT: The Software is owned by Licensor and is protected by United
|
16 |
+
States copyright laws and applicable international treaties and/or conventions.
|
17 |
+
|
18 |
+
PERMITTED USES: The Software may be used for your own noncommercial internal research purposes. You understand and agree that Licensor is not obligated to implement any suggestions and/or feedback you might provide regarding the Software, but to the extent Licensor does so, you are not entitled to any compensation related thereto.
|
19 |
+
|
20 |
+
DERIVATIVES: You may create derivatives of or make modifications to the Software, however, You agree that all and any such derivatives and modifications will be owned by Licensor and become a part of the Software licensed to You under this Agreement. You may only use such derivatives and modifications for your own noncommercial internal research purposes, and you may not otherwise use, distribute or copy such derivatives and modifications in violation of this Agreement.
|
21 |
+
|
22 |
+
BACKUPS: If Licensee is an organization, it may make that number of copies of the Software necessary for internal noncommercial use at a single site within its organization provided that all information appearing in or on the original labels, including the copyright and trademark notices are copied onto the labels of the copies.
|
23 |
+
|
24 |
+
USES NOT PERMITTED: You may not distribute, copy or use the Software except as explicitly permitted herein. Licensee has not been granted any trademark license as part of this Agreement and may not use the name or mark “OpenPose", "Carnegie Mellon" or any renditions thereof without the prior written permission of Licensor.
|
25 |
+
|
26 |
+
You may not sell, rent, lease, sublicense, lend, time-share or transfer, in whole or in part, or provide third parties access to prior or present versions (or any parts thereof) of the Software.
|
27 |
+
|
28 |
+
ASSIGNMENT: You may not assign this Agreement or your rights hereunder without the prior written consent of Licensor. Any attempted assignment without such consent shall be null and void.
|
29 |
+
|
30 |
+
TERM: The term of the license granted by this Agreement is from Licensee's acceptance of this Agreement by downloading the Software or by using the Software until terminated as provided below.
|
31 |
+
|
32 |
+
The Agreement automatically terminates without notice if you fail to comply with any provision of this Agreement. Licensee may terminate this Agreement by ceasing using the Software. Upon any termination of this Agreement, Licensee will delete any and all copies of the Software. You agree that all provisions which operate to protect the proprietary rights of Licensor shall remain in force should breach occur and that the obligation of confidentiality described in this Agreement is binding in perpetuity and, as such, survives the term of the Agreement.
|
33 |
+
|
34 |
+
FEE: Provided Licensee abides completely by the terms and conditions of this Agreement, there is no fee due to Licensor for Licensee's use of the Software in accordance with this Agreement.
|
35 |
+
|
36 |
+
DISCLAIMER OF WARRANTIES: THE SOFTWARE IS PROVIDED "AS-IS" WITHOUT WARRANTY OF ANY KIND INCLUDING ANY WARRANTIES OF PERFORMANCE OR MERCHANTABILITY OR FITNESS FOR A PARTICULAR USE OR PURPOSE OR OF NON-INFRINGEMENT. LICENSEE BEARS ALL RISK RELATING TO QUALITY AND PERFORMANCE OF THE SOFTWARE AND RELATED MATERIALS.
|
37 |
+
|
38 |
+
SUPPORT AND MAINTENANCE: No Software support or training by the Licensor is provided as part of this Agreement.
|
39 |
+
|
40 |
+
EXCLUSIVE REMEDY AND LIMITATION OF LIABILITY: To the maximum extent permitted under applicable law, Licensor shall not be liable for direct, indirect, special, incidental, or consequential damages or lost profits related to Licensee's use of and/or inability to use the Software, even if Licensor is advised of the possibility of such damage.
|
41 |
+
|
42 |
+
EXPORT REGULATION: Licensee agrees to comply with any and all applicable
|
43 |
+
U.S. export control laws, regulations, and/or other laws related to embargoes and sanction programs administered by the Office of Foreign Assets Control.
|
44 |
+
|
45 |
+
SEVERABILITY: If any provision(s) of this Agreement shall be held to be invalid, illegal, or unenforceable by a court or other tribunal of competent jurisdiction, the validity, legality and enforceability of the remaining provisions shall not in any way be affected or impaired thereby.
|
46 |
+
|
47 |
+
NO IMPLIED WAIVERS: No failure or delay by Licensor in enforcing any right or remedy under this Agreement shall be construed as a waiver of any future or other exercise of such right or remedy by Licensor.
|
48 |
+
|
49 |
+
GOVERNING LAW: This Agreement shall be construed and enforced in accordance with the laws of the Commonwealth of Pennsylvania without reference to conflict of laws principles. You consent to the personal jurisdiction of the courts of this County and waive their rights to venue outside of Allegheny County, Pennsylvania.
|
50 |
+
|
51 |
+
ENTIRE AGREEMENT AND AMENDMENTS: This Agreement constitutes the sole and entire agreement between Licensee and Licensor as to the matter set forth herein and supersedes any previous agreements, understandings, and arrangements between the parties relating hereto.
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
************************************************************************
|
56 |
+
|
57 |
+
THIRD-PARTY SOFTWARE NOTICES AND INFORMATION
|
58 |
+
|
59 |
+
This project incorporates material from the project(s) listed below (collectively, "Third Party Code"). This Third Party Code is licensed to you under their original license terms set forth below. We reserves all other rights not expressly granted, whether by implication, estoppel or otherwise.
|
60 |
+
|
61 |
+
1. Caffe, version 1.0.0, (https://github.com/BVLC/caffe/)
|
62 |
+
|
63 |
+
COPYRIGHT
|
64 |
+
|
65 |
+
All contributions by the University of California:
|
66 |
+
Copyright (c) 2014-2017 The Regents of the University of California (Regents)
|
67 |
+
All rights reserved.
|
68 |
+
|
69 |
+
All other contributions:
|
70 |
+
Copyright (c) 2014-2017, the respective contributors
|
71 |
+
All rights reserved.
|
72 |
+
|
73 |
+
Caffe uses a shared copyright model: each contributor holds copyright over
|
74 |
+
their contributions to Caffe. The project versioning records all such
|
75 |
+
contribution and copyright details. If a contributor wants to further mark
|
76 |
+
their specific copyright on a particular contribution, they should indicate
|
77 |
+
their copyright solely in the commit message of the change when it is
|
78 |
+
committed.
|
79 |
+
|
80 |
+
LICENSE
|
81 |
+
|
82 |
+
Redistribution and use in source and binary forms, with or without
|
83 |
+
modification, are permitted provided that the following conditions are met:
|
84 |
+
|
85 |
+
1. Redistributions of source code must retain the above copyright notice, this
|
86 |
+
list of conditions and the following disclaimer.
|
87 |
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
88 |
+
this list of conditions and the following disclaimer in the documentation
|
89 |
+
and/or other materials provided with the distribution.
|
90 |
+
|
91 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
92 |
+
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
93 |
+
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
94 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
|
95 |
+
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
96 |
+
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
97 |
+
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
98 |
+
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
99 |
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
100 |
+
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
101 |
+
|
102 |
+
CONTRIBUTION AGREEMENT
|
103 |
+
|
104 |
+
By contributing to the BVLC/caffe repository through pull-request, comment,
|
105 |
+
or otherwise, the contributor releases their content to the
|
106 |
+
license and copyright terms herein.
|
107 |
+
|
108 |
+
************END OF THIRD-PARTY SOFTWARE NOTICES AND INFORMATION**********
|
custom_nodes/ComfyUI-tbox/src/dwpose/__init__.py
ADDED
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Openpose
|
2 |
+
# Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose
|
3 |
+
# 2nd Edited by https://github.com/Hzzone/pytorch-openpose
|
4 |
+
# 3rd Edited by ControlNet
|
5 |
+
# 4th Edited by ControlNet (added face and correct hands)
|
6 |
+
# 5th Edited by ControlNet (Improved JSON serialization/deserialization, and lots of bug fixs)
|
7 |
+
# This preprocessor is licensed by CMU for non-commercial use only.
|
8 |
+
|
9 |
+
import os
|
10 |
+
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
11 |
+
|
12 |
+
import json
|
13 |
+
import torch
|
14 |
+
import numpy as np
|
15 |
+
from . import util
|
16 |
+
from .body import Body, BodyResult, Keypoint
|
17 |
+
from .hand import Hand
|
18 |
+
from .face import Face
|
19 |
+
from .types import PoseResult, HandResult, FaceResult, AnimalPoseResult
|
20 |
+
#from huggingface_hub import hf_hub_download
|
21 |
+
from .wholebody import Wholebody
|
22 |
+
import warnings
|
23 |
+
from common import HWC3, resize_image_with_pad, common_input_validate, custom_hf_download
|
24 |
+
import cv2
|
25 |
+
from PIL import Image
|
26 |
+
from .animalpose import AnimalPoseImage
|
27 |
+
|
28 |
+
from typing import Tuple, List, Callable, Union, Optional
|
29 |
+
|
30 |
+
|
31 |
+
def draw_animalposes(animals: list[list[Keypoint]], H: int, W: int) -> np.ndarray:
|
32 |
+
canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
|
33 |
+
for animal_pose in animals:
|
34 |
+
canvas = draw_animalpose(canvas, animal_pose)
|
35 |
+
return canvas
|
36 |
+
|
37 |
+
|
38 |
+
def draw_animalpose(canvas: np.ndarray, keypoints: list[Keypoint]) -> np.ndarray:
|
39 |
+
# order of the keypoints for AP10k and a standardized list of colors for limbs
|
40 |
+
keypointPairsList = [
|
41 |
+
(1, 2),
|
42 |
+
(2, 3),
|
43 |
+
(1, 3),
|
44 |
+
(3, 4),
|
45 |
+
(4, 9),
|
46 |
+
(9, 10),
|
47 |
+
(10, 11),
|
48 |
+
(4, 6),
|
49 |
+
(6, 7),
|
50 |
+
(7, 8),
|
51 |
+
(4, 5),
|
52 |
+
(5, 15),
|
53 |
+
(15, 16),
|
54 |
+
(16, 17),
|
55 |
+
(5, 12),
|
56 |
+
(12, 13),
|
57 |
+
(13, 14),
|
58 |
+
]
|
59 |
+
colorsList = [
|
60 |
+
(255, 255, 255),
|
61 |
+
(100, 255, 100),
|
62 |
+
(150, 255, 255),
|
63 |
+
(100, 50, 255),
|
64 |
+
(50, 150, 200),
|
65 |
+
(0, 255, 255),
|
66 |
+
(0, 150, 0),
|
67 |
+
(0, 0, 255),
|
68 |
+
(0, 0, 150),
|
69 |
+
(255, 50, 255),
|
70 |
+
(255, 0, 255),
|
71 |
+
(255, 0, 0),
|
72 |
+
(150, 0, 0),
|
73 |
+
(255, 255, 100),
|
74 |
+
(0, 150, 0),
|
75 |
+
(255, 255, 0),
|
76 |
+
(150, 150, 150),
|
77 |
+
] # 16 colors needed
|
78 |
+
|
79 |
+
for ind, (i, j) in enumerate(keypointPairsList):
|
80 |
+
p1 = keypoints[i - 1]
|
81 |
+
p2 = keypoints[j - 1]
|
82 |
+
|
83 |
+
if p1 is not None and p2 is not None:
|
84 |
+
cv2.line(
|
85 |
+
canvas,
|
86 |
+
(int(p1.x), int(p1.y)),
|
87 |
+
(int(p2.x), int(p2.y)),
|
88 |
+
colorsList[ind],
|
89 |
+
5,
|
90 |
+
)
|
91 |
+
return canvas
|
92 |
+
|
93 |
+
|
94 |
+
def draw_poses(poses: List[PoseResult], H, W, draw_body=True, draw_hand=True, draw_face=True):
|
95 |
+
"""
|
96 |
+
Draw the detected poses on an empty canvas.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
poses (List[PoseResult]): A list of PoseResult objects containing the detected poses.
|
100 |
+
H (int): The height of the canvas.
|
101 |
+
W (int): The width of the canvas.
|
102 |
+
draw_body (bool, optional): Whether to draw body keypoints. Defaults to True.
|
103 |
+
draw_hand (bool, optional): Whether to draw hand keypoints. Defaults to True.
|
104 |
+
draw_face (bool, optional): Whether to draw face keypoints. Defaults to True.
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
numpy.ndarray: A 3D numpy array representing the canvas with the drawn poses.
|
108 |
+
"""
|
109 |
+
canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
|
110 |
+
|
111 |
+
for pose in poses:
|
112 |
+
if draw_body:
|
113 |
+
canvas = util.draw_bodypose(canvas, pose.body.keypoints)
|
114 |
+
|
115 |
+
if draw_hand:
|
116 |
+
canvas = util.draw_handpose(canvas, pose.left_hand)
|
117 |
+
canvas = util.draw_handpose(canvas, pose.right_hand)
|
118 |
+
|
119 |
+
if draw_face:
|
120 |
+
canvas = util.draw_facepose(canvas, pose.face)
|
121 |
+
|
122 |
+
return canvas
|
123 |
+
|
124 |
+
|
125 |
+
def decode_json_as_poses(
|
126 |
+
pose_json: dict,
|
127 |
+
) -> Tuple[List[PoseResult], List[AnimalPoseResult], int, int]:
|
128 |
+
"""Decode the json_string complying with the openpose JSON output format
|
129 |
+
to poses that controlnet recognizes.
|
130 |
+
https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/doc/02_output.md
|
131 |
+
|
132 |
+
Args:
|
133 |
+
json_string: The json string to decode.
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
human_poses
|
137 |
+
animal_poses
|
138 |
+
canvas_height
|
139 |
+
canvas_width
|
140 |
+
"""
|
141 |
+
height = pose_json["canvas_height"]
|
142 |
+
width = pose_json["canvas_width"]
|
143 |
+
|
144 |
+
def chunks(lst, n):
|
145 |
+
"""Yield successive n-sized chunks from lst."""
|
146 |
+
for i in range(0, len(lst), n):
|
147 |
+
yield lst[i : i + n]
|
148 |
+
|
149 |
+
def decompress_keypoints(
|
150 |
+
numbers: Optional[List[float]],
|
151 |
+
) -> Optional[List[Optional[Keypoint]]]:
|
152 |
+
if not numbers:
|
153 |
+
return None
|
154 |
+
|
155 |
+
assert len(numbers) % 3 == 0
|
156 |
+
|
157 |
+
def create_keypoint(x, y, c):
|
158 |
+
if c < 1.0:
|
159 |
+
return None
|
160 |
+
keypoint = Keypoint(x, y)
|
161 |
+
return keypoint
|
162 |
+
|
163 |
+
return [create_keypoint(x, y, c) for x, y, c in chunks(numbers, n=3)]
|
164 |
+
|
165 |
+
return (
|
166 |
+
[
|
167 |
+
PoseResult(
|
168 |
+
body=BodyResult(
|
169 |
+
keypoints=decompress_keypoints(pose.get("pose_keypoints_2d"))
|
170 |
+
),
|
171 |
+
left_hand=decompress_keypoints(pose.get("hand_left_keypoints_2d")),
|
172 |
+
right_hand=decompress_keypoints(pose.get("hand_right_keypoints_2d")),
|
173 |
+
face=decompress_keypoints(pose.get("face_keypoints_2d")),
|
174 |
+
)
|
175 |
+
for pose in pose_json.get("people", [])
|
176 |
+
],
|
177 |
+
[decompress_keypoints(pose) for pose in pose_json.get("animals", [])],
|
178 |
+
height,
|
179 |
+
width,
|
180 |
+
)
|
181 |
+
|
182 |
+
|
183 |
+
def encode_poses_as_dict(poses: List[PoseResult], canvas_height: int, canvas_width: int) -> str:
|
184 |
+
""" Encode the pose as a dict following openpose JSON output format:
|
185 |
+
https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/doc/02_output.md
|
186 |
+
"""
|
187 |
+
def compress_keypoints(keypoints: Union[List[Keypoint], None]) -> Union[List[float], None]:
|
188 |
+
if not keypoints:
|
189 |
+
return None
|
190 |
+
|
191 |
+
return [
|
192 |
+
value
|
193 |
+
for keypoint in keypoints
|
194 |
+
for value in (
|
195 |
+
[float(keypoint.x), float(keypoint.y), 1.0]
|
196 |
+
if keypoint is not None
|
197 |
+
else [0.0, 0.0, 0.0]
|
198 |
+
)
|
199 |
+
]
|
200 |
+
|
201 |
+
return {
|
202 |
+
'people': [
|
203 |
+
{
|
204 |
+
'pose_keypoints_2d': compress_keypoints(pose.body.keypoints),
|
205 |
+
"face_keypoints_2d": compress_keypoints(pose.face),
|
206 |
+
"hand_left_keypoints_2d": compress_keypoints(pose.left_hand),
|
207 |
+
"hand_right_keypoints_2d":compress_keypoints(pose.right_hand),
|
208 |
+
}
|
209 |
+
for pose in poses
|
210 |
+
],
|
211 |
+
'canvas_height': canvas_height,
|
212 |
+
'canvas_width': canvas_width,
|
213 |
+
}
|
214 |
+
|
215 |
+
global_cached_dwpose = Wholebody()
|
216 |
+
|
217 |
+
class DwposeDetector:
|
218 |
+
"""
|
219 |
+
A class for detecting human poses in images using the Dwpose model.
|
220 |
+
|
221 |
+
Attributes:
|
222 |
+
model_dir (str): Path to the directory where the pose models are stored.
|
223 |
+
"""
|
224 |
+
def __init__(self, dw_pose_estimation):
|
225 |
+
self.dw_pose_estimation = dw_pose_estimation
|
226 |
+
|
227 |
+
@classmethod
|
228 |
+
def from_pretrained(cls, pretrained_model_or_path, pretrained_det_model_or_path=None, det_filename=None, pose_filename=None, torchscript_device="cuda"):
|
229 |
+
global global_cached_dwpose
|
230 |
+
pretrained_det_model_or_path = pretrained_det_model_or_path or pretrained_model_or_path
|
231 |
+
det_filename = det_filename or "yolox_l.onnx"
|
232 |
+
pose_filename = pose_filename or "dw-ll_ucoco_384.onnx"
|
233 |
+
det_model_path = custom_hf_download(pretrained_det_model_or_path, det_filename)
|
234 |
+
pose_model_path = custom_hf_download(pretrained_model_or_path, pose_filename)
|
235 |
+
|
236 |
+
print(f"\nDWPose: Using {det_filename} for bbox detection and {pose_filename} for pose estimation")
|
237 |
+
if global_cached_dwpose.det is None or global_cached_dwpose.det_filename != det_filename:
|
238 |
+
t = Wholebody(det_model_path, None, torchscript_device=torchscript_device)
|
239 |
+
t.pose = global_cached_dwpose.pose
|
240 |
+
t.pose_filename = global_cached_dwpose.pose
|
241 |
+
global_cached_dwpose = t
|
242 |
+
|
243 |
+
if global_cached_dwpose.pose is None or global_cached_dwpose.pose_filename != pose_filename:
|
244 |
+
t = Wholebody(None, pose_model_path, torchscript_device=torchscript_device)
|
245 |
+
t.det = global_cached_dwpose.det
|
246 |
+
t.det_filename = global_cached_dwpose.det_filename
|
247 |
+
global_cached_dwpose = t
|
248 |
+
return cls(global_cached_dwpose)
|
249 |
+
|
250 |
+
def detect_poses(self, oriImg) -> List[PoseResult]:
|
251 |
+
with torch.no_grad():
|
252 |
+
keypoints_info = self.dw_pose_estimation(oriImg.copy())
|
253 |
+
return Wholebody.format_result(keypoints_info)
|
254 |
+
|
255 |
+
def __call__(self, input_image, detect_resolution=512, include_body=True, include_hand=False, include_face=False, hand_and_face=None, output_type="pil", image_and_json=False, upscale_method="INTER_CUBIC", **kwargs):
|
256 |
+
if hand_and_face is not None:
|
257 |
+
warnings.warn("hand_and_face is deprecated. Use include_hand and include_face instead.", DeprecationWarning)
|
258 |
+
include_hand = hand_and_face
|
259 |
+
include_face = hand_and_face
|
260 |
+
|
261 |
+
input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
|
262 |
+
poses = self.detect_poses(input_image)
|
263 |
+
|
264 |
+
canvas = draw_poses(poses, input_image.shape[0], input_image.shape[1], draw_body=include_body, draw_hand=include_hand, draw_face=include_face)
|
265 |
+
canvas, remove_pad = resize_image_with_pad(canvas, detect_resolution, upscale_method)
|
266 |
+
detected_map = HWC3(remove_pad(canvas))
|
267 |
+
|
268 |
+
if output_type == "pil":
|
269 |
+
detected_map = Image.fromarray(detected_map)
|
270 |
+
|
271 |
+
if image_and_json:
|
272 |
+
return (detected_map, encode_poses_as_dict(poses, input_image.shape[0], input_image.shape[1]))
|
273 |
+
|
274 |
+
return detected_map
|
275 |
+
|
276 |
+
global_cached_animalpose = AnimalPoseImage()
|
277 |
+
class AnimalposeDetector:
|
278 |
+
"""
|
279 |
+
A class for detecting animal poses in images using the RTMPose AP10k model.
|
280 |
+
|
281 |
+
Attributes:
|
282 |
+
model_dir (str): Path to the directory where the pose models are stored.
|
283 |
+
"""
|
284 |
+
def __init__(self, animal_pose_estimation):
|
285 |
+
self.animal_pose_estimation = animal_pose_estimation
|
286 |
+
|
287 |
+
@classmethod
|
288 |
+
def from_pretrained(cls, pretrained_model_or_path, pretrained_det_model_or_path=None, det_filename="yolox_l.onnx", pose_filename="dw-ll_ucoco_384.onnx", torchscript_device="cuda"):
|
289 |
+
global global_cached_animalpose
|
290 |
+
det_model_path = custom_hf_download(pretrained_det_model_or_path, det_filename)
|
291 |
+
pose_model_path = custom_hf_download(pretrained_model_or_path, pose_filename)
|
292 |
+
|
293 |
+
print(f"\nAnimalPose: Using {det_filename} for bbox detection and {pose_filename} for pose estimation")
|
294 |
+
if global_cached_animalpose.det is None or global_cached_animalpose.det_filename != det_filename:
|
295 |
+
t = AnimalPoseImage(det_model_path, None, torchscript_device=torchscript_device)
|
296 |
+
t.pose = global_cached_animalpose.pose
|
297 |
+
t.pose_filename = global_cached_animalpose.pose
|
298 |
+
global_cached_animalpose = t
|
299 |
+
|
300 |
+
if global_cached_animalpose.pose is None or global_cached_animalpose.pose_filename != pose_filename:
|
301 |
+
t = AnimalPoseImage(None, pose_model_path, torchscript_device=torchscript_device)
|
302 |
+
t.det = global_cached_animalpose.det
|
303 |
+
t.det_filename = global_cached_animalpose.det_filename
|
304 |
+
global_cached_animalpose = t
|
305 |
+
return cls(global_cached_animalpose)
|
306 |
+
|
307 |
+
def __call__(self, input_image, detect_resolution=512, output_type="pil", image_and_json=False, upscale_method="INTER_CUBIC", **kwargs):
|
308 |
+
input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
|
309 |
+
input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
|
310 |
+
result = self.animal_pose_estimation(input_image)
|
311 |
+
if result is None:
|
312 |
+
detected_map = np.zeros_like(input_image)
|
313 |
+
openpose_dict = {
|
314 |
+
'version': 'ap10k',
|
315 |
+
'animals': [],
|
316 |
+
'canvas_height': input_image.shape[0],
|
317 |
+
'canvas_width': input_image.shape[1]
|
318 |
+
}
|
319 |
+
else:
|
320 |
+
detected_map, openpose_dict = result
|
321 |
+
detected_map = remove_pad(detected_map)
|
322 |
+
if output_type == "pil":
|
323 |
+
detected_map = Image.fromarray(detected_map)
|
324 |
+
|
325 |
+
if image_and_json:
|
326 |
+
return (detected_map, openpose_dict)
|
327 |
+
|
328 |
+
return detected_map
|
custom_nodes/ComfyUI-tbox/src/dwpose/animalpose.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
import os
|
4 |
+
import cv2
|
5 |
+
from .dw_onnx.cv_ox_det import inference_detector as inference_onnx_yolox
|
6 |
+
from .dw_onnx.cv_ox_yolo_nas import inference_detector as inference_onnx_yolo_nas
|
7 |
+
from .dw_onnx.cv_ox_pose import inference_pose as inference_onnx_pose
|
8 |
+
|
9 |
+
from .dw_torchscript.jit_det import inference_detector as inference_jit_yolox
|
10 |
+
from .dw_torchscript.jit_pose import inference_pose as inference_jit_pose
|
11 |
+
from typing import List, Optional
|
12 |
+
from .types import PoseResult, BodyResult, Keypoint
|
13 |
+
from timeit import default_timer
|
14 |
+
from .util import guess_onnx_input_shape_dtype, get_ort_providers, get_model_type, is_model_torchscript
|
15 |
+
import json
|
16 |
+
import torch
|
17 |
+
|
18 |
+
def drawBetweenKeypoints(pose_img, keypoints, indexes, color, scaleFactor):
|
19 |
+
ind0 = indexes[0] - 1
|
20 |
+
ind1 = indexes[1] - 1
|
21 |
+
|
22 |
+
point1 = (keypoints[ind0][0], keypoints[ind0][1])
|
23 |
+
point2 = (keypoints[ind1][0], keypoints[ind1][1])
|
24 |
+
|
25 |
+
thickness = int(5 // scaleFactor)
|
26 |
+
|
27 |
+
|
28 |
+
cv2.line(pose_img, (int(point1[0]), int(point1[1])), (int(point2[0]), int(point2[1])), color, thickness)
|
29 |
+
|
30 |
+
|
31 |
+
def drawBetweenKeypointsList(pose_img, keypoints, keypointPairsList, colorsList, scaleFactor):
|
32 |
+
for ind, keypointPair in enumerate(keypointPairsList):
|
33 |
+
drawBetweenKeypoints(pose_img, keypoints, keypointPair, colorsList[ind], scaleFactor)
|
34 |
+
|
35 |
+
def drawBetweenSetofKeypointLists(pose_img, keypoints_set, keypointPairsList, colorsList, scaleFactor):
|
36 |
+
for keypoints in keypoints_set:
|
37 |
+
drawBetweenKeypointsList(pose_img, keypoints, keypointPairsList, colorsList, scaleFactor)
|
38 |
+
|
39 |
+
|
40 |
+
def padImg(img, size, blackBorder=True):
|
41 |
+
left, right, top, bottom = 0, 0, 0, 0
|
42 |
+
|
43 |
+
# pad x
|
44 |
+
if img.shape[1] < size[1]:
|
45 |
+
sidePadding = int((size[1] - img.shape[1]) // 2)
|
46 |
+
left = sidePadding
|
47 |
+
right = sidePadding
|
48 |
+
|
49 |
+
# pad extra on right if padding needed is an odd number
|
50 |
+
if img.shape[1] % 2 == 1:
|
51 |
+
right += 1
|
52 |
+
|
53 |
+
# pad y
|
54 |
+
if img.shape[0] < size[0]:
|
55 |
+
topBottomPadding = int((size[0] - img.shape[0]) // 2)
|
56 |
+
top = topBottomPadding
|
57 |
+
bottom = topBottomPadding
|
58 |
+
|
59 |
+
# pad extra on bottom if padding needed is an odd number
|
60 |
+
if img.shape[0] % 2 == 1:
|
61 |
+
bottom += 1
|
62 |
+
|
63 |
+
if blackBorder:
|
64 |
+
paddedImg = cv2.copyMakeBorder(src=img, top=top, bottom=bottom, left=left, right=right, borderType=cv2.BORDER_CONSTANT, value=(0,0,0))
|
65 |
+
else:
|
66 |
+
paddedImg = cv2.copyMakeBorder(src=img, top=top, bottom=bottom, left=left, right=right, borderType=cv2.BORDER_REPLICATE)
|
67 |
+
|
68 |
+
return paddedImg
|
69 |
+
|
70 |
+
def smartCrop(img, size, center):
|
71 |
+
|
72 |
+
width = img.shape[1]
|
73 |
+
height = img.shape[0]
|
74 |
+
xSize = size[1]
|
75 |
+
ySize = size[0]
|
76 |
+
xCenter = center[0]
|
77 |
+
yCenter = center[1]
|
78 |
+
|
79 |
+
if img.shape[0] > size[0] or img.shape[1] > size[1]:
|
80 |
+
|
81 |
+
|
82 |
+
leftMargin = xCenter - xSize//2
|
83 |
+
rightMargin = xCenter + xSize//2
|
84 |
+
upMargin = yCenter - ySize//2
|
85 |
+
downMargin = yCenter + ySize//2
|
86 |
+
|
87 |
+
|
88 |
+
if(leftMargin < 0):
|
89 |
+
xCenter += (-leftMargin)
|
90 |
+
if(rightMargin > width):
|
91 |
+
xCenter -= (rightMargin - width)
|
92 |
+
|
93 |
+
if(upMargin < 0):
|
94 |
+
yCenter -= -upMargin
|
95 |
+
if(downMargin > height):
|
96 |
+
yCenter -= (downMargin - height)
|
97 |
+
|
98 |
+
|
99 |
+
img = cv2.getRectSubPix(img, size, (xCenter, yCenter))
|
100 |
+
|
101 |
+
|
102 |
+
|
103 |
+
return img
|
104 |
+
|
105 |
+
|
106 |
+
|
107 |
+
def calculateScaleFactor(img, size, poseSpanX, poseSpanY):
|
108 |
+
|
109 |
+
poseSpanX = max(poseSpanX, size[0])
|
110 |
+
|
111 |
+
scaleFactorX = 1
|
112 |
+
|
113 |
+
|
114 |
+
if poseSpanX > size[0]:
|
115 |
+
scaleFactorX = size[0] / poseSpanX
|
116 |
+
|
117 |
+
scaleFactorY = 1
|
118 |
+
if poseSpanY > size[1]:
|
119 |
+
scaleFactorY = size[1] / poseSpanY
|
120 |
+
|
121 |
+
scaleFactor = min(scaleFactorX, scaleFactorY)
|
122 |
+
|
123 |
+
|
124 |
+
return scaleFactor
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
+
def scaleImg(img, size, poseSpanX, poseSpanY, scaleFactor):
|
129 |
+
scaledImg = img
|
130 |
+
|
131 |
+
scaledImg = cv2.resize(img, (0, 0), fx=scaleFactor, fy=scaleFactor)
|
132 |
+
|
133 |
+
return scaledImg, scaleFactor
|
134 |
+
|
135 |
+
class AnimalPoseImage:
|
136 |
+
def __init__(self, det_model_path: Optional[str] = None, pose_model_path: Optional[str] = None, torchscript_device="cuda"):
|
137 |
+
self.det_filename = det_model_path and os.path.basename(det_model_path)
|
138 |
+
self.pose_filename = pose_model_path and os.path.basename(pose_model_path)
|
139 |
+
self.det, self.pose = None, None
|
140 |
+
# return type: None ort cv2 torchscript
|
141 |
+
self.det_model_type = get_model_type("AnimalPose",self.det_filename)
|
142 |
+
self.pose_model_type = get_model_type("AnimalPose",self.pose_filename)
|
143 |
+
# Always loads to CPU to avoid building OpenCV.
|
144 |
+
cv2_device = 'cpu'
|
145 |
+
cv2_backend = cv2.dnn.DNN_BACKEND_OPENCV if cv2_device == 'cpu' else cv2.dnn.DNN_BACKEND_CUDA
|
146 |
+
# You need to manually build OpenCV through cmake to work with your GPU.
|
147 |
+
cv2_providers = cv2.dnn.DNN_TARGET_CPU if cv2_device == 'cpu' else cv2.dnn.DNN_TARGET_CUDA
|
148 |
+
ort_providers = get_ort_providers()
|
149 |
+
|
150 |
+
if self.det_model_type is None:
|
151 |
+
pass
|
152 |
+
elif self.det_model_type == "ort":
|
153 |
+
try:
|
154 |
+
import onnxruntime as ort
|
155 |
+
self.det = ort.InferenceSession(det_model_path, providers=ort_providers)
|
156 |
+
except:
|
157 |
+
print(f"Failed to load onnxruntime with {self.det.get_providers()}.\nPlease change EP_list in the config.yaml and restart ComfyUI")
|
158 |
+
self.det = ort.InferenceSession(det_model_path, providers=["CPUExecutionProvider"])
|
159 |
+
elif self.det_model_type == "cv2":
|
160 |
+
try:
|
161 |
+
self.det = cv2.dnn.readNetFromONNX(det_model_path)
|
162 |
+
self.det.setPreferableBackend(cv2_backend)
|
163 |
+
self.det.setPreferableTarget(cv2_providers)
|
164 |
+
except:
|
165 |
+
print("TopK operators may not work on your OpenCV, try use onnxruntime with CPUExecutionProvider")
|
166 |
+
try:
|
167 |
+
import onnxruntime as ort
|
168 |
+
self.det = ort.InferenceSession(det_model_path, providers=["CPUExecutionProvider"])
|
169 |
+
except:
|
170 |
+
print(f"Failed to load {det_model_path}, you can use other models instead")
|
171 |
+
else:
|
172 |
+
self.det = torch.jit.load(det_model_path)
|
173 |
+
self.det.to(torchscript_device)
|
174 |
+
|
175 |
+
if self.pose_model_type is None:
|
176 |
+
pass
|
177 |
+
elif self.pose_model_type == "ort":
|
178 |
+
try:
|
179 |
+
import onnxruntime as ort
|
180 |
+
self.pose = ort.InferenceSession(pose_model_path, providers=ort_providers)
|
181 |
+
except:
|
182 |
+
print(f"Failed to load onnxruntime with {self.pose.get_providers()}.\nPlease change EP_list in the config.yaml and restart ComfyUI")
|
183 |
+
self.pose = ort.InferenceSession(pose_model_path, providers=["CPUExecutionProvider"])
|
184 |
+
elif self.pose_model_type == "cv2":
|
185 |
+
self.pose = cv2.dnn.readNetFromONNX(pose_model_path)
|
186 |
+
self.pose.setPreferableBackend(cv2_backend)
|
187 |
+
self.pose.setPreferableTarget(cv2_providers)
|
188 |
+
else:
|
189 |
+
self.pose = torch.jit.load(pose_model_path)
|
190 |
+
self.pose.to(torchscript_device)
|
191 |
+
|
192 |
+
if self.pose_filename is not None:
|
193 |
+
self.pose_input_size, _ = guess_onnx_input_shape_dtype(self.pose_filename)
|
194 |
+
|
195 |
+
def __call__(self, oriImg):
|
196 |
+
import torch.utils.benchmark.utils.timer as torch_timer
|
197 |
+
detect_classes = list(range(14, 23 + 1)) #https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/datasets/coco.yaml
|
198 |
+
|
199 |
+
if is_model_torchscript(self.det):
|
200 |
+
det_start = torch_timer.timer()
|
201 |
+
det_result = inference_jit_yolox(self.det, oriImg, detect_classes=detect_classes)
|
202 |
+
print(f"AnimalPose: Bbox {((torch_timer.timer() - det_start) * 1000):.2f}ms")
|
203 |
+
else:
|
204 |
+
det_start = default_timer()
|
205 |
+
det_onnx_dtype = np.float32 if "yolox" in self.det_filename else np.uint8
|
206 |
+
if "yolox" in self.det_filename:
|
207 |
+
det_result = inference_onnx_yolox(self.det, oriImg, detect_classes=detect_classes, dtype=det_onnx_dtype)
|
208 |
+
else:
|
209 |
+
#FP16 and INT8 YOLO NAS accept uint8 input
|
210 |
+
det_result = inference_onnx_yolo_nas(self.det, oriImg, detect_classes=detect_classes, dtype=det_onnx_dtype)
|
211 |
+
print(f"AnimalPose: Bbox {((default_timer() - det_start) * 1000):.2f}ms")
|
212 |
+
if (det_result is None) or (det_result.shape[0] == 0):
|
213 |
+
openpose_dict = {
|
214 |
+
'version': 'ap10k',
|
215 |
+
'animals': [],
|
216 |
+
'canvas_height': oriImg.shape[0],
|
217 |
+
'canvas_width': oriImg.shape[1]
|
218 |
+
}
|
219 |
+
return np.zeros_like(oriImg), openpose_dict
|
220 |
+
|
221 |
+
if is_model_torchscript(self.pose):
|
222 |
+
pose_start = torch_timer.timer()
|
223 |
+
keypoint_sets, scores = inference_jit_pose(self.pose, det_result, oriImg, self.pose_input_size)
|
224 |
+
print(f"AnimalPose: Pose {((torch_timer.timer() - pose_start) * 1000):.2f}ms on {det_result.shape[0]} animals\n")
|
225 |
+
else:
|
226 |
+
pose_start = default_timer()
|
227 |
+
_, pose_onnx_dtype = guess_onnx_input_shape_dtype(self.pose_filename)
|
228 |
+
keypoint_sets, scores = inference_onnx_pose(self.pose, det_result, oriImg, self.pose_input_size, dtype=pose_onnx_dtype)
|
229 |
+
print(f"AnimalPose: Pose {((default_timer() - pose_start) * 1000):.2f}ms on {det_result.shape[0]} animals\n")
|
230 |
+
|
231 |
+
animal_kps_scores = []
|
232 |
+
pose_img = np.zeros((oriImg.shape[0], oriImg.shape[1], 3), dtype = np.uint8)
|
233 |
+
for (idx, keypoints) in enumerate(keypoint_sets):
|
234 |
+
# don't use keypoints that go outside the frame in calculations for the center
|
235 |
+
interorKeypoints = keypoints[((keypoints[:,0] > 0) & (keypoints[:,0] < oriImg.shape[1])) & ((keypoints[:,1] > 0) & (keypoints[:,1] < oriImg.shape[0]))]
|
236 |
+
|
237 |
+
xVals = interorKeypoints[:,0]
|
238 |
+
yVals = interorKeypoints[:,1]
|
239 |
+
|
240 |
+
minX = np.amin(xVals)
|
241 |
+
minY = np.amin(yVals)
|
242 |
+
maxX = np.amax(xVals)
|
243 |
+
maxY = np.amax(yVals)
|
244 |
+
|
245 |
+
poseSpanX = maxX - minX
|
246 |
+
poseSpanY = maxY - minY
|
247 |
+
|
248 |
+
# find mean center
|
249 |
+
|
250 |
+
xSum = np.sum(xVals)
|
251 |
+
ySum = np.sum(yVals)
|
252 |
+
|
253 |
+
xCenter = xSum // xVals.shape[0]
|
254 |
+
yCenter = ySum // yVals.shape[0]
|
255 |
+
center_of_keypoints = (xCenter,yCenter)
|
256 |
+
|
257 |
+
# order of the keypoints for AP10k and a standardized list of colors for limbs
|
258 |
+
keypointPairsList = [(1,2), (2,3), (1,3), (3,4), (4,9), (9,10), (10,11), (4,6), (6,7), (7,8), (4,5), (5,15), (15,16), (16,17), (5,12), (12,13), (13,14)]
|
259 |
+
colorsList = [(255,255,255), (100,255,100), (150,255,255), (100,50,255), (50,150,200), (0,255,255), (0,150,0), (0,0,255), (0,0,150), (255,50,255), (255,0,255), (255,0,0), (150,0,0), (255,255,100), (0,150,0), (255,255,0), (150,150,150)] # 16 colors needed
|
260 |
+
|
261 |
+
drawBetweenKeypointsList(pose_img, keypoints, keypointPairsList, colorsList, scaleFactor=1.0)
|
262 |
+
score = scores[idx, ..., None]
|
263 |
+
score[score > 1.0] = 1.0
|
264 |
+
score[score < 0.0] = 0.0
|
265 |
+
animal_kps_scores.append(np.concatenate((keypoints, score), axis=-1))
|
266 |
+
|
267 |
+
openpose_dict = {
|
268 |
+
'version': 'ap10k',
|
269 |
+
'animals': [keypoints.tolist() for keypoints in animal_kps_scores],
|
270 |
+
'canvas_height': oriImg.shape[0],
|
271 |
+
'canvas_width': oriImg.shape[1]
|
272 |
+
}
|
273 |
+
return pose_img, openpose_dict
|
custom_nodes/ComfyUI-tbox/src/dwpose/body.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import math
|
4 |
+
import time
|
5 |
+
from scipy.ndimage.filters import gaussian_filter
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import matplotlib
|
8 |
+
import torch
|
9 |
+
from torchvision import transforms
|
10 |
+
from typing import NamedTuple, List, Union
|
11 |
+
|
12 |
+
from . import util
|
13 |
+
from .model import bodypose_model
|
14 |
+
from .types import Keypoint, BodyResult
|
15 |
+
|
16 |
+
class Body(object):
|
17 |
+
def __init__(self, model_path):
|
18 |
+
self.model = bodypose_model()
|
19 |
+
# if torch.cuda.is_available():
|
20 |
+
# self.model = self.model.cuda()
|
21 |
+
# print('cuda')
|
22 |
+
model_dict = util.transfer(self.model, torch.load(model_path))
|
23 |
+
self.model.load_state_dict(model_dict)
|
24 |
+
self.model.eval()
|
25 |
+
|
26 |
+
def __call__(self, oriImg):
|
27 |
+
# scale_search = [0.5, 1.0, 1.5, 2.0]
|
28 |
+
scale_search = [0.5]
|
29 |
+
boxsize = 368
|
30 |
+
stride = 8
|
31 |
+
padValue = 128
|
32 |
+
thre1 = 0.1
|
33 |
+
thre2 = 0.05
|
34 |
+
multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
|
35 |
+
heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19))
|
36 |
+
paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
|
37 |
+
|
38 |
+
for m in range(len(multiplier)):
|
39 |
+
scale = multiplier[m]
|
40 |
+
imageToTest = util.smart_resize_k(oriImg, fx=scale, fy=scale)
|
41 |
+
imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue)
|
42 |
+
im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
|
43 |
+
im = np.ascontiguousarray(im)
|
44 |
+
|
45 |
+
data = torch.from_numpy(im).float()
|
46 |
+
if torch.cuda.is_available():
|
47 |
+
data = data.cuda()
|
48 |
+
# data = data.permute([2, 0, 1]).unsqueeze(0).float()
|
49 |
+
with torch.no_grad():
|
50 |
+
data = data.to(self.cn_device)
|
51 |
+
Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data)
|
52 |
+
Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy()
|
53 |
+
Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy()
|
54 |
+
|
55 |
+
# extract outputs, resize, and remove padding
|
56 |
+
# heatmap = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[1]].data), (1, 2, 0)) # output 1 is heatmaps
|
57 |
+
heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2), (1, 2, 0)) # output 1 is heatmaps
|
58 |
+
heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride)
|
59 |
+
heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
|
60 |
+
heatmap = util.smart_resize(heatmap, (oriImg.shape[0], oriImg.shape[1]))
|
61 |
+
|
62 |
+
# paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs
|
63 |
+
paf = np.transpose(np.squeeze(Mconv7_stage6_L1), (1, 2, 0)) # output 0 is PAFs
|
64 |
+
paf = util.smart_resize_k(paf, fx=stride, fy=stride)
|
65 |
+
paf = paf[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
|
66 |
+
paf = util.smart_resize(paf, (oriImg.shape[0], oriImg.shape[1]))
|
67 |
+
|
68 |
+
heatmap_avg += heatmap_avg + heatmap / len(multiplier)
|
69 |
+
paf_avg += + paf / len(multiplier)
|
70 |
+
|
71 |
+
all_peaks = []
|
72 |
+
peak_counter = 0
|
73 |
+
|
74 |
+
for part in range(18):
|
75 |
+
map_ori = heatmap_avg[:, :, part]
|
76 |
+
one_heatmap = gaussian_filter(map_ori, sigma=3)
|
77 |
+
|
78 |
+
map_left = np.zeros(one_heatmap.shape)
|
79 |
+
map_left[1:, :] = one_heatmap[:-1, :]
|
80 |
+
map_right = np.zeros(one_heatmap.shape)
|
81 |
+
map_right[:-1, :] = one_heatmap[1:, :]
|
82 |
+
map_up = np.zeros(one_heatmap.shape)
|
83 |
+
map_up[:, 1:] = one_heatmap[:, :-1]
|
84 |
+
map_down = np.zeros(one_heatmap.shape)
|
85 |
+
map_down[:, :-1] = one_heatmap[:, 1:]
|
86 |
+
|
87 |
+
peaks_binary = np.logical_and.reduce(
|
88 |
+
(one_heatmap >= map_left, one_heatmap >= map_right, one_heatmap >= map_up, one_heatmap >= map_down, one_heatmap > thre1))
|
89 |
+
peaks = list(zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])) # note reverse
|
90 |
+
peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks]
|
91 |
+
peak_id = range(peak_counter, peak_counter + len(peaks))
|
92 |
+
peaks_with_score_and_id = [peaks_with_score[i] + (peak_id[i],) for i in range(len(peak_id))]
|
93 |
+
|
94 |
+
all_peaks.append(peaks_with_score_and_id)
|
95 |
+
peak_counter += len(peaks)
|
96 |
+
|
97 |
+
# find connection in the specified sequence, center 29 is in the position 15
|
98 |
+
limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
|
99 |
+
[10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
|
100 |
+
[1, 16], [16, 18], [3, 17], [6, 18]]
|
101 |
+
# the middle joints heatmap correpondence
|
102 |
+
mapIdx = [[31, 32], [39, 40], [33, 34], [35, 36], [41, 42], [43, 44], [19, 20], [21, 22], \
|
103 |
+
[23, 24], [25, 26], [27, 28], [29, 30], [47, 48], [49, 50], [53, 54], [51, 52], \
|
104 |
+
[55, 56], [37, 38], [45, 46]]
|
105 |
+
|
106 |
+
connection_all = []
|
107 |
+
special_k = []
|
108 |
+
mid_num = 10
|
109 |
+
|
110 |
+
for k in range(len(mapIdx)):
|
111 |
+
score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]]
|
112 |
+
candA = all_peaks[limbSeq[k][0] - 1]
|
113 |
+
candB = all_peaks[limbSeq[k][1] - 1]
|
114 |
+
nA = len(candA)
|
115 |
+
nB = len(candB)
|
116 |
+
indexA, indexB = limbSeq[k]
|
117 |
+
if (nA != 0 and nB != 0):
|
118 |
+
connection_candidate = []
|
119 |
+
for i in range(nA):
|
120 |
+
for j in range(nB):
|
121 |
+
vec = np.subtract(candB[j][:2], candA[i][:2])
|
122 |
+
norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1])
|
123 |
+
norm = max(0.001, norm)
|
124 |
+
vec = np.divide(vec, norm)
|
125 |
+
|
126 |
+
startend = list(zip(np.linspace(candA[i][0], candB[j][0], num=mid_num), \
|
127 |
+
np.linspace(candA[i][1], candB[j][1], num=mid_num)))
|
128 |
+
|
129 |
+
vec_x = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 0] \
|
130 |
+
for I in range(len(startend))])
|
131 |
+
vec_y = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 1] \
|
132 |
+
for I in range(len(startend))])
|
133 |
+
|
134 |
+
score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(vec_y, vec[1])
|
135 |
+
score_with_dist_prior = sum(score_midpts) / len(score_midpts) + min(
|
136 |
+
0.5 * oriImg.shape[0] / norm - 1, 0)
|
137 |
+
criterion1 = len(np.nonzero(score_midpts > thre2)[0]) > 0.8 * len(score_midpts)
|
138 |
+
criterion2 = score_with_dist_prior > 0
|
139 |
+
if criterion1 and criterion2:
|
140 |
+
connection_candidate.append(
|
141 |
+
[i, j, score_with_dist_prior, score_with_dist_prior + candA[i][2] + candB[j][2]])
|
142 |
+
|
143 |
+
connection_candidate = sorted(connection_candidate, key=lambda x: x[2], reverse=True)
|
144 |
+
connection = np.zeros((0, 5))
|
145 |
+
for c in range(len(connection_candidate)):
|
146 |
+
i, j, s = connection_candidate[c][0:3]
|
147 |
+
if (i not in connection[:, 3] and j not in connection[:, 4]):
|
148 |
+
connection = np.vstack([connection, [candA[i][3], candB[j][3], s, i, j]])
|
149 |
+
if (len(connection) >= min(nA, nB)):
|
150 |
+
break
|
151 |
+
|
152 |
+
connection_all.append(connection)
|
153 |
+
else:
|
154 |
+
special_k.append(k)
|
155 |
+
connection_all.append([])
|
156 |
+
|
157 |
+
# last number in each row is the total parts number of that person
|
158 |
+
# the second last number in each row is the score of the overall configuration
|
159 |
+
subset = -1 * np.ones((0, 20))
|
160 |
+
candidate = np.array([item for sublist in all_peaks for item in sublist])
|
161 |
+
|
162 |
+
for k in range(len(mapIdx)):
|
163 |
+
if k not in special_k:
|
164 |
+
partAs = connection_all[k][:, 0]
|
165 |
+
partBs = connection_all[k][:, 1]
|
166 |
+
indexA, indexB = np.array(limbSeq[k]) - 1
|
167 |
+
|
168 |
+
for i in range(len(connection_all[k])): # = 1:size(temp,1)
|
169 |
+
found = 0
|
170 |
+
subset_idx = [-1, -1]
|
171 |
+
for j in range(len(subset)): # 1:size(subset,1):
|
172 |
+
if subset[j][indexA] == partAs[i] or subset[j][indexB] == partBs[i]:
|
173 |
+
subset_idx[found] = j
|
174 |
+
found += 1
|
175 |
+
|
176 |
+
if found == 1:
|
177 |
+
j = subset_idx[0]
|
178 |
+
if subset[j][indexB] != partBs[i]:
|
179 |
+
subset[j][indexB] = partBs[i]
|
180 |
+
subset[j][-1] += 1
|
181 |
+
subset[j][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
|
182 |
+
elif found == 2: # if found 2 and disjoint, merge them
|
183 |
+
j1, j2 = subset_idx
|
184 |
+
membership = ((subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int))[:-2]
|
185 |
+
if len(np.nonzero(membership == 2)[0]) == 0: # merge
|
186 |
+
subset[j1][:-2] += (subset[j2][:-2] + 1)
|
187 |
+
subset[j1][-2:] += subset[j2][-2:]
|
188 |
+
subset[j1][-2] += connection_all[k][i][2]
|
189 |
+
subset = np.delete(subset, j2, 0)
|
190 |
+
else: # as like found == 1
|
191 |
+
subset[j1][indexB] = partBs[i]
|
192 |
+
subset[j1][-1] += 1
|
193 |
+
subset[j1][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
|
194 |
+
|
195 |
+
# if find no partA in the subset, create a new subset
|
196 |
+
elif not found and k < 17:
|
197 |
+
row = -1 * np.ones(20)
|
198 |
+
row[indexA] = partAs[i]
|
199 |
+
row[indexB] = partBs[i]
|
200 |
+
row[-1] = 2
|
201 |
+
row[-2] = sum(candidate[connection_all[k][i, :2].astype(int), 2]) + connection_all[k][i][2]
|
202 |
+
subset = np.vstack([subset, row])
|
203 |
+
# delete some rows of subset which has few parts occur
|
204 |
+
deleteIdx = []
|
205 |
+
for i in range(len(subset)):
|
206 |
+
if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4:
|
207 |
+
deleteIdx.append(i)
|
208 |
+
subset = np.delete(subset, deleteIdx, axis=0)
|
209 |
+
|
210 |
+
# subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts
|
211 |
+
# candidate: x, y, score, id
|
212 |
+
return candidate, subset
|
213 |
+
|
214 |
+
@staticmethod
|
215 |
+
def format_body_result(candidate: np.ndarray, subset: np.ndarray) -> List[BodyResult]:
|
216 |
+
"""
|
217 |
+
Format the body results from the candidate and subset arrays into a list of BodyResult objects.
|
218 |
+
|
219 |
+
Args:
|
220 |
+
candidate (np.ndarray): An array of candidates containing the x, y coordinates, score, and id
|
221 |
+
for each body part.
|
222 |
+
subset (np.ndarray): An array of subsets containing indices to the candidate array for each
|
223 |
+
person detected. The last two columns of each row hold the total score and total parts
|
224 |
+
of the person.
|
225 |
+
|
226 |
+
Returns:
|
227 |
+
List[BodyResult]: A list of BodyResult objects, where each object represents a person with
|
228 |
+
detected keypoints, total score, and total parts.
|
229 |
+
"""
|
230 |
+
return [
|
231 |
+
BodyResult(
|
232 |
+
keypoints=[
|
233 |
+
Keypoint(
|
234 |
+
x=candidate[candidate_index][0],
|
235 |
+
y=candidate[candidate_index][1],
|
236 |
+
score=candidate[candidate_index][2],
|
237 |
+
id=candidate[candidate_index][3]
|
238 |
+
) if candidate_index != -1 else None
|
239 |
+
for candidate_index in person[:18].astype(int)
|
240 |
+
],
|
241 |
+
total_score=person[18],
|
242 |
+
total_parts=person[19]
|
243 |
+
)
|
244 |
+
for person in subset
|
245 |
+
]
|
246 |
+
|
247 |
+
|
248 |
+
if __name__ == "__main__":
|
249 |
+
body_estimation = Body('../model/body_pose_model.pth')
|
250 |
+
|
251 |
+
test_image = '../images/ski.jpg'
|
252 |
+
oriImg = cv2.imread(test_image) # B,G,R order
|
253 |
+
candidate, subset = body_estimation(oriImg)
|
254 |
+
bodies = body_estimation.format_body_result(candidate, subset)
|
255 |
+
|
256 |
+
canvas = oriImg
|
257 |
+
for body in bodies:
|
258 |
+
canvas = util.draw_bodypose(canvas, body)
|
259 |
+
|
260 |
+
plt.imshow(canvas[:, :, [2, 1, 0]])
|
261 |
+
plt.show()
|
custom_nodes/ComfyUI-tbox/src/dwpose/dw_onnx/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
#Dummy file ensuring this package will be recognized
|
custom_nodes/ComfyUI-tbox/src/dwpose/dw_onnx/cv_ox_det.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
def nms(boxes, scores, nms_thr):
|
5 |
+
"""Single class NMS implemented in Numpy."""
|
6 |
+
x1 = boxes[:, 0]
|
7 |
+
y1 = boxes[:, 1]
|
8 |
+
x2 = boxes[:, 2]
|
9 |
+
y2 = boxes[:, 3]
|
10 |
+
|
11 |
+
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
12 |
+
order = scores.argsort()[::-1]
|
13 |
+
|
14 |
+
keep = []
|
15 |
+
while order.size > 0:
|
16 |
+
i = order[0]
|
17 |
+
keep.append(i)
|
18 |
+
xx1 = np.maximum(x1[i], x1[order[1:]])
|
19 |
+
yy1 = np.maximum(y1[i], y1[order[1:]])
|
20 |
+
xx2 = np.minimum(x2[i], x2[order[1:]])
|
21 |
+
yy2 = np.minimum(y2[i], y2[order[1:]])
|
22 |
+
|
23 |
+
w = np.maximum(0.0, xx2 - xx1 + 1)
|
24 |
+
h = np.maximum(0.0, yy2 - yy1 + 1)
|
25 |
+
inter = w * h
|
26 |
+
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
27 |
+
|
28 |
+
inds = np.where(ovr <= nms_thr)[0]
|
29 |
+
order = order[inds + 1]
|
30 |
+
|
31 |
+
return keep
|
32 |
+
|
33 |
+
def multiclass_nms(boxes, scores, nms_thr, score_thr):
|
34 |
+
"""Multiclass NMS implemented in Numpy. Class-aware version."""
|
35 |
+
final_dets = []
|
36 |
+
num_classes = scores.shape[1]
|
37 |
+
for cls_ind in range(num_classes):
|
38 |
+
cls_scores = scores[:, cls_ind]
|
39 |
+
valid_score_mask = cls_scores > score_thr
|
40 |
+
if valid_score_mask.sum() == 0:
|
41 |
+
continue
|
42 |
+
else:
|
43 |
+
valid_scores = cls_scores[valid_score_mask]
|
44 |
+
valid_boxes = boxes[valid_score_mask]
|
45 |
+
keep = nms(valid_boxes, valid_scores, nms_thr)
|
46 |
+
if len(keep) > 0:
|
47 |
+
cls_inds = np.ones((len(keep), 1)) * cls_ind
|
48 |
+
dets = np.concatenate(
|
49 |
+
[valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
|
50 |
+
)
|
51 |
+
final_dets.append(dets)
|
52 |
+
if len(final_dets) == 0:
|
53 |
+
return None
|
54 |
+
return np.concatenate(final_dets, 0)
|
55 |
+
|
56 |
+
def demo_postprocess(outputs, img_size, p6=False):
|
57 |
+
grids = []
|
58 |
+
expanded_strides = []
|
59 |
+
strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
|
60 |
+
|
61 |
+
hsizes = [img_size[0] // stride for stride in strides]
|
62 |
+
wsizes = [img_size[1] // stride for stride in strides]
|
63 |
+
|
64 |
+
for hsize, wsize, stride in zip(hsizes, wsizes, strides):
|
65 |
+
xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
|
66 |
+
grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
|
67 |
+
grids.append(grid)
|
68 |
+
shape = grid.shape[:2]
|
69 |
+
expanded_strides.append(np.full((*shape, 1), stride))
|
70 |
+
|
71 |
+
grids = np.concatenate(grids, 1)
|
72 |
+
expanded_strides = np.concatenate(expanded_strides, 1)
|
73 |
+
outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
|
74 |
+
outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
|
75 |
+
|
76 |
+
return outputs
|
77 |
+
|
78 |
+
def preprocess(img, input_size, swap=(2, 0, 1)):
|
79 |
+
if len(img.shape) == 3:
|
80 |
+
padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
|
81 |
+
else:
|
82 |
+
padded_img = np.ones(input_size, dtype=np.uint8) * 114
|
83 |
+
|
84 |
+
r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
|
85 |
+
resized_img = cv2.resize(
|
86 |
+
img,
|
87 |
+
(int(img.shape[1] * r), int(img.shape[0] * r)),
|
88 |
+
interpolation=cv2.INTER_LINEAR,
|
89 |
+
).astype(np.uint8)
|
90 |
+
padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
|
91 |
+
|
92 |
+
padded_img = padded_img.transpose(swap)
|
93 |
+
padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
|
94 |
+
return padded_img, r
|
95 |
+
|
96 |
+
def inference_detector(session, oriImg, detect_classes=[0], dtype=np.float32):
|
97 |
+
input_shape = (640,640)
|
98 |
+
img, ratio = preprocess(oriImg, input_shape)
|
99 |
+
|
100 |
+
input = img[None, :, :, :]
|
101 |
+
input = input.astype(dtype)
|
102 |
+
if "InferenceSession" in type(session).__name__:
|
103 |
+
input_name = session.get_inputs()[0].name
|
104 |
+
output = session.run(None, {input_name: input})
|
105 |
+
else:
|
106 |
+
outNames = session.getUnconnectedOutLayersNames()
|
107 |
+
session.setInput(input)
|
108 |
+
output = session.forward(outNames)
|
109 |
+
|
110 |
+
predictions = demo_postprocess(output[0], input_shape)[0]
|
111 |
+
|
112 |
+
boxes = predictions[:, :4]
|
113 |
+
scores = predictions[:, 4:5] * predictions[:, 5:]
|
114 |
+
|
115 |
+
boxes_xyxy = np.ones_like(boxes)
|
116 |
+
boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2.
|
117 |
+
boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2.
|
118 |
+
boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2.
|
119 |
+
boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2.
|
120 |
+
boxes_xyxy /= ratio
|
121 |
+
dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
|
122 |
+
if dets is None:
|
123 |
+
return None
|
124 |
+
final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
|
125 |
+
isscore = final_scores>0.3
|
126 |
+
iscat = np.isin(final_cls_inds, detect_classes)
|
127 |
+
isbbox = [ i and j for (i, j) in zip(isscore, iscat)]
|
128 |
+
final_boxes = final_boxes[isbbox]
|
129 |
+
return final_boxes
|
custom_nodes/ComfyUI-tbox/src/dwpose/dw_onnx/cv_ox_pose.py
ADDED
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
def preprocess(
|
7 |
+
img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256)
|
8 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
9 |
+
"""Do preprocessing for DWPose model inference.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
img (np.ndarray): Input image in shape.
|
13 |
+
input_size (tuple): Input image size in shape (w, h).
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
tuple:
|
17 |
+
- resized_img (np.ndarray): Preprocessed image.
|
18 |
+
- center (np.ndarray): Center of image.
|
19 |
+
- scale (np.ndarray): Scale of image.
|
20 |
+
"""
|
21 |
+
# get shape of image
|
22 |
+
img_shape = img.shape[:2]
|
23 |
+
out_img, out_center, out_scale = [], [], []
|
24 |
+
if len(out_bbox) == 0:
|
25 |
+
out_bbox = [[0, 0, img_shape[1], img_shape[0]]]
|
26 |
+
for i in range(len(out_bbox)):
|
27 |
+
x0 = out_bbox[i][0]
|
28 |
+
y0 = out_bbox[i][1]
|
29 |
+
x1 = out_bbox[i][2]
|
30 |
+
y1 = out_bbox[i][3]
|
31 |
+
bbox = np.array([x0, y0, x1, y1])
|
32 |
+
|
33 |
+
# get center and scale
|
34 |
+
center, scale = bbox_xyxy2cs(bbox, padding=1.25)
|
35 |
+
|
36 |
+
# do affine transformation
|
37 |
+
resized_img, scale = top_down_affine(input_size, scale, center, img)
|
38 |
+
|
39 |
+
# normalize image
|
40 |
+
mean = np.array([123.675, 116.28, 103.53])
|
41 |
+
std = np.array([58.395, 57.12, 57.375])
|
42 |
+
resized_img = (resized_img - mean) / std
|
43 |
+
|
44 |
+
out_img.append(resized_img)
|
45 |
+
out_center.append(center)
|
46 |
+
out_scale.append(scale)
|
47 |
+
|
48 |
+
return out_img, out_center, out_scale
|
49 |
+
|
50 |
+
|
51 |
+
def inference(sess, img, dtype=np.float32):
|
52 |
+
"""Inference DWPose model. Processing all image segments at once to take advantage of GPU's parallelism ability if onnxruntime is installed
|
53 |
+
|
54 |
+
Args:
|
55 |
+
sess : ONNXRuntime session.
|
56 |
+
img : Input image in shape.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
outputs : Output of DWPose model.
|
60 |
+
"""
|
61 |
+
all_out = []
|
62 |
+
# build input
|
63 |
+
input = np.stack(img, axis=0).transpose(0, 3, 1, 2)
|
64 |
+
input = input.astype(dtype)
|
65 |
+
if "InferenceSession" in type(sess).__name__:
|
66 |
+
input_name = sess.get_inputs()[0].name
|
67 |
+
all_outputs = sess.run(None, {input_name: input})
|
68 |
+
for batch_idx in range(len(all_outputs[0])):
|
69 |
+
outputs = [all_outputs[i][batch_idx:batch_idx+1,...] for i in range(len(all_outputs))]
|
70 |
+
all_out.append(outputs)
|
71 |
+
return all_out
|
72 |
+
|
73 |
+
#OpenCV doesn't support batch processing sadly
|
74 |
+
for i in range(len(img)):
|
75 |
+
input = img[i].transpose(2, 0, 1)
|
76 |
+
input = input[None, :, :, :]
|
77 |
+
|
78 |
+
outNames = sess.getUnconnectedOutLayersNames()
|
79 |
+
sess.setInput(input)
|
80 |
+
outputs = sess.forward(outNames)
|
81 |
+
all_out.append(outputs)
|
82 |
+
|
83 |
+
return all_out
|
84 |
+
|
85 |
+
def postprocess(outputs: List[np.ndarray],
|
86 |
+
model_input_size: Tuple[int, int],
|
87 |
+
center: Tuple[int, int],
|
88 |
+
scale: Tuple[int, int],
|
89 |
+
simcc_split_ratio: float = 2.0
|
90 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
91 |
+
"""Postprocess for DWPose model output.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
outputs (np.ndarray): Output of RTMPose model.
|
95 |
+
model_input_size (tuple): RTMPose model Input image size.
|
96 |
+
center (tuple): Center of bbox in shape (x, y).
|
97 |
+
scale (tuple): Scale of bbox in shape (w, h).
|
98 |
+
simcc_split_ratio (float): Split ratio of simcc.
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
tuple:
|
102 |
+
- keypoints (np.ndarray): Rescaled keypoints.
|
103 |
+
- scores (np.ndarray): Model predict scores.
|
104 |
+
"""
|
105 |
+
all_key = []
|
106 |
+
all_score = []
|
107 |
+
for i in range(len(outputs)):
|
108 |
+
# use simcc to decode
|
109 |
+
simcc_x, simcc_y = outputs[i]
|
110 |
+
keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio)
|
111 |
+
|
112 |
+
# rescale keypoints
|
113 |
+
keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2
|
114 |
+
all_key.append(keypoints[0])
|
115 |
+
all_score.append(scores[0])
|
116 |
+
|
117 |
+
return np.array(all_key), np.array(all_score)
|
118 |
+
|
119 |
+
|
120 |
+
def bbox_xyxy2cs(bbox: np.ndarray,
|
121 |
+
padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]:
|
122 |
+
"""Transform the bbox format from (x,y,w,h) into (center, scale)
|
123 |
+
|
124 |
+
Args:
|
125 |
+
bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted
|
126 |
+
as (left, top, right, bottom)
|
127 |
+
padding (float): BBox padding factor that will be multilied to scale.
|
128 |
+
Default: 1.0
|
129 |
+
|
130 |
+
Returns:
|
131 |
+
tuple: A tuple containing center and scale.
|
132 |
+
- np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or
|
133 |
+
(n, 2)
|
134 |
+
- np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or
|
135 |
+
(n, 2)
|
136 |
+
"""
|
137 |
+
# convert single bbox from (4, ) to (1, 4)
|
138 |
+
dim = bbox.ndim
|
139 |
+
if dim == 1:
|
140 |
+
bbox = bbox[None, :]
|
141 |
+
|
142 |
+
# get bbox center and scale
|
143 |
+
x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3])
|
144 |
+
center = np.hstack([x1 + x2, y1 + y2]) * 0.5
|
145 |
+
scale = np.hstack([x2 - x1, y2 - y1]) * padding
|
146 |
+
|
147 |
+
if dim == 1:
|
148 |
+
center = center[0]
|
149 |
+
scale = scale[0]
|
150 |
+
|
151 |
+
return center, scale
|
152 |
+
|
153 |
+
|
154 |
+
def _fix_aspect_ratio(bbox_scale: np.ndarray,
|
155 |
+
aspect_ratio: float) -> np.ndarray:
|
156 |
+
"""Extend the scale to match the given aspect ratio.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
scale (np.ndarray): The image scale (w, h) in shape (2, )
|
160 |
+
aspect_ratio (float): The ratio of ``w/h``
|
161 |
+
|
162 |
+
Returns:
|
163 |
+
np.ndarray: The reshaped image scale in (2, )
|
164 |
+
"""
|
165 |
+
w, h = np.hsplit(bbox_scale, [1])
|
166 |
+
bbox_scale = np.where(w > h * aspect_ratio,
|
167 |
+
np.hstack([w, w / aspect_ratio]),
|
168 |
+
np.hstack([h * aspect_ratio, h]))
|
169 |
+
return bbox_scale
|
170 |
+
|
171 |
+
|
172 |
+
def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray:
|
173 |
+
"""Rotate a point by an angle.
|
174 |
+
|
175 |
+
Args:
|
176 |
+
pt (np.ndarray): 2D point coordinates (x, y) in shape (2, )
|
177 |
+
angle_rad (float): rotation angle in radian
|
178 |
+
|
179 |
+
Returns:
|
180 |
+
np.ndarray: Rotated point in shape (2, )
|
181 |
+
"""
|
182 |
+
sn, cs = np.sin(angle_rad), np.cos(angle_rad)
|
183 |
+
rot_mat = np.array([[cs, -sn], [sn, cs]])
|
184 |
+
return rot_mat @ pt
|
185 |
+
|
186 |
+
|
187 |
+
def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
188 |
+
"""To calculate the affine matrix, three pairs of points are required. This
|
189 |
+
function is used to get the 3rd point, given 2D points a & b.
|
190 |
+
|
191 |
+
The 3rd point is defined by rotating vector `a - b` by 90 degrees
|
192 |
+
anticlockwise, using b as the rotation center.
|
193 |
+
|
194 |
+
Args:
|
195 |
+
a (np.ndarray): The 1st point (x,y) in shape (2, )
|
196 |
+
b (np.ndarray): The 2nd point (x,y) in shape (2, )
|
197 |
+
|
198 |
+
Returns:
|
199 |
+
np.ndarray: The 3rd point.
|
200 |
+
"""
|
201 |
+
direction = a - b
|
202 |
+
c = b + np.r_[-direction[1], direction[0]]
|
203 |
+
return c
|
204 |
+
|
205 |
+
|
206 |
+
def get_warp_matrix(center: np.ndarray,
|
207 |
+
scale: np.ndarray,
|
208 |
+
rot: float,
|
209 |
+
output_size: Tuple[int, int],
|
210 |
+
shift: Tuple[float, float] = (0., 0.),
|
211 |
+
inv: bool = False) -> np.ndarray:
|
212 |
+
"""Calculate the affine transformation matrix that can warp the bbox area
|
213 |
+
in the input image to the output size.
|
214 |
+
|
215 |
+
Args:
|
216 |
+
center (np.ndarray[2, ]): Center of the bounding box (x, y).
|
217 |
+
scale (np.ndarray[2, ]): Scale of the bounding box
|
218 |
+
wrt [width, height].
|
219 |
+
rot (float): Rotation angle (degree).
|
220 |
+
output_size (np.ndarray[2, ] | list(2,)): Size of the
|
221 |
+
destination heatmaps.
|
222 |
+
shift (0-100%): Shift translation ratio wrt the width/height.
|
223 |
+
Default (0., 0.).
|
224 |
+
inv (bool): Option to inverse the affine transform direction.
|
225 |
+
(inv=False: src->dst or inv=True: dst->src)
|
226 |
+
|
227 |
+
Returns:
|
228 |
+
np.ndarray: A 2x3 transformation matrix
|
229 |
+
"""
|
230 |
+
shift = np.array(shift)
|
231 |
+
src_w = scale[0]
|
232 |
+
dst_w = output_size[0]
|
233 |
+
dst_h = output_size[1]
|
234 |
+
|
235 |
+
# compute transformation matrix
|
236 |
+
rot_rad = np.deg2rad(rot)
|
237 |
+
src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad)
|
238 |
+
dst_dir = np.array([0., dst_w * -0.5])
|
239 |
+
|
240 |
+
# get four corners of the src rectangle in the original image
|
241 |
+
src = np.zeros((3, 2), dtype=np.float32)
|
242 |
+
src[0, :] = center + scale * shift
|
243 |
+
src[1, :] = center + src_dir + scale * shift
|
244 |
+
src[2, :] = _get_3rd_point(src[0, :], src[1, :])
|
245 |
+
|
246 |
+
# get four corners of the dst rectangle in the input image
|
247 |
+
dst = np.zeros((3, 2), dtype=np.float32)
|
248 |
+
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
|
249 |
+
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
|
250 |
+
dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
|
251 |
+
|
252 |
+
if inv:
|
253 |
+
warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src))
|
254 |
+
else:
|
255 |
+
warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst))
|
256 |
+
|
257 |
+
return warp_mat
|
258 |
+
|
259 |
+
|
260 |
+
def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict,
|
261 |
+
img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
262 |
+
"""Get the bbox image as the model input by affine transform.
|
263 |
+
|
264 |
+
Args:
|
265 |
+
input_size (dict): The input size of the model.
|
266 |
+
bbox_scale (dict): The bbox scale of the img.
|
267 |
+
bbox_center (dict): The bbox center of the img.
|
268 |
+
img (np.ndarray): The original image.
|
269 |
+
|
270 |
+
Returns:
|
271 |
+
tuple: A tuple containing center and scale.
|
272 |
+
- np.ndarray[float32]: img after affine transform.
|
273 |
+
- np.ndarray[float32]: bbox scale after affine transform.
|
274 |
+
"""
|
275 |
+
w, h = input_size
|
276 |
+
warp_size = (int(w), int(h))
|
277 |
+
|
278 |
+
# reshape bbox to fixed aspect ratio
|
279 |
+
bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h)
|
280 |
+
|
281 |
+
# get the affine matrix
|
282 |
+
center = bbox_center
|
283 |
+
scale = bbox_scale
|
284 |
+
rot = 0
|
285 |
+
warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h))
|
286 |
+
|
287 |
+
# do affine transform
|
288 |
+
img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR)
|
289 |
+
|
290 |
+
return img, bbox_scale
|
291 |
+
|
292 |
+
|
293 |
+
def get_simcc_maximum(simcc_x: np.ndarray,
|
294 |
+
simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
295 |
+
"""Get maximum response location and value from simcc representations.
|
296 |
+
|
297 |
+
Note:
|
298 |
+
instance number: N
|
299 |
+
num_keypoints: K
|
300 |
+
heatmap height: H
|
301 |
+
heatmap width: W
|
302 |
+
|
303 |
+
Args:
|
304 |
+
simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
|
305 |
+
simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy)
|
306 |
+
|
307 |
+
Returns:
|
308 |
+
tuple:
|
309 |
+
- locs (np.ndarray): locations of maximum heatmap responses in shape
|
310 |
+
(K, 2) or (N, K, 2)
|
311 |
+
- vals (np.ndarray): values of maximum heatmap responses in shape
|
312 |
+
(K,) or (N, K)
|
313 |
+
"""
|
314 |
+
N, K, Wx = simcc_x.shape
|
315 |
+
simcc_x = simcc_x.reshape(N * K, -1)
|
316 |
+
simcc_y = simcc_y.reshape(N * K, -1)
|
317 |
+
|
318 |
+
# get maximum value locations
|
319 |
+
x_locs = np.argmax(simcc_x, axis=1)
|
320 |
+
y_locs = np.argmax(simcc_y, axis=1)
|
321 |
+
locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
|
322 |
+
max_val_x = np.amax(simcc_x, axis=1)
|
323 |
+
max_val_y = np.amax(simcc_y, axis=1)
|
324 |
+
|
325 |
+
# get maximum value across x and y axis
|
326 |
+
mask = max_val_x > max_val_y
|
327 |
+
max_val_x[mask] = max_val_y[mask]
|
328 |
+
vals = max_val_x
|
329 |
+
locs[vals <= 0.] = -1
|
330 |
+
|
331 |
+
# reshape
|
332 |
+
locs = locs.reshape(N, K, 2)
|
333 |
+
vals = vals.reshape(N, K)
|
334 |
+
|
335 |
+
return locs, vals
|
336 |
+
|
337 |
+
|
338 |
+
def decode(simcc_x: np.ndarray, simcc_y: np.ndarray,
|
339 |
+
simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]:
|
340 |
+
"""Modulate simcc distribution with Gaussian.
|
341 |
+
|
342 |
+
Args:
|
343 |
+
simcc_x (np.ndarray[K, Wx]): model predicted simcc in x.
|
344 |
+
simcc_y (np.ndarray[K, Wy]): model predicted simcc in y.
|
345 |
+
simcc_split_ratio (int): The split ratio of simcc.
|
346 |
+
|
347 |
+
Returns:
|
348 |
+
tuple: A tuple containing center and scale.
|
349 |
+
- np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2)
|
350 |
+
- np.ndarray[float32]: scores in shape (K,) or (n, K)
|
351 |
+
"""
|
352 |
+
keypoints, scores = get_simcc_maximum(simcc_x, simcc_y)
|
353 |
+
keypoints /= simcc_split_ratio
|
354 |
+
|
355 |
+
return keypoints, scores
|
356 |
+
|
357 |
+
|
358 |
+
def inference_pose(session, out_bbox, oriImg, model_input_size=(288, 384), dtype=np.float32):
|
359 |
+
resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size)
|
360 |
+
outputs = inference(session, resized_img, dtype)
|
361 |
+
keypoints, scores = postprocess(outputs, model_input_size, center, scale)
|
362 |
+
|
363 |
+
return keypoints, scores
|
custom_nodes/ComfyUI-tbox/src/dwpose/dw_onnx/cv_ox_yolo_nas.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Source: https://github.com/Hyuto/yolo-nas-onnx/tree/master/yolo-nas-py
|
2 |
+
# Inspired from: https://github.com/Deci-AI/super-gradients/blob/3.1.1/src/super_gradients/training/processing/processing.py
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
+
|
7 |
+
def preprocess(img, input_size, swap=(2, 0, 1)):
|
8 |
+
if len(img.shape) == 3:
|
9 |
+
padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
|
10 |
+
else:
|
11 |
+
padded_img = np.ones(input_size, dtype=np.uint8) * 114
|
12 |
+
|
13 |
+
r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
|
14 |
+
resized_img = cv2.resize(
|
15 |
+
img,
|
16 |
+
(int(img.shape[1] * r), int(img.shape[0] * r)),
|
17 |
+
interpolation=cv2.INTER_LINEAR,
|
18 |
+
).astype(np.uint8)
|
19 |
+
padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
|
20 |
+
|
21 |
+
padded_img = padded_img.transpose(swap)
|
22 |
+
padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
|
23 |
+
return padded_img, r
|
24 |
+
|
25 |
+
def inference_detector(session, oriImg, detect_classes=[0], dtype=np.uint8):
|
26 |
+
"""
|
27 |
+
This function is only compatible with onnx models exported from the new API with built-in NMS
|
28 |
+
```py
|
29 |
+
from super_gradients.conversion.conversion_enums import ExportQuantizationMode
|
30 |
+
from super_gradients.common.object_names import Models
|
31 |
+
from super_gradients.training import models
|
32 |
+
|
33 |
+
model = models.get(Models.YOLO_NAS_L, pretrained_weights="coco")
|
34 |
+
|
35 |
+
export_result = model.export(
|
36 |
+
"yolo_nas/yolo_nas_l_fp16.onnx",
|
37 |
+
quantization_mode=ExportQuantizationMode.FP16,
|
38 |
+
device="cuda"
|
39 |
+
)
|
40 |
+
```
|
41 |
+
"""
|
42 |
+
input_shape = (640,640)
|
43 |
+
img, ratio = preprocess(oriImg, input_shape)
|
44 |
+
input = img[None, :, :, :]
|
45 |
+
input = input.astype(dtype)
|
46 |
+
if "InferenceSession" in type(session).__name__:
|
47 |
+
input_name = session.get_inputs()[0].name
|
48 |
+
output = session.run(None, {input_name: input})
|
49 |
+
else:
|
50 |
+
outNames = session.getUnconnectedOutLayersNames()
|
51 |
+
session.setInput(input)
|
52 |
+
output = session.forward(outNames)
|
53 |
+
num_preds, pred_boxes, pred_scores, pred_classes = output
|
54 |
+
num_preds = num_preds[0,0]
|
55 |
+
if num_preds == 0:
|
56 |
+
return None
|
57 |
+
idxs = np.where((np.isin(pred_classes[0, :num_preds], detect_classes)) & (pred_scores[0, :num_preds] > 0.3))
|
58 |
+
if (len(idxs) == 0) or (idxs[0].size == 0):
|
59 |
+
return None
|
60 |
+
return pred_boxes[0, idxs].squeeze(axis=0) / ratio
|
custom_nodes/ComfyUI-tbox/src/dwpose/dw_torchscript/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
#Dummy file ensuring this package will be recognized
|
custom_nodes/ComfyUI-tbox/src/dwpose/dw_torchscript/jit_det.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
|
5 |
+
def nms(boxes, scores, nms_thr):
|
6 |
+
"""Single class NMS implemented in Numpy."""
|
7 |
+
x1 = boxes[:, 0]
|
8 |
+
y1 = boxes[:, 1]
|
9 |
+
x2 = boxes[:, 2]
|
10 |
+
y2 = boxes[:, 3]
|
11 |
+
|
12 |
+
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
13 |
+
order = scores.argsort()[::-1]
|
14 |
+
|
15 |
+
keep = []
|
16 |
+
while order.size > 0:
|
17 |
+
i = order[0]
|
18 |
+
keep.append(i)
|
19 |
+
xx1 = np.maximum(x1[i], x1[order[1:]])
|
20 |
+
yy1 = np.maximum(y1[i], y1[order[1:]])
|
21 |
+
xx2 = np.minimum(x2[i], x2[order[1:]])
|
22 |
+
yy2 = np.minimum(y2[i], y2[order[1:]])
|
23 |
+
|
24 |
+
w = np.maximum(0.0, xx2 - xx1 + 1)
|
25 |
+
h = np.maximum(0.0, yy2 - yy1 + 1)
|
26 |
+
inter = w * h
|
27 |
+
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
28 |
+
|
29 |
+
inds = np.where(ovr <= nms_thr)[0]
|
30 |
+
order = order[inds + 1]
|
31 |
+
|
32 |
+
return keep
|
33 |
+
|
34 |
+
def multiclass_nms(boxes, scores, nms_thr, score_thr):
|
35 |
+
"""Multiclass NMS implemented in Numpy. Class-aware version."""
|
36 |
+
final_dets = []
|
37 |
+
num_classes = scores.shape[1]
|
38 |
+
for cls_ind in range(num_classes):
|
39 |
+
cls_scores = scores[:, cls_ind]
|
40 |
+
valid_score_mask = cls_scores > score_thr
|
41 |
+
if valid_score_mask.sum() == 0:
|
42 |
+
continue
|
43 |
+
else:
|
44 |
+
valid_scores = cls_scores[valid_score_mask]
|
45 |
+
valid_boxes = boxes[valid_score_mask]
|
46 |
+
keep = nms(valid_boxes, valid_scores, nms_thr)
|
47 |
+
if len(keep) > 0:
|
48 |
+
cls_inds = np.ones((len(keep), 1)) * cls_ind
|
49 |
+
dets = np.concatenate(
|
50 |
+
[valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
|
51 |
+
)
|
52 |
+
final_dets.append(dets)
|
53 |
+
if len(final_dets) == 0:
|
54 |
+
return None
|
55 |
+
return np.concatenate(final_dets, 0)
|
56 |
+
|
57 |
+
def demo_postprocess(outputs, img_size, p6=False):
|
58 |
+
grids = []
|
59 |
+
expanded_strides = []
|
60 |
+
strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
|
61 |
+
|
62 |
+
hsizes = [img_size[0] // stride for stride in strides]
|
63 |
+
wsizes = [img_size[1] // stride for stride in strides]
|
64 |
+
|
65 |
+
for hsize, wsize, stride in zip(hsizes, wsizes, strides):
|
66 |
+
xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
|
67 |
+
grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
|
68 |
+
grids.append(grid)
|
69 |
+
shape = grid.shape[:2]
|
70 |
+
expanded_strides.append(np.full((*shape, 1), stride))
|
71 |
+
|
72 |
+
grids = np.concatenate(grids, 1)
|
73 |
+
expanded_strides = np.concatenate(expanded_strides, 1)
|
74 |
+
outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
|
75 |
+
outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
|
76 |
+
|
77 |
+
return outputs
|
78 |
+
|
79 |
+
def preprocess(img, input_size, swap=(2, 0, 1)):
|
80 |
+
if len(img.shape) == 3:
|
81 |
+
padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
|
82 |
+
else:
|
83 |
+
padded_img = np.ones(input_size, dtype=np.uint8) * 114
|
84 |
+
|
85 |
+
r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
|
86 |
+
resized_img = cv2.resize(
|
87 |
+
img,
|
88 |
+
(int(img.shape[1] * r), int(img.shape[0] * r)),
|
89 |
+
interpolation=cv2.INTER_LINEAR,
|
90 |
+
).astype(np.uint8)
|
91 |
+
padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
|
92 |
+
|
93 |
+
padded_img = padded_img.transpose(swap)
|
94 |
+
padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
|
95 |
+
return padded_img, r
|
96 |
+
|
97 |
+
def inference_detector(model, oriImg, detect_classes=[0]):
|
98 |
+
input_shape = (640,640)
|
99 |
+
img, ratio = preprocess(oriImg, input_shape)
|
100 |
+
|
101 |
+
device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
|
102 |
+
input = img[None, :, :, :]
|
103 |
+
input = torch.from_numpy(input).to(device, dtype)
|
104 |
+
|
105 |
+
output = model(input).float().cpu().detach().numpy()
|
106 |
+
predictions = demo_postprocess(output[0], input_shape)
|
107 |
+
|
108 |
+
boxes = predictions[:, :4]
|
109 |
+
scores = predictions[:, 4:5] * predictions[:, 5:]
|
110 |
+
|
111 |
+
boxes_xyxy = np.ones_like(boxes)
|
112 |
+
boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2.
|
113 |
+
boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2.
|
114 |
+
boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2.
|
115 |
+
boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2.
|
116 |
+
boxes_xyxy /= ratio
|
117 |
+
dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
|
118 |
+
if dets is None:
|
119 |
+
return None
|
120 |
+
final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
|
121 |
+
isscore = final_scores>0.3
|
122 |
+
iscat = np.isin(final_cls_inds, detect_classes)
|
123 |
+
isbbox = [ i and j for (i, j) in zip(isscore, iscat)]
|
124 |
+
final_boxes = final_boxes[isbbox]
|
125 |
+
return final_boxes
|
custom_nodes/ComfyUI-tbox/src/dwpose/dw_torchscript/jit_pose.py
ADDED
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
def preprocess(
|
8 |
+
img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256)
|
9 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
10 |
+
"""Do preprocessing for DWPose model inference.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
img (np.ndarray): Input image in shape.
|
14 |
+
input_size (tuple): Input image size in shape (w, h).
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
tuple:
|
18 |
+
- resized_img (np.ndarray): Preprocessed image.
|
19 |
+
- center (np.ndarray): Center of image.
|
20 |
+
- scale (np.ndarray): Scale of image.
|
21 |
+
"""
|
22 |
+
# get shape of image
|
23 |
+
img_shape = img.shape[:2]
|
24 |
+
out_img, out_center, out_scale = [], [], []
|
25 |
+
if len(out_bbox) == 0:
|
26 |
+
out_bbox = [[0, 0, img_shape[1], img_shape[0]]]
|
27 |
+
for i in range(len(out_bbox)):
|
28 |
+
x0 = out_bbox[i][0]
|
29 |
+
y0 = out_bbox[i][1]
|
30 |
+
x1 = out_bbox[i][2]
|
31 |
+
y1 = out_bbox[i][3]
|
32 |
+
bbox = np.array([x0, y0, x1, y1])
|
33 |
+
|
34 |
+
# get center and scale
|
35 |
+
center, scale = bbox_xyxy2cs(bbox, padding=1.25)
|
36 |
+
|
37 |
+
# do affine transformation
|
38 |
+
resized_img, scale = top_down_affine(input_size, scale, center, img)
|
39 |
+
|
40 |
+
# normalize image
|
41 |
+
mean = np.array([123.675, 116.28, 103.53])
|
42 |
+
std = np.array([58.395, 57.12, 57.375])
|
43 |
+
resized_img = (resized_img - mean) / std
|
44 |
+
|
45 |
+
out_img.append(resized_img)
|
46 |
+
out_center.append(center)
|
47 |
+
out_scale.append(scale)
|
48 |
+
|
49 |
+
return out_img, out_center, out_scale
|
50 |
+
|
51 |
+
def inference(model, img, bs=5):
|
52 |
+
"""Inference DWPose model implemented in TorchScript.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
model : TorchScript Model.
|
56 |
+
img : Input image in shape.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
outputs : Output of DWPose model.
|
60 |
+
"""
|
61 |
+
all_out = []
|
62 |
+
# build input
|
63 |
+
orig_img_count = len(img)
|
64 |
+
#Pad zeros to fit batch size
|
65 |
+
for _ in range(bs - (orig_img_count % bs)):
|
66 |
+
img.append(np.zeros_like(img[0]))
|
67 |
+
input = np.stack(img, axis=0).transpose(0, 3, 1, 2)
|
68 |
+
device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
|
69 |
+
input = torch.from_numpy(input).to(device, dtype)
|
70 |
+
|
71 |
+
out1, out2 = [], []
|
72 |
+
for i in range(input.shape[0] // bs):
|
73 |
+
curr_batch_output = model(input[i*bs:(i+1)*bs])
|
74 |
+
out1.append(curr_batch_output[0].float())
|
75 |
+
out2.append(curr_batch_output[1].float())
|
76 |
+
out1, out2 = torch.cat(out1, dim=0)[:orig_img_count], torch.cat(out2, dim=0)[:orig_img_count]
|
77 |
+
out1, out2 = out1.float().cpu().detach().numpy(), out2.float().cpu().detach().numpy()
|
78 |
+
all_outputs = out1, out2
|
79 |
+
|
80 |
+
for batch_idx in range(len(all_outputs[0])):
|
81 |
+
outputs = [all_outputs[i][batch_idx:batch_idx+1,...] for i in range(len(all_outputs))]
|
82 |
+
all_out.append(outputs)
|
83 |
+
return all_out
|
84 |
+
def postprocess(outputs: List[np.ndarray],
|
85 |
+
model_input_size: Tuple[int, int],
|
86 |
+
center: Tuple[int, int],
|
87 |
+
scale: Tuple[int, int],
|
88 |
+
simcc_split_ratio: float = 2.0
|
89 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
90 |
+
"""Postprocess for DWPose model output.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
outputs (np.ndarray): Output of RTMPose model.
|
94 |
+
model_input_size (tuple): RTMPose model Input image size.
|
95 |
+
center (tuple): Center of bbox in shape (x, y).
|
96 |
+
scale (tuple): Scale of bbox in shape (w, h).
|
97 |
+
simcc_split_ratio (float): Split ratio of simcc.
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
tuple:
|
101 |
+
- keypoints (np.ndarray): Rescaled keypoints.
|
102 |
+
- scores (np.ndarray): Model predict scores.
|
103 |
+
"""
|
104 |
+
all_key = []
|
105 |
+
all_score = []
|
106 |
+
for i in range(len(outputs)):
|
107 |
+
# use simcc to decode
|
108 |
+
simcc_x, simcc_y = outputs[i]
|
109 |
+
keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio)
|
110 |
+
|
111 |
+
# rescale keypoints
|
112 |
+
keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2
|
113 |
+
all_key.append(keypoints[0])
|
114 |
+
all_score.append(scores[0])
|
115 |
+
|
116 |
+
return np.array(all_key), np.array(all_score)
|
117 |
+
|
118 |
+
|
119 |
+
def bbox_xyxy2cs(bbox: np.ndarray,
|
120 |
+
padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]:
|
121 |
+
"""Transform the bbox format from (x,y,w,h) into (center, scale)
|
122 |
+
|
123 |
+
Args:
|
124 |
+
bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted
|
125 |
+
as (left, top, right, bottom)
|
126 |
+
padding (float): BBox padding factor that will be multilied to scale.
|
127 |
+
Default: 1.0
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
tuple: A tuple containing center and scale.
|
131 |
+
- np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or
|
132 |
+
(n, 2)
|
133 |
+
- np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or
|
134 |
+
(n, 2)
|
135 |
+
"""
|
136 |
+
# convert single bbox from (4, ) to (1, 4)
|
137 |
+
dim = bbox.ndim
|
138 |
+
if dim == 1:
|
139 |
+
bbox = bbox[None, :]
|
140 |
+
|
141 |
+
# get bbox center and scale
|
142 |
+
x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3])
|
143 |
+
center = np.hstack([x1 + x2, y1 + y2]) * 0.5
|
144 |
+
scale = np.hstack([x2 - x1, y2 - y1]) * padding
|
145 |
+
|
146 |
+
if dim == 1:
|
147 |
+
center = center[0]
|
148 |
+
scale = scale[0]
|
149 |
+
|
150 |
+
return center, scale
|
151 |
+
|
152 |
+
|
153 |
+
def _fix_aspect_ratio(bbox_scale: np.ndarray,
|
154 |
+
aspect_ratio: float) -> np.ndarray:
|
155 |
+
"""Extend the scale to match the given aspect ratio.
|
156 |
+
|
157 |
+
Args:
|
158 |
+
scale (np.ndarray): The image scale (w, h) in shape (2, )
|
159 |
+
aspect_ratio (float): The ratio of ``w/h``
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
np.ndarray: The reshaped image scale in (2, )
|
163 |
+
"""
|
164 |
+
w, h = np.hsplit(bbox_scale, [1])
|
165 |
+
bbox_scale = np.where(w > h * aspect_ratio,
|
166 |
+
np.hstack([w, w / aspect_ratio]),
|
167 |
+
np.hstack([h * aspect_ratio, h]))
|
168 |
+
return bbox_scale
|
169 |
+
|
170 |
+
|
171 |
+
def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray:
|
172 |
+
"""Rotate a point by an angle.
|
173 |
+
|
174 |
+
Args:
|
175 |
+
pt (np.ndarray): 2D point coordinates (x, y) in shape (2, )
|
176 |
+
angle_rad (float): rotation angle in radian
|
177 |
+
|
178 |
+
Returns:
|
179 |
+
np.ndarray: Rotated point in shape (2, )
|
180 |
+
"""
|
181 |
+
sn, cs = np.sin(angle_rad), np.cos(angle_rad)
|
182 |
+
rot_mat = np.array([[cs, -sn], [sn, cs]])
|
183 |
+
return rot_mat @ pt
|
184 |
+
|
185 |
+
|
186 |
+
def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
187 |
+
"""To calculate the affine matrix, three pairs of points are required. This
|
188 |
+
function is used to get the 3rd point, given 2D points a & b.
|
189 |
+
|
190 |
+
The 3rd point is defined by rotating vector `a - b` by 90 degrees
|
191 |
+
anticlockwise, using b as the rotation center.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
a (np.ndarray): The 1st point (x,y) in shape (2, )
|
195 |
+
b (np.ndarray): The 2nd point (x,y) in shape (2, )
|
196 |
+
|
197 |
+
Returns:
|
198 |
+
np.ndarray: The 3rd point.
|
199 |
+
"""
|
200 |
+
direction = a - b
|
201 |
+
c = b + np.r_[-direction[1], direction[0]]
|
202 |
+
return c
|
203 |
+
|
204 |
+
|
205 |
+
def get_warp_matrix(center: np.ndarray,
|
206 |
+
scale: np.ndarray,
|
207 |
+
rot: float,
|
208 |
+
output_size: Tuple[int, int],
|
209 |
+
shift: Tuple[float, float] = (0., 0.),
|
210 |
+
inv: bool = False) -> np.ndarray:
|
211 |
+
"""Calculate the affine transformation matrix that can warp the bbox area
|
212 |
+
in the input image to the output size.
|
213 |
+
|
214 |
+
Args:
|
215 |
+
center (np.ndarray[2, ]): Center of the bounding box (x, y).
|
216 |
+
scale (np.ndarray[2, ]): Scale of the bounding box
|
217 |
+
wrt [width, height].
|
218 |
+
rot (float): Rotation angle (degree).
|
219 |
+
output_size (np.ndarray[2, ] | list(2,)): Size of the
|
220 |
+
destination heatmaps.
|
221 |
+
shift (0-100%): Shift translation ratio wrt the width/height.
|
222 |
+
Default (0., 0.).
|
223 |
+
inv (bool): Option to inverse the affine transform direction.
|
224 |
+
(inv=False: src->dst or inv=True: dst->src)
|
225 |
+
|
226 |
+
Returns:
|
227 |
+
np.ndarray: A 2x3 transformation matrix
|
228 |
+
"""
|
229 |
+
shift = np.array(shift)
|
230 |
+
src_w = scale[0]
|
231 |
+
dst_w = output_size[0]
|
232 |
+
dst_h = output_size[1]
|
233 |
+
|
234 |
+
# compute transformation matrix
|
235 |
+
rot_rad = np.deg2rad(rot)
|
236 |
+
src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad)
|
237 |
+
dst_dir = np.array([0., dst_w * -0.5])
|
238 |
+
|
239 |
+
# get four corners of the src rectangle in the original image
|
240 |
+
src = np.zeros((3, 2), dtype=np.float32)
|
241 |
+
src[0, :] = center + scale * shift
|
242 |
+
src[1, :] = center + src_dir + scale * shift
|
243 |
+
src[2, :] = _get_3rd_point(src[0, :], src[1, :])
|
244 |
+
|
245 |
+
# get four corners of the dst rectangle in the input image
|
246 |
+
dst = np.zeros((3, 2), dtype=np.float32)
|
247 |
+
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
|
248 |
+
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
|
249 |
+
dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
|
250 |
+
|
251 |
+
if inv:
|
252 |
+
warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src))
|
253 |
+
else:
|
254 |
+
warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst))
|
255 |
+
|
256 |
+
return warp_mat
|
257 |
+
|
258 |
+
|
259 |
+
def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict,
|
260 |
+
img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
261 |
+
"""Get the bbox image as the model input by affine transform.
|
262 |
+
|
263 |
+
Args:
|
264 |
+
input_size (dict): The input size of the model.
|
265 |
+
bbox_scale (dict): The bbox scale of the img.
|
266 |
+
bbox_center (dict): The bbox center of the img.
|
267 |
+
img (np.ndarray): The original image.
|
268 |
+
|
269 |
+
Returns:
|
270 |
+
tuple: A tuple containing center and scale.
|
271 |
+
- np.ndarray[float32]: img after affine transform.
|
272 |
+
- np.ndarray[float32]: bbox scale after affine transform.
|
273 |
+
"""
|
274 |
+
w, h = input_size
|
275 |
+
warp_size = (int(w), int(h))
|
276 |
+
|
277 |
+
# reshape bbox to fixed aspect ratio
|
278 |
+
bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h)
|
279 |
+
|
280 |
+
# get the affine matrix
|
281 |
+
center = bbox_center
|
282 |
+
scale = bbox_scale
|
283 |
+
rot = 0
|
284 |
+
warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h))
|
285 |
+
|
286 |
+
# do affine transform
|
287 |
+
img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR)
|
288 |
+
|
289 |
+
return img, bbox_scale
|
290 |
+
|
291 |
+
|
292 |
+
def get_simcc_maximum(simcc_x: np.ndarray,
|
293 |
+
simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
294 |
+
"""Get maximum response location and value from simcc representations.
|
295 |
+
|
296 |
+
Note:
|
297 |
+
instance number: N
|
298 |
+
num_keypoints: K
|
299 |
+
heatmap height: H
|
300 |
+
heatmap width: W
|
301 |
+
|
302 |
+
Args:
|
303 |
+
simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
|
304 |
+
simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy)
|
305 |
+
|
306 |
+
Returns:
|
307 |
+
tuple:
|
308 |
+
- locs (np.ndarray): locations of maximum heatmap responses in shape
|
309 |
+
(K, 2) or (N, K, 2)
|
310 |
+
- vals (np.ndarray): values of maximum heatmap responses in shape
|
311 |
+
(K,) or (N, K)
|
312 |
+
"""
|
313 |
+
N, K, Wx = simcc_x.shape
|
314 |
+
simcc_x = simcc_x.reshape(N * K, -1)
|
315 |
+
simcc_y = simcc_y.reshape(N * K, -1)
|
316 |
+
|
317 |
+
# get maximum value locations
|
318 |
+
x_locs = np.argmax(simcc_x, axis=1)
|
319 |
+
y_locs = np.argmax(simcc_y, axis=1)
|
320 |
+
locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
|
321 |
+
max_val_x = np.amax(simcc_x, axis=1)
|
322 |
+
max_val_y = np.amax(simcc_y, axis=1)
|
323 |
+
|
324 |
+
# get maximum value across x and y axis
|
325 |
+
mask = max_val_x > max_val_y
|
326 |
+
max_val_x[mask] = max_val_y[mask]
|
327 |
+
vals = max_val_x
|
328 |
+
locs[vals <= 0.] = -1
|
329 |
+
|
330 |
+
# reshape
|
331 |
+
locs = locs.reshape(N, K, 2)
|
332 |
+
vals = vals.reshape(N, K)
|
333 |
+
|
334 |
+
return locs, vals
|
335 |
+
|
336 |
+
|
337 |
+
def decode(simcc_x: np.ndarray, simcc_y: np.ndarray,
|
338 |
+
simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]:
|
339 |
+
"""Modulate simcc distribution with Gaussian.
|
340 |
+
|
341 |
+
Args:
|
342 |
+
simcc_x (np.ndarray[K, Wx]): model predicted simcc in x.
|
343 |
+
simcc_y (np.ndarray[K, Wy]): model predicted simcc in y.
|
344 |
+
simcc_split_ratio (int): The split ratio of simcc.
|
345 |
+
|
346 |
+
Returns:
|
347 |
+
tuple: A tuple containing center and scale.
|
348 |
+
- np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2)
|
349 |
+
- np.ndarray[float32]: scores in shape (K,) or (n, K)
|
350 |
+
"""
|
351 |
+
keypoints, scores = get_simcc_maximum(simcc_x, simcc_y)
|
352 |
+
keypoints /= simcc_split_ratio
|
353 |
+
|
354 |
+
return keypoints, scores
|
355 |
+
|
356 |
+
def inference_pose(model, out_bbox, oriImg, model_input_size=(288, 384)):
|
357 |
+
resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size)
|
358 |
+
#outputs = inference(session, resized_img, dtype)
|
359 |
+
outputs = inference(model, resized_img)
|
360 |
+
|
361 |
+
keypoints, scores = postprocess(outputs, model_input_size, center, scale)
|
362 |
+
|
363 |
+
return keypoints, scores
|
custom_nodes/ComfyUI-tbox/src/dwpose/face.py
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import numpy as np
|
3 |
+
from torchvision.transforms import ToTensor, ToPILImage
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import cv2
|
7 |
+
|
8 |
+
from . import util
|
9 |
+
from torch.nn import Conv2d, Module, ReLU, MaxPool2d, init
|
10 |
+
|
11 |
+
|
12 |
+
class FaceNet(Module):
|
13 |
+
"""Model the cascading heatmaps. """
|
14 |
+
def __init__(self):
|
15 |
+
super(FaceNet, self).__init__()
|
16 |
+
# cnn to make feature map
|
17 |
+
self.relu = ReLU()
|
18 |
+
self.max_pooling_2d = MaxPool2d(kernel_size=2, stride=2)
|
19 |
+
self.conv1_1 = Conv2d(in_channels=3, out_channels=64,
|
20 |
+
kernel_size=3, stride=1, padding=1)
|
21 |
+
self.conv1_2 = Conv2d(
|
22 |
+
in_channels=64, out_channels=64, kernel_size=3, stride=1,
|
23 |
+
padding=1)
|
24 |
+
self.conv2_1 = Conv2d(
|
25 |
+
in_channels=64, out_channels=128, kernel_size=3, stride=1,
|
26 |
+
padding=1)
|
27 |
+
self.conv2_2 = Conv2d(
|
28 |
+
in_channels=128, out_channels=128, kernel_size=3, stride=1,
|
29 |
+
padding=1)
|
30 |
+
self.conv3_1 = Conv2d(
|
31 |
+
in_channels=128, out_channels=256, kernel_size=3, stride=1,
|
32 |
+
padding=1)
|
33 |
+
self.conv3_2 = Conv2d(
|
34 |
+
in_channels=256, out_channels=256, kernel_size=3, stride=1,
|
35 |
+
padding=1)
|
36 |
+
self.conv3_3 = Conv2d(
|
37 |
+
in_channels=256, out_channels=256, kernel_size=3, stride=1,
|
38 |
+
padding=1)
|
39 |
+
self.conv3_4 = Conv2d(
|
40 |
+
in_channels=256, out_channels=256, kernel_size=3, stride=1,
|
41 |
+
padding=1)
|
42 |
+
self.conv4_1 = Conv2d(
|
43 |
+
in_channels=256, out_channels=512, kernel_size=3, stride=1,
|
44 |
+
padding=1)
|
45 |
+
self.conv4_2 = Conv2d(
|
46 |
+
in_channels=512, out_channels=512, kernel_size=3, stride=1,
|
47 |
+
padding=1)
|
48 |
+
self.conv4_3 = Conv2d(
|
49 |
+
in_channels=512, out_channels=512, kernel_size=3, stride=1,
|
50 |
+
padding=1)
|
51 |
+
self.conv4_4 = Conv2d(
|
52 |
+
in_channels=512, out_channels=512, kernel_size=3, stride=1,
|
53 |
+
padding=1)
|
54 |
+
self.conv5_1 = Conv2d(
|
55 |
+
in_channels=512, out_channels=512, kernel_size=3, stride=1,
|
56 |
+
padding=1)
|
57 |
+
self.conv5_2 = Conv2d(
|
58 |
+
in_channels=512, out_channels=512, kernel_size=3, stride=1,
|
59 |
+
padding=1)
|
60 |
+
self.conv5_3_CPM = Conv2d(
|
61 |
+
in_channels=512, out_channels=128, kernel_size=3, stride=1,
|
62 |
+
padding=1)
|
63 |
+
|
64 |
+
# stage1
|
65 |
+
self.conv6_1_CPM = Conv2d(
|
66 |
+
in_channels=128, out_channels=512, kernel_size=1, stride=1,
|
67 |
+
padding=0)
|
68 |
+
self.conv6_2_CPM = Conv2d(
|
69 |
+
in_channels=512, out_channels=71, kernel_size=1, stride=1,
|
70 |
+
padding=0)
|
71 |
+
|
72 |
+
# stage2
|
73 |
+
self.Mconv1_stage2 = Conv2d(
|
74 |
+
in_channels=199, out_channels=128, kernel_size=7, stride=1,
|
75 |
+
padding=3)
|
76 |
+
self.Mconv2_stage2 = Conv2d(
|
77 |
+
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
78 |
+
padding=3)
|
79 |
+
self.Mconv3_stage2 = Conv2d(
|
80 |
+
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
81 |
+
padding=3)
|
82 |
+
self.Mconv4_stage2 = Conv2d(
|
83 |
+
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
84 |
+
padding=3)
|
85 |
+
self.Mconv5_stage2 = Conv2d(
|
86 |
+
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
87 |
+
padding=3)
|
88 |
+
self.Mconv6_stage2 = Conv2d(
|
89 |
+
in_channels=128, out_channels=128, kernel_size=1, stride=1,
|
90 |
+
padding=0)
|
91 |
+
self.Mconv7_stage2 = Conv2d(
|
92 |
+
in_channels=128, out_channels=71, kernel_size=1, stride=1,
|
93 |
+
padding=0)
|
94 |
+
|
95 |
+
# stage3
|
96 |
+
self.Mconv1_stage3 = Conv2d(
|
97 |
+
in_channels=199, out_channels=128, kernel_size=7, stride=1,
|
98 |
+
padding=3)
|
99 |
+
self.Mconv2_stage3 = Conv2d(
|
100 |
+
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
101 |
+
padding=3)
|
102 |
+
self.Mconv3_stage3 = Conv2d(
|
103 |
+
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
104 |
+
padding=3)
|
105 |
+
self.Mconv4_stage3 = Conv2d(
|
106 |
+
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
107 |
+
padding=3)
|
108 |
+
self.Mconv5_stage3 = Conv2d(
|
109 |
+
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
110 |
+
padding=3)
|
111 |
+
self.Mconv6_stage3 = Conv2d(
|
112 |
+
in_channels=128, out_channels=128, kernel_size=1, stride=1,
|
113 |
+
padding=0)
|
114 |
+
self.Mconv7_stage3 = Conv2d(
|
115 |
+
in_channels=128, out_channels=71, kernel_size=1, stride=1,
|
116 |
+
padding=0)
|
117 |
+
|
118 |
+
# stage4
|
119 |
+
self.Mconv1_stage4 = Conv2d(
|
120 |
+
in_channels=199, out_channels=128, kernel_size=7, stride=1,
|
121 |
+
padding=3)
|
122 |
+
self.Mconv2_stage4 = Conv2d(
|
123 |
+
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
124 |
+
padding=3)
|
125 |
+
self.Mconv3_stage4 = Conv2d(
|
126 |
+
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
127 |
+
padding=3)
|
128 |
+
self.Mconv4_stage4 = Conv2d(
|
129 |
+
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
130 |
+
padding=3)
|
131 |
+
self.Mconv5_stage4 = Conv2d(
|
132 |
+
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
133 |
+
padding=3)
|
134 |
+
self.Mconv6_stage4 = Conv2d(
|
135 |
+
in_channels=128, out_channels=128, kernel_size=1, stride=1,
|
136 |
+
padding=0)
|
137 |
+
self.Mconv7_stage4 = Conv2d(
|
138 |
+
in_channels=128, out_channels=71, kernel_size=1, stride=1,
|
139 |
+
padding=0)
|
140 |
+
|
141 |
+
# stage5
|
142 |
+
self.Mconv1_stage5 = Conv2d(
|
143 |
+
in_channels=199, out_channels=128, kernel_size=7, stride=1,
|
144 |
+
padding=3)
|
145 |
+
self.Mconv2_stage5 = Conv2d(
|
146 |
+
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
147 |
+
padding=3)
|
148 |
+
self.Mconv3_stage5 = Conv2d(
|
149 |
+
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
150 |
+
padding=3)
|
151 |
+
self.Mconv4_stage5 = Conv2d(
|
152 |
+
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
153 |
+
padding=3)
|
154 |
+
self.Mconv5_stage5 = Conv2d(
|
155 |
+
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
156 |
+
padding=3)
|
157 |
+
self.Mconv6_stage5 = Conv2d(
|
158 |
+
in_channels=128, out_channels=128, kernel_size=1, stride=1,
|
159 |
+
padding=0)
|
160 |
+
self.Mconv7_stage5 = Conv2d(
|
161 |
+
in_channels=128, out_channels=71, kernel_size=1, stride=1,
|
162 |
+
padding=0)
|
163 |
+
|
164 |
+
# stage6
|
165 |
+
self.Mconv1_stage6 = Conv2d(
|
166 |
+
in_channels=199, out_channels=128, kernel_size=7, stride=1,
|
167 |
+
padding=3)
|
168 |
+
self.Mconv2_stage6 = Conv2d(
|
169 |
+
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
170 |
+
padding=3)
|
171 |
+
self.Mconv3_stage6 = Conv2d(
|
172 |
+
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
173 |
+
padding=3)
|
174 |
+
self.Mconv4_stage6 = Conv2d(
|
175 |
+
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
176 |
+
padding=3)
|
177 |
+
self.Mconv5_stage6 = Conv2d(
|
178 |
+
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
179 |
+
padding=3)
|
180 |
+
self.Mconv6_stage6 = Conv2d(
|
181 |
+
in_channels=128, out_channels=128, kernel_size=1, stride=1,
|
182 |
+
padding=0)
|
183 |
+
self.Mconv7_stage6 = Conv2d(
|
184 |
+
in_channels=128, out_channels=71, kernel_size=1, stride=1,
|
185 |
+
padding=0)
|
186 |
+
|
187 |
+
for m in self.modules():
|
188 |
+
if isinstance(m, Conv2d):
|
189 |
+
init.constant_(m.bias, 0)
|
190 |
+
|
191 |
+
def forward(self, x):
|
192 |
+
"""Return a list of heatmaps."""
|
193 |
+
heatmaps = []
|
194 |
+
|
195 |
+
h = self.relu(self.conv1_1(x))
|
196 |
+
h = self.relu(self.conv1_2(h))
|
197 |
+
h = self.max_pooling_2d(h)
|
198 |
+
h = self.relu(self.conv2_1(h))
|
199 |
+
h = self.relu(self.conv2_2(h))
|
200 |
+
h = self.max_pooling_2d(h)
|
201 |
+
h = self.relu(self.conv3_1(h))
|
202 |
+
h = self.relu(self.conv3_2(h))
|
203 |
+
h = self.relu(self.conv3_3(h))
|
204 |
+
h = self.relu(self.conv3_4(h))
|
205 |
+
h = self.max_pooling_2d(h)
|
206 |
+
h = self.relu(self.conv4_1(h))
|
207 |
+
h = self.relu(self.conv4_2(h))
|
208 |
+
h = self.relu(self.conv4_3(h))
|
209 |
+
h = self.relu(self.conv4_4(h))
|
210 |
+
h = self.relu(self.conv5_1(h))
|
211 |
+
h = self.relu(self.conv5_2(h))
|
212 |
+
h = self.relu(self.conv5_3_CPM(h))
|
213 |
+
feature_map = h
|
214 |
+
|
215 |
+
# stage1
|
216 |
+
h = self.relu(self.conv6_1_CPM(h))
|
217 |
+
h = self.conv6_2_CPM(h)
|
218 |
+
heatmaps.append(h)
|
219 |
+
|
220 |
+
# stage2
|
221 |
+
h = torch.cat([h, feature_map], dim=1) # channel concat
|
222 |
+
h = self.relu(self.Mconv1_stage2(h))
|
223 |
+
h = self.relu(self.Mconv2_stage2(h))
|
224 |
+
h = self.relu(self.Mconv3_stage2(h))
|
225 |
+
h = self.relu(self.Mconv4_stage2(h))
|
226 |
+
h = self.relu(self.Mconv5_stage2(h))
|
227 |
+
h = self.relu(self.Mconv6_stage2(h))
|
228 |
+
h = self.Mconv7_stage2(h)
|
229 |
+
heatmaps.append(h)
|
230 |
+
|
231 |
+
# stage3
|
232 |
+
h = torch.cat([h, feature_map], dim=1) # channel concat
|
233 |
+
h = self.relu(self.Mconv1_stage3(h))
|
234 |
+
h = self.relu(self.Mconv2_stage3(h))
|
235 |
+
h = self.relu(self.Mconv3_stage3(h))
|
236 |
+
h = self.relu(self.Mconv4_stage3(h))
|
237 |
+
h = self.relu(self.Mconv5_stage3(h))
|
238 |
+
h = self.relu(self.Mconv6_stage3(h))
|
239 |
+
h = self.Mconv7_stage3(h)
|
240 |
+
heatmaps.append(h)
|
241 |
+
|
242 |
+
# stage4
|
243 |
+
h = torch.cat([h, feature_map], dim=1) # channel concat
|
244 |
+
h = self.relu(self.Mconv1_stage4(h))
|
245 |
+
h = self.relu(self.Mconv2_stage4(h))
|
246 |
+
h = self.relu(self.Mconv3_stage4(h))
|
247 |
+
h = self.relu(self.Mconv4_stage4(h))
|
248 |
+
h = self.relu(self.Mconv5_stage4(h))
|
249 |
+
h = self.relu(self.Mconv6_stage4(h))
|
250 |
+
h = self.Mconv7_stage4(h)
|
251 |
+
heatmaps.append(h)
|
252 |
+
|
253 |
+
# stage5
|
254 |
+
h = torch.cat([h, feature_map], dim=1) # channel concat
|
255 |
+
h = self.relu(self.Mconv1_stage5(h))
|
256 |
+
h = self.relu(self.Mconv2_stage5(h))
|
257 |
+
h = self.relu(self.Mconv3_stage5(h))
|
258 |
+
h = self.relu(self.Mconv4_stage5(h))
|
259 |
+
h = self.relu(self.Mconv5_stage5(h))
|
260 |
+
h = self.relu(self.Mconv6_stage5(h))
|
261 |
+
h = self.Mconv7_stage5(h)
|
262 |
+
heatmaps.append(h)
|
263 |
+
|
264 |
+
# stage6
|
265 |
+
h = torch.cat([h, feature_map], dim=1) # channel concat
|
266 |
+
h = self.relu(self.Mconv1_stage6(h))
|
267 |
+
h = self.relu(self.Mconv2_stage6(h))
|
268 |
+
h = self.relu(self.Mconv3_stage6(h))
|
269 |
+
h = self.relu(self.Mconv4_stage6(h))
|
270 |
+
h = self.relu(self.Mconv5_stage6(h))
|
271 |
+
h = self.relu(self.Mconv6_stage6(h))
|
272 |
+
h = self.Mconv7_stage6(h)
|
273 |
+
heatmaps.append(h)
|
274 |
+
|
275 |
+
return heatmaps
|
276 |
+
|
277 |
+
|
278 |
+
LOG = logging.getLogger(__name__)
|
279 |
+
TOTEN = ToTensor()
|
280 |
+
TOPIL = ToPILImage()
|
281 |
+
|
282 |
+
|
283 |
+
params = {
|
284 |
+
'gaussian_sigma': 2.5,
|
285 |
+
'inference_img_size': 736, # 368, 736, 1312
|
286 |
+
'heatmap_peak_thresh': 0.1,
|
287 |
+
'crop_scale': 1.5,
|
288 |
+
'line_indices': [
|
289 |
+
[0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6],
|
290 |
+
[6, 7], [7, 8], [8, 9], [9, 10], [10, 11], [11, 12], [12, 13],
|
291 |
+
[13, 14], [14, 15], [15, 16],
|
292 |
+
[17, 18], [18, 19], [19, 20], [20, 21],
|
293 |
+
[22, 23], [23, 24], [24, 25], [25, 26],
|
294 |
+
[27, 28], [28, 29], [29, 30],
|
295 |
+
[31, 32], [32, 33], [33, 34], [34, 35],
|
296 |
+
[36, 37], [37, 38], [38, 39], [39, 40], [40, 41], [41, 36],
|
297 |
+
[42, 43], [43, 44], [44, 45], [45, 46], [46, 47], [47, 42],
|
298 |
+
[48, 49], [49, 50], [50, 51], [51, 52], [52, 53], [53, 54],
|
299 |
+
[54, 55], [55, 56], [56, 57], [57, 58], [58, 59], [59, 48],
|
300 |
+
[60, 61], [61, 62], [62, 63], [63, 64], [64, 65], [65, 66],
|
301 |
+
[66, 67], [67, 60]
|
302 |
+
],
|
303 |
+
}
|
304 |
+
|
305 |
+
|
306 |
+
class Face(object):
|
307 |
+
"""
|
308 |
+
The OpenPose face landmark detector model.
|
309 |
+
|
310 |
+
Args:
|
311 |
+
inference_size: set the size of the inference image size, suggested:
|
312 |
+
368, 736, 1312, default 736
|
313 |
+
gaussian_sigma: blur the heatmaps, default 2.5
|
314 |
+
heatmap_peak_thresh: return landmark if over threshold, default 0.1
|
315 |
+
|
316 |
+
"""
|
317 |
+
def __init__(self, face_model_path,
|
318 |
+
inference_size=None,
|
319 |
+
gaussian_sigma=None,
|
320 |
+
heatmap_peak_thresh=None):
|
321 |
+
self.inference_size = inference_size or params["inference_img_size"]
|
322 |
+
self.sigma = gaussian_sigma or params['gaussian_sigma']
|
323 |
+
self.threshold = heatmap_peak_thresh or params["heatmap_peak_thresh"]
|
324 |
+
self.model = FaceNet()
|
325 |
+
self.model.load_state_dict(torch.load(face_model_path))
|
326 |
+
# if torch.cuda.is_available():
|
327 |
+
# self.model = self.model.cuda()
|
328 |
+
# print('cuda')
|
329 |
+
self.model.eval()
|
330 |
+
|
331 |
+
def __call__(self, face_img):
|
332 |
+
H, W, C = face_img.shape
|
333 |
+
|
334 |
+
w_size = 384
|
335 |
+
x_data = torch.from_numpy(util.smart_resize(face_img, (w_size, w_size))).permute([2, 0, 1]) / 256.0 - 0.5
|
336 |
+
|
337 |
+
x_data = x_data.to(self.cn_device)
|
338 |
+
|
339 |
+
with torch.no_grad():
|
340 |
+
hs = self.model(x_data[None, ...])
|
341 |
+
heatmaps = F.interpolate(
|
342 |
+
hs[-1],
|
343 |
+
(H, W),
|
344 |
+
mode='bilinear', align_corners=True).cpu().numpy()[0]
|
345 |
+
return heatmaps
|
346 |
+
|
347 |
+
def compute_peaks_from_heatmaps(self, heatmaps):
|
348 |
+
all_peaks = []
|
349 |
+
for part in range(heatmaps.shape[0]):
|
350 |
+
map_ori = heatmaps[part].copy()
|
351 |
+
binary = np.ascontiguousarray(map_ori > 0.05, dtype=np.uint8)
|
352 |
+
|
353 |
+
if np.sum(binary) == 0:
|
354 |
+
continue
|
355 |
+
|
356 |
+
positions = np.where(binary > 0.5)
|
357 |
+
intensities = map_ori[positions]
|
358 |
+
mi = np.argmax(intensities)
|
359 |
+
y, x = positions[0][mi], positions[1][mi]
|
360 |
+
all_peaks.append([x, y])
|
361 |
+
|
362 |
+
return np.array(all_peaks)
|
custom_nodes/ComfyUI-tbox/src/dwpose/hand.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import json
|
3 |
+
import numpy as np
|
4 |
+
import math
|
5 |
+
import time
|
6 |
+
from scipy.ndimage.filters import gaussian_filter
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import matplotlib
|
9 |
+
import torch
|
10 |
+
from skimage.measure import label
|
11 |
+
|
12 |
+
from .model import handpose_model
|
13 |
+
from . import util
|
14 |
+
|
15 |
+
class Hand(object):
|
16 |
+
def __init__(self, model_path):
|
17 |
+
self.model = handpose_model()
|
18 |
+
# if torch.cuda.is_available():
|
19 |
+
# self.model = self.model.cuda()
|
20 |
+
# print('cuda')
|
21 |
+
model_dict = util.transfer(self.model, torch.load(model_path))
|
22 |
+
self.model.load_state_dict(model_dict)
|
23 |
+
self.model.eval()
|
24 |
+
|
25 |
+
def __call__(self, oriImgRaw):
|
26 |
+
scale_search = [0.5, 1.0, 1.5, 2.0]
|
27 |
+
# scale_search = [0.5]
|
28 |
+
boxsize = 368
|
29 |
+
stride = 8
|
30 |
+
padValue = 128
|
31 |
+
thre = 0.05
|
32 |
+
multiplier = [x * boxsize for x in scale_search]
|
33 |
+
|
34 |
+
wsize = 128
|
35 |
+
heatmap_avg = np.zeros((wsize, wsize, 22))
|
36 |
+
|
37 |
+
Hr, Wr, Cr = oriImgRaw.shape
|
38 |
+
|
39 |
+
oriImg = cv2.GaussianBlur(oriImgRaw, (0, 0), 0.8)
|
40 |
+
|
41 |
+
for m in range(len(multiplier)):
|
42 |
+
scale = multiplier[m]
|
43 |
+
imageToTest = util.smart_resize(oriImg, (scale, scale))
|
44 |
+
|
45 |
+
imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue)
|
46 |
+
im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
|
47 |
+
im = np.ascontiguousarray(im)
|
48 |
+
|
49 |
+
data = torch.from_numpy(im).float()
|
50 |
+
if torch.cuda.is_available():
|
51 |
+
data = data.cuda()
|
52 |
+
|
53 |
+
with torch.no_grad():
|
54 |
+
data = data.to(self.cn_device)
|
55 |
+
output = self.model(data).cpu().numpy()
|
56 |
+
|
57 |
+
# extract outputs, resize, and remove padding
|
58 |
+
heatmap = np.transpose(np.squeeze(output), (1, 2, 0)) # output 1 is heatmaps
|
59 |
+
heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride)
|
60 |
+
heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
|
61 |
+
heatmap = util.smart_resize(heatmap, (wsize, wsize))
|
62 |
+
|
63 |
+
heatmap_avg += heatmap / len(multiplier)
|
64 |
+
|
65 |
+
all_peaks = []
|
66 |
+
for part in range(21):
|
67 |
+
map_ori = heatmap_avg[:, :, part]
|
68 |
+
one_heatmap = gaussian_filter(map_ori, sigma=3)
|
69 |
+
binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8)
|
70 |
+
|
71 |
+
if np.sum(binary) == 0:
|
72 |
+
all_peaks.append([0, 0])
|
73 |
+
continue
|
74 |
+
label_img, label_numbers = label(binary, return_num=True, connectivity=binary.ndim)
|
75 |
+
max_index = np.argmax([np.sum(map_ori[label_img == i]) for i in range(1, label_numbers + 1)]) + 1
|
76 |
+
label_img[label_img != max_index] = 0
|
77 |
+
map_ori[label_img == 0] = 0
|
78 |
+
|
79 |
+
y, x = util.npmax(map_ori)
|
80 |
+
y = int(float(y) * float(Hr) / float(wsize))
|
81 |
+
x = int(float(x) * float(Wr) / float(wsize))
|
82 |
+
all_peaks.append([x, y])
|
83 |
+
return np.array(all_peaks)
|
84 |
+
|
85 |
+
if __name__ == "__main__":
|
86 |
+
hand_estimation = Hand('../model/hand_pose_model.pth')
|
87 |
+
|
88 |
+
# test_image = '../images/hand.jpg'
|
89 |
+
test_image = '../images/hand.jpg'
|
90 |
+
oriImg = cv2.imread(test_image) # B,G,R order
|
91 |
+
peaks = hand_estimation(oriImg)
|
92 |
+
canvas = util.draw_handpose(oriImg, peaks, True)
|
93 |
+
cv2.imshow('', canvas)
|
94 |
+
cv2.waitKey(0)
|
custom_nodes/ComfyUI-tbox/src/dwpose/model.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from collections import OrderedDict
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
def make_layers(block, no_relu_layers):
|
8 |
+
layers = []
|
9 |
+
for layer_name, v in block.items():
|
10 |
+
if 'pool' in layer_name:
|
11 |
+
layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1],
|
12 |
+
padding=v[2])
|
13 |
+
layers.append((layer_name, layer))
|
14 |
+
else:
|
15 |
+
conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1],
|
16 |
+
kernel_size=v[2], stride=v[3],
|
17 |
+
padding=v[4])
|
18 |
+
layers.append((layer_name, conv2d))
|
19 |
+
if layer_name not in no_relu_layers:
|
20 |
+
layers.append(('relu_'+layer_name, nn.ReLU(inplace=True)))
|
21 |
+
|
22 |
+
return nn.Sequential(OrderedDict(layers))
|
23 |
+
|
24 |
+
class bodypose_model(nn.Module):
|
25 |
+
def __init__(self):
|
26 |
+
super(bodypose_model, self).__init__()
|
27 |
+
|
28 |
+
# these layers have no relu layer
|
29 |
+
no_relu_layers = ['conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1',\
|
30 |
+
'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2',\
|
31 |
+
'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1',\
|
32 |
+
'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1']
|
33 |
+
blocks = {}
|
34 |
+
block0 = OrderedDict([
|
35 |
+
('conv1_1', [3, 64, 3, 1, 1]),
|
36 |
+
('conv1_2', [64, 64, 3, 1, 1]),
|
37 |
+
('pool1_stage1', [2, 2, 0]),
|
38 |
+
('conv2_1', [64, 128, 3, 1, 1]),
|
39 |
+
('conv2_2', [128, 128, 3, 1, 1]),
|
40 |
+
('pool2_stage1', [2, 2, 0]),
|
41 |
+
('conv3_1', [128, 256, 3, 1, 1]),
|
42 |
+
('conv3_2', [256, 256, 3, 1, 1]),
|
43 |
+
('conv3_3', [256, 256, 3, 1, 1]),
|
44 |
+
('conv3_4', [256, 256, 3, 1, 1]),
|
45 |
+
('pool3_stage1', [2, 2, 0]),
|
46 |
+
('conv4_1', [256, 512, 3, 1, 1]),
|
47 |
+
('conv4_2', [512, 512, 3, 1, 1]),
|
48 |
+
('conv4_3_CPM', [512, 256, 3, 1, 1]),
|
49 |
+
('conv4_4_CPM', [256, 128, 3, 1, 1])
|
50 |
+
])
|
51 |
+
|
52 |
+
|
53 |
+
# Stage 1
|
54 |
+
block1_1 = OrderedDict([
|
55 |
+
('conv5_1_CPM_L1', [128, 128, 3, 1, 1]),
|
56 |
+
('conv5_2_CPM_L1', [128, 128, 3, 1, 1]),
|
57 |
+
('conv5_3_CPM_L1', [128, 128, 3, 1, 1]),
|
58 |
+
('conv5_4_CPM_L1', [128, 512, 1, 1, 0]),
|
59 |
+
('conv5_5_CPM_L1', [512, 38, 1, 1, 0])
|
60 |
+
])
|
61 |
+
|
62 |
+
block1_2 = OrderedDict([
|
63 |
+
('conv5_1_CPM_L2', [128, 128, 3, 1, 1]),
|
64 |
+
('conv5_2_CPM_L2', [128, 128, 3, 1, 1]),
|
65 |
+
('conv5_3_CPM_L2', [128, 128, 3, 1, 1]),
|
66 |
+
('conv5_4_CPM_L2', [128, 512, 1, 1, 0]),
|
67 |
+
('conv5_5_CPM_L2', [512, 19, 1, 1, 0])
|
68 |
+
])
|
69 |
+
blocks['block1_1'] = block1_1
|
70 |
+
blocks['block1_2'] = block1_2
|
71 |
+
|
72 |
+
self.model0 = make_layers(block0, no_relu_layers)
|
73 |
+
|
74 |
+
# Stages 2 - 6
|
75 |
+
for i in range(2, 7):
|
76 |
+
blocks['block%d_1' % i] = OrderedDict([
|
77 |
+
('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]),
|
78 |
+
('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]),
|
79 |
+
('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]),
|
80 |
+
('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]),
|
81 |
+
('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]),
|
82 |
+
('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]),
|
83 |
+
('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0])
|
84 |
+
])
|
85 |
+
|
86 |
+
blocks['block%d_2' % i] = OrderedDict([
|
87 |
+
('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]),
|
88 |
+
('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]),
|
89 |
+
('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]),
|
90 |
+
('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]),
|
91 |
+
('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]),
|
92 |
+
('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]),
|
93 |
+
('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0])
|
94 |
+
])
|
95 |
+
|
96 |
+
for k in blocks.keys():
|
97 |
+
blocks[k] = make_layers(blocks[k], no_relu_layers)
|
98 |
+
|
99 |
+
self.model1_1 = blocks['block1_1']
|
100 |
+
self.model2_1 = blocks['block2_1']
|
101 |
+
self.model3_1 = blocks['block3_1']
|
102 |
+
self.model4_1 = blocks['block4_1']
|
103 |
+
self.model5_1 = blocks['block5_1']
|
104 |
+
self.model6_1 = blocks['block6_1']
|
105 |
+
|
106 |
+
self.model1_2 = blocks['block1_2']
|
107 |
+
self.model2_2 = blocks['block2_2']
|
108 |
+
self.model3_2 = blocks['block3_2']
|
109 |
+
self.model4_2 = blocks['block4_2']
|
110 |
+
self.model5_2 = blocks['block5_2']
|
111 |
+
self.model6_2 = blocks['block6_2']
|
112 |
+
|
113 |
+
|
114 |
+
def forward(self, x):
|
115 |
+
|
116 |
+
out1 = self.model0(x)
|
117 |
+
|
118 |
+
out1_1 = self.model1_1(out1)
|
119 |
+
out1_2 = self.model1_2(out1)
|
120 |
+
out2 = torch.cat([out1_1, out1_2, out1], 1)
|
121 |
+
|
122 |
+
out2_1 = self.model2_1(out2)
|
123 |
+
out2_2 = self.model2_2(out2)
|
124 |
+
out3 = torch.cat([out2_1, out2_2, out1], 1)
|
125 |
+
|
126 |
+
out3_1 = self.model3_1(out3)
|
127 |
+
out3_2 = self.model3_2(out3)
|
128 |
+
out4 = torch.cat([out3_1, out3_2, out1], 1)
|
129 |
+
|
130 |
+
out4_1 = self.model4_1(out4)
|
131 |
+
out4_2 = self.model4_2(out4)
|
132 |
+
out5 = torch.cat([out4_1, out4_2, out1], 1)
|
133 |
+
|
134 |
+
out5_1 = self.model5_1(out5)
|
135 |
+
out5_2 = self.model5_2(out5)
|
136 |
+
out6 = torch.cat([out5_1, out5_2, out1], 1)
|
137 |
+
|
138 |
+
out6_1 = self.model6_1(out6)
|
139 |
+
out6_2 = self.model6_2(out6)
|
140 |
+
|
141 |
+
return out6_1, out6_2
|
142 |
+
|
143 |
+
class handpose_model(nn.Module):
|
144 |
+
def __init__(self):
|
145 |
+
super(handpose_model, self).__init__()
|
146 |
+
|
147 |
+
# these layers have no relu layer
|
148 |
+
no_relu_layers = ['conv6_2_CPM', 'Mconv7_stage2', 'Mconv7_stage3',\
|
149 |
+
'Mconv7_stage4', 'Mconv7_stage5', 'Mconv7_stage6']
|
150 |
+
# stage 1
|
151 |
+
block1_0 = OrderedDict([
|
152 |
+
('conv1_1', [3, 64, 3, 1, 1]),
|
153 |
+
('conv1_2', [64, 64, 3, 1, 1]),
|
154 |
+
('pool1_stage1', [2, 2, 0]),
|
155 |
+
('conv2_1', [64, 128, 3, 1, 1]),
|
156 |
+
('conv2_2', [128, 128, 3, 1, 1]),
|
157 |
+
('pool2_stage1', [2, 2, 0]),
|
158 |
+
('conv3_1', [128, 256, 3, 1, 1]),
|
159 |
+
('conv3_2', [256, 256, 3, 1, 1]),
|
160 |
+
('conv3_3', [256, 256, 3, 1, 1]),
|
161 |
+
('conv3_4', [256, 256, 3, 1, 1]),
|
162 |
+
('pool3_stage1', [2, 2, 0]),
|
163 |
+
('conv4_1', [256, 512, 3, 1, 1]),
|
164 |
+
('conv4_2', [512, 512, 3, 1, 1]),
|
165 |
+
('conv4_3', [512, 512, 3, 1, 1]),
|
166 |
+
('conv4_4', [512, 512, 3, 1, 1]),
|
167 |
+
('conv5_1', [512, 512, 3, 1, 1]),
|
168 |
+
('conv5_2', [512, 512, 3, 1, 1]),
|
169 |
+
('conv5_3_CPM', [512, 128, 3, 1, 1])
|
170 |
+
])
|
171 |
+
|
172 |
+
block1_1 = OrderedDict([
|
173 |
+
('conv6_1_CPM', [128, 512, 1, 1, 0]),
|
174 |
+
('conv6_2_CPM', [512, 22, 1, 1, 0])
|
175 |
+
])
|
176 |
+
|
177 |
+
blocks = {}
|
178 |
+
blocks['block1_0'] = block1_0
|
179 |
+
blocks['block1_1'] = block1_1
|
180 |
+
|
181 |
+
# stage 2-6
|
182 |
+
for i in range(2, 7):
|
183 |
+
blocks['block%d' % i] = OrderedDict([
|
184 |
+
('Mconv1_stage%d' % i, [150, 128, 7, 1, 3]),
|
185 |
+
('Mconv2_stage%d' % i, [128, 128, 7, 1, 3]),
|
186 |
+
('Mconv3_stage%d' % i, [128, 128, 7, 1, 3]),
|
187 |
+
('Mconv4_stage%d' % i, [128, 128, 7, 1, 3]),
|
188 |
+
('Mconv5_stage%d' % i, [128, 128, 7, 1, 3]),
|
189 |
+
('Mconv6_stage%d' % i, [128, 128, 1, 1, 0]),
|
190 |
+
('Mconv7_stage%d' % i, [128, 22, 1, 1, 0])
|
191 |
+
])
|
192 |
+
|
193 |
+
for k in blocks.keys():
|
194 |
+
blocks[k] = make_layers(blocks[k], no_relu_layers)
|
195 |
+
|
196 |
+
self.model1_0 = blocks['block1_0']
|
197 |
+
self.model1_1 = blocks['block1_1']
|
198 |
+
self.model2 = blocks['block2']
|
199 |
+
self.model3 = blocks['block3']
|
200 |
+
self.model4 = blocks['block4']
|
201 |
+
self.model5 = blocks['block5']
|
202 |
+
self.model6 = blocks['block6']
|
203 |
+
|
204 |
+
def forward(self, x):
|
205 |
+
out1_0 = self.model1_0(x)
|
206 |
+
out1_1 = self.model1_1(out1_0)
|
207 |
+
concat_stage2 = torch.cat([out1_1, out1_0], 1)
|
208 |
+
out_stage2 = self.model2(concat_stage2)
|
209 |
+
concat_stage3 = torch.cat([out_stage2, out1_0], 1)
|
210 |
+
out_stage3 = self.model3(concat_stage3)
|
211 |
+
concat_stage4 = torch.cat([out_stage3, out1_0], 1)
|
212 |
+
out_stage4 = self.model4(concat_stage4)
|
213 |
+
concat_stage5 = torch.cat([out_stage4, out1_0], 1)
|
214 |
+
out_stage5 = self.model5(concat_stage5)
|
215 |
+
concat_stage6 = torch.cat([out_stage5, out1_0], 1)
|
216 |
+
out_stage6 = self.model6(concat_stage6)
|
217 |
+
return out_stage6
|
218 |
+
|
custom_nodes/ComfyUI-tbox/src/dwpose/types.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import NamedTuple, List, Optional
|
2 |
+
|
3 |
+
class Keypoint(NamedTuple):
|
4 |
+
x: float
|
5 |
+
y: float
|
6 |
+
score: float = 1.0
|
7 |
+
id: int = -1
|
8 |
+
|
9 |
+
|
10 |
+
class BodyResult(NamedTuple):
|
11 |
+
# Note: Using `Optional` instead of `|` operator as the ladder is a Python
|
12 |
+
# 3.10 feature.
|
13 |
+
# Annotator code should be Python 3.8 Compatible, as controlnet repo uses
|
14 |
+
# Python 3.8 environment.
|
15 |
+
# https://github.com/lllyasviel/ControlNet/blob/d3284fcd0972c510635a4f5abe2eeb71dc0de524/environment.yaml#L6
|
16 |
+
keypoints: List[Optional[Keypoint]]
|
17 |
+
total_score: float = 0.0
|
18 |
+
total_parts: int = 0
|
19 |
+
|
20 |
+
|
21 |
+
HandResult = List[Keypoint]
|
22 |
+
FaceResult = List[Keypoint]
|
23 |
+
AnimalPoseResult = List[Keypoint]
|
24 |
+
|
25 |
+
|
26 |
+
class PoseResult(NamedTuple):
|
27 |
+
body: BodyResult
|
28 |
+
left_hand: Optional[HandResult]
|
29 |
+
right_hand: Optional[HandResult]
|
30 |
+
face: Optional[FaceResult]
|