Dreamspire commited on
Commit
f2dbf59
·
1 Parent(s): 6a1c163

custom_nodes

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +9 -0
  2. .gitignore +1 -1
  3. custom_nodes/ComfyUI-tbox/.gitignore +3 -0
  4. custom_nodes/ComfyUI-tbox/README.md +5 -0
  5. custom_nodes/ComfyUI-tbox/__init__.py +82 -0
  6. custom_nodes/ComfyUI-tbox/config.yaml +4 -0
  7. custom_nodes/ComfyUI-tbox/nodes/face/__init__.py +8 -0
  8. custom_nodes/ComfyUI-tbox/nodes/face/face_enhance_node.py +81 -0
  9. custom_nodes/ComfyUI-tbox/nodes/image/load_node.py +81 -0
  10. custom_nodes/ComfyUI-tbox/nodes/image/save_node.py +79 -0
  11. custom_nodes/ComfyUI-tbox/nodes/image/size_node.py +121 -0
  12. custom_nodes/ComfyUI-tbox/nodes/image/watermark_node.py +58 -0
  13. custom_nodes/ComfyUI-tbox/nodes/mask/mask_node.py +86 -0
  14. custom_nodes/ComfyUI-tbox/nodes/other/vram_node.py +45 -0
  15. custom_nodes/ComfyUI-tbox/nodes/preprocessor/canny_node.py +22 -0
  16. custom_nodes/ComfyUI-tbox/nodes/preprocessor/densepose_node.py +22 -0
  17. custom_nodes/ComfyUI-tbox/nodes/preprocessor/dwpose_node.py +158 -0
  18. custom_nodes/ComfyUI-tbox/nodes/preprocessor/lineart_node.py +44 -0
  19. custom_nodes/ComfyUI-tbox/nodes/preprocessor/midas_node.py +25 -0
  20. custom_nodes/ComfyUI-tbox/nodes/utils.py +165 -0
  21. custom_nodes/ComfyUI-tbox/nodes/video/batch_node.py +69 -0
  22. custom_nodes/ComfyUI-tbox/nodes/video/ffmpeg.py +129 -0
  23. custom_nodes/ComfyUI-tbox/nodes/video/info_node.py +39 -0
  24. custom_nodes/ComfyUI-tbox/nodes/video/load_node.py +261 -0
  25. custom_nodes/ComfyUI-tbox/nodes/video/save_node.py +415 -0
  26. custom_nodes/ComfyUI-tbox/nodes/video/video_formats/h264-mp4.json +10 -0
  27. custom_nodes/ComfyUI-tbox/nodes/video/video_formats/h265-mp4.json +13 -0
  28. custom_nodes/ComfyUI-tbox/nodes/video/video_formats/nvenc_h264-mp4.json +12 -0
  29. custom_nodes/ComfyUI-tbox/nodes/video/video_formats/nvenc_h265-mp4.json +10 -0
  30. custom_nodes/ComfyUI-tbox/nodes/video/video_formats/webm.json +11 -0
  31. custom_nodes/ComfyUI-tbox/requirements.txt +11 -0
  32. custom_nodes/ComfyUI-tbox/src/canny/__init__.py +17 -0
  33. custom_nodes/ComfyUI-tbox/src/common.py +186 -0
  34. custom_nodes/ComfyUI-tbox/src/densepose/__init__.py +67 -0
  35. custom_nodes/ComfyUI-tbox/src/densepose/densepose.py +347 -0
  36. custom_nodes/ComfyUI-tbox/src/dwpose/LICENSE +108 -0
  37. custom_nodes/ComfyUI-tbox/src/dwpose/__init__.py +328 -0
  38. custom_nodes/ComfyUI-tbox/src/dwpose/animalpose.py +273 -0
  39. custom_nodes/ComfyUI-tbox/src/dwpose/body.py +261 -0
  40. custom_nodes/ComfyUI-tbox/src/dwpose/dw_onnx/__init__.py +1 -0
  41. custom_nodes/ComfyUI-tbox/src/dwpose/dw_onnx/cv_ox_det.py +129 -0
  42. custom_nodes/ComfyUI-tbox/src/dwpose/dw_onnx/cv_ox_pose.py +363 -0
  43. custom_nodes/ComfyUI-tbox/src/dwpose/dw_onnx/cv_ox_yolo_nas.py +60 -0
  44. custom_nodes/ComfyUI-tbox/src/dwpose/dw_torchscript/__init__.py +1 -0
  45. custom_nodes/ComfyUI-tbox/src/dwpose/dw_torchscript/jit_det.py +125 -0
  46. custom_nodes/ComfyUI-tbox/src/dwpose/dw_torchscript/jit_pose.py +363 -0
  47. custom_nodes/ComfyUI-tbox/src/dwpose/face.py +362 -0
  48. custom_nodes/ComfyUI-tbox/src/dwpose/hand.py +94 -0
  49. custom_nodes/ComfyUI-tbox/src/dwpose/model.py +218 -0
  50. custom_nodes/ComfyUI-tbox/src/dwpose/types.py +30 -0
.gitattributes CHANGED
@@ -15,3 +15,12 @@ disaster_girl.png filter=lfs diff=lfs merge=lfs -text
15
  istambul.jpg filter=lfs diff=lfs merge=lfs -text
16
  mona.png filter=lfs diff=lfs merge=lfs -text
17
  natasha.png filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
15
  istambul.jpg filter=lfs diff=lfs merge=lfs -text
16
  mona.png filter=lfs diff=lfs merge=lfs -text
17
  natasha.png filter=lfs diff=lfs merge=lfs -text
18
+ *.png filter=lfs diff=lfs merge=lfs -text
19
+ *.jpg filter=lfs diff=lfs merge=lfs -text
20
+ *.ttf filter=lfs diff=lfs merge=lfs -text
21
+ *.whl filter=lfs diff=lfs merge=lfs -text
22
+ *.task filter=lfs diff=lfs merge=lfs -text
23
+ *.npy filter=lfs diff=lfs merge=lfs -text
24
+ *.pkl filter=lfs diff=lfs merge=lfs -text
25
+ *.npz filter=lfs diff=lfs merge=lfs -text
26
+ *.woff2 filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -5,7 +5,7 @@ __pycache__/
5
  !/input/example.png
6
  /models/
7
  /temp/
8
- /custom_nodes/
9
  !custom_nodes/example_node.py.example
10
  extra_model_paths.yaml
11
  /.vs
 
5
  !/input/example.png
6
  /models/
7
  /temp/
8
+ #/custom_nodes/
9
  !custom_nodes/example_node.py.example
10
  extra_model_paths.yaml
11
  /.vs
custom_nodes/ComfyUI-tbox/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .idea
2
+ **/__pycache__
3
+
custom_nodes/ComfyUI-tbox/README.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ ## Image Node:
2
+
3
+ ## Video Node
4
+
5
+ ## ControlNet PreProcessor
custom_nodes/ComfyUI-tbox/__init__.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+ from .utils import here
4
+ import platform
5
+
6
+ sys.path.insert(0, str(Path(here, "src").resolve()))
7
+
8
+ from .nodes.image.load_node import LoadImageNode
9
+ from .nodes.image.save_node import SaveImageNode
10
+ from .nodes.image.save_node import SaveImagesNode
11
+ from .nodes.image.size_node import ImageResizeNode
12
+ from .nodes.image.size_node import ImageSizeNode
13
+ from .nodes.image.size_node import ConstrainImageNode
14
+ from .nodes.image.watermark_node import WatermarkNode
15
+ from .nodes.mask.mask_node import MaskAddNode
16
+ from .nodes.video.load_node import LoadVideoNode
17
+ from .nodes.video.save_node import SaveVideoNode
18
+ from .nodes.video.info_node import VideoInfoNode
19
+ from .nodes.video.batch_node import BatchManagerNode
20
+ from .nodes.preprocessor.canny_node import Canny_Preprocessor
21
+ from .nodes.preprocessor.lineart_node import LineArt_Preprocessor
22
+ from .nodes.preprocessor.lineart_node import Lineart_Standard_Preprocessor
23
+ from .nodes.preprocessor.midas_node import MIDAS_Depth_Map_Preprocessor
24
+ from .nodes.preprocessor.dwpose_node import DWPose_Preprocessor, AnimalPose_Preprocessor
25
+ from .nodes.preprocessor.densepose_node import DensePose_Preprocessor
26
+ from .nodes.face.face_enhance_node import GFPGANNode
27
+ from .nodes.other.vram_node import PurgeVRAMNode
28
+
29
+ NODE_CLASS_MAPPINGS = {
30
+ "PurgeVRAMNode": PurgeVRAMNode,
31
+ "GFPGANNode": GFPGANNode,
32
+ "MaskAddNode": MaskAddNode,
33
+ "ImageLoader": LoadImageNode,
34
+ "ImageSaver": SaveImageNode,
35
+ "ImagesSaver": SaveImagesNode,
36
+ "ImageResize": ImageResizeNode,
37
+ "ImageSize": ImageSizeNode,
38
+ "WatermarkNode": WatermarkNode,
39
+ "VideoLoader": LoadVideoNode,
40
+ "VideoSaver": SaveVideoNode,
41
+ "VideoInfo": VideoInfoNode,
42
+ "BatchManager": BatchManagerNode,
43
+ "ConstrainImageNode": ConstrainImageNode,
44
+ "DensePosePreprocessor": DensePose_Preprocessor,
45
+ "DWPosePreprocessor": DWPose_Preprocessor,
46
+ "AnimalPosePreprocessor": AnimalPose_Preprocessor,
47
+ "MiDaSDepthPreprocessor": MIDAS_Depth_Map_Preprocessor,
48
+ "CannyPreprocessor": Canny_Preprocessor,
49
+ "LineArtPreprocessor": LineArt_Preprocessor,
50
+ "LineartStandardPreprocessor": Lineart_Standard_Preprocessor,
51
+ }
52
+
53
+ NODE_DISPLAY_NAME_MAPPINGS = {
54
+ "PurgeVRAMNode":"PurgeVRAMNode",
55
+ "GFPGANNode": "GFPGANNode",
56
+ "MaskAddNode": "MaskAddNode",
57
+ "ImageLoader": "Image Load",
58
+ "ImageSaver": "Image Save",
59
+ "ImagesSaver": "Image List Save",
60
+ "ImageResize": "Image Resize",
61
+ "ImageSize": "Image Size",
62
+ "WatermarkNode": "Watermark",
63
+ "VideoLoader": "Video Load",
64
+ "VideoSaver": "Video Save",
65
+ "VideoInfo": "Video Info",
66
+ "BatchManager": "Batch Manager",
67
+ "ConstrainImageNode": "Image Constrain",
68
+ "DensePosePreprocessor": "DensePose Estimator",
69
+ "DWPosePreprocessor": "DWPose Estimator",
70
+ "AnimalPosePreprocessor": "AnimalPose Estimator",
71
+ "MiDaSDepthPreprocessor": "MiDaS Depth Estimator",
72
+ "CannyPreprocessor": "Canny Edge Estimator",
73
+ "LineArtPreprocessor": "Realistic Lineart",
74
+ "LineartStandardPreprocessor": "Standard Lineart",
75
+ }
76
+
77
+
78
+ if platform.system() == "Darwin":
79
+ WEB_DIRECTORY = "./web"
80
+ __all__ = ["WEB_DIRECTORY", "NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
81
+ else:
82
+ __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
custom_nodes/ComfyUI-tbox/config.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ annotator_ckpts_path: "../../models/annotator"
2
+ custom_temp_path: "../../temp"
3
+ USE_SYMLINKS: False
4
+ EP_list: ["CUDAExecutionProvider", "DirectMLExecutionProvider", "OpenVINOExecutionProvider", "ROCMExecutionProvider", "CPUExecutionProvider"]
custom_nodes/ComfyUI-tbox/nodes/face/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import folder_paths
4
+
5
+
6
+ model_path = os.path.join(folder_paths.models_dir, "facefusion")
7
+ folder_paths.add_model_folder_path('facefusion', model_path)
8
+
custom_nodes/ComfyUI-tbox/nodes/face/face_enhance_node.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import cv2
4
+ import numpy as np
5
+ from PIL import Image
6
+ from facefusion.gfpgan_onnx import GFPGANOnnx
7
+ from facefusion.yoloface_onnx import YoloFaceOnnx
8
+ from facefusion.affine import create_box_mask, warp_face_by_landmark, paste_back
9
+
10
+ import folder_paths
11
+ from ..utils import tensor2pil, pil2tensor
12
+
13
+ # class GFPGANProvider:
14
+ # @classmethod
15
+ # def INPUT_TYPES(s):
16
+ # return {
17
+ # "required": {
18
+ # "model_name": ("IMAGE", ["gfpgan_1.4.onnx"]),
19
+ # },
20
+ # }
21
+
22
+ # RETURN_TYPES = ("GFPGAN_MODEL",)
23
+ # RETURN_NAMES = ("model",)
24
+ # FUNCTION = "load_model"
25
+ # CATEGORY = "tbox/facefusion"
26
+
27
+ # def load_model(self, model_name):
28
+ # return (model_name,)
29
+
30
+
31
+ class GFPGANNode:
32
+ @classmethod
33
+ def INPUT_TYPES(cls):
34
+ return {
35
+ "required": {
36
+ "images": ("IMAGE",),
37
+ "model_name": (['gfpgan_1.3', 'gfpgan_1.4'], {"default": 'gfpgan_1.4'}),
38
+ "device": (['CPU', 'CUDA', 'CoreML', 'ROCM'], {"default": 'CPU'}),
39
+ "weight": ("FLOAT", {"default": 0.8}),
40
+ }
41
+ }
42
+
43
+ RETURN_TYPES = ("IMAGE", )
44
+ FUNCTION = "process"
45
+ CATEGORY = "tbox/FaceFusion"
46
+
47
+ def process(self, images, model_name, device='CPU', weight=0.8):
48
+ providers = ['CPUExecutionProvider']
49
+ if device== 'CUDA':
50
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
51
+ elif device == 'CoreML':
52
+ providers = ['CoreMLExecutionProvider', 'CPUExecutionProvider']
53
+ elif device == 'ROCM':
54
+ providers = ['ROCMExecutionProvider', 'CPUExecutionProvider']
55
+
56
+ gfpgan_path = folder_paths.get_full_path("facefusion", f'{model_name}.onnx')
57
+ yolo_path = folder_paths.get_full_path("facefusion", 'yoloface_8n.onnx')
58
+
59
+ detector = YoloFaceOnnx(model_path=yolo_path, providers=providers)
60
+ enhancer = GFPGANOnnx(model_path=gfpgan_path, providers=providers)
61
+
62
+ image_list = []
63
+ for i, img in enumerate(images):
64
+ pil = tensor2pil(img)
65
+ image = np.ascontiguousarray(pil)
66
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
67
+ output = image
68
+ face_list = detector.detect(image=image, conf=0.7)
69
+ for index, face in enumerate(face_list):
70
+ cropped, affine_matrix = warp_face_by_landmark(image, face.landmarks, enhancer.input_size)
71
+ box_mask = create_box_mask(enhancer.input_size, 0.3, (0,0,0,0))
72
+ crop_mask = np.minimum.reduce([box_mask]).clip(0, 1)
73
+ result = enhancer.run(cropped)
74
+ output = paste_back(output, result, crop_mask, affine_matrix)
75
+ image = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
76
+ pil = Image.fromarray(image)
77
+ image_list.append(pil2tensor(pil))
78
+ image_list = torch.stack([tensor.squeeze() for tensor in image_list])
79
+ return (image_list,)
80
+
81
+
custom_nodes/ComfyUI-tbox/nodes/image/load_node.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import cv2
4
+ import torch
5
+ import requests
6
+ import itertools
7
+ import folder_paths
8
+ import psutil
9
+ import numpy as np
10
+ from comfy.utils import common_upscale
11
+ from io import BytesIO
12
+ from PIL import Image, ImageSequence, ImageOps
13
+
14
+
15
+
16
+ def pil2tensor(img):
17
+ output_images = []
18
+ output_masks = []
19
+ for i in ImageSequence.Iterator(img):
20
+ i = ImageOps.exif_transpose(i)
21
+ if i.mode == 'I':
22
+ i = i.point(lambda i: i * (1 / 255))
23
+ image = i.convert("RGB")
24
+ image = np.array(image).astype(np.float32) / 255.0
25
+ image = torch.from_numpy(image)[None,]
26
+ if 'A' in i.getbands():
27
+ mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
28
+ mask = 1. - torch.from_numpy(mask)
29
+ else:
30
+ mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
31
+ output_images.append(image)
32
+ output_masks.append(mask.unsqueeze(0))
33
+
34
+ if len(output_images) > 1:
35
+ output_image = torch.cat(output_images, dim=0)
36
+ output_mask = torch.cat(output_masks, dim=0)
37
+ else:
38
+ output_image = output_images[0]
39
+ output_mask = output_masks[0]
40
+
41
+ return (output_image, output_mask)
42
+
43
+
44
+ def load_image(image_source):
45
+ if image_source.startswith('http'):
46
+ print(image_source)
47
+ response = requests.get(image_source)
48
+ img = Image.open(BytesIO(response.content))
49
+ file_name = image_source.split('/')[-1]
50
+ else:
51
+ img = Image.open(image_source)
52
+ file_name = os.path.basename(image_source)
53
+ return img, file_name
54
+
55
+
56
+ class LoadImageNode:
57
+ @classmethod
58
+ def INPUT_TYPES(cls):
59
+ return {
60
+ "required": {
61
+ "path": ("STRING", {"multiline": True, "dynamicPrompts": False})
62
+ }
63
+ }
64
+
65
+
66
+ RETURN_TYPES = ("IMAGE", "MASK")
67
+ FUNCTION = "load_image"
68
+ CATEGORY = "tbox/Image"
69
+
70
+ def load_image(self, path):
71
+ filepaht = path.split('\n')[0]
72
+ img, name = load_image(filepaht)
73
+ img_out, mask_out = pil2tensor(img)
74
+ return (img_out, mask_out)
75
+
76
+
77
+
78
+ if __name__ == "__main__":
79
+ img, name = load_image("https://creativestorage.blob.core.chinacloudapi.cn/test/bird.png")
80
+ img_out, mask_out = pil2tensor(img)
81
+
custom_nodes/ComfyUI-tbox/nodes/image/save_node.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import numpy as np
4
+ from PIL import Image, ImageSequence, ImageOps
5
+
6
+ #from load_node import load_image, pil2tensor
7
+
8
+ def save_image(img, filepath, format, quality):
9
+ try:
10
+ if format in ["jpg", "jpeg"]:
11
+ img.convert("RGB").save(filepath, format="JPEG", quality=quality, subsampling=0)
12
+ elif format == "webp":
13
+ img.save(filepath, format="WEBP", quality=quality, method=6)
14
+ elif format == "bmp":
15
+ img.save(filepath, format="BMP")
16
+ else:
17
+ img.save(filepath, format="PNG", optimize=True)
18
+ except Exception as e:
19
+ print(f"Error saving {filepath}: {str(e)}")
20
+
21
+ class SaveImageNode:
22
+ @classmethod
23
+ def INPUT_TYPES(cls):
24
+ return {
25
+ "required": {
26
+ "images": ("IMAGE",),
27
+ "path": ("STRING", {"multiline": True, "dynamicPrompts": False}),
28
+ "quality": ([100, 95, 90, 85, 80, 75, 70, 60, 50], {"default": 100}),
29
+ }
30
+ }
31
+ RETURN_TYPES = ()
32
+ FUNCTION = "save_image"
33
+ CATEGORY = "tbox/Image"
34
+ OUTPUT_NODE = True
35
+
36
+ def save_image(self, images, path, quality):
37
+ filepaht = path.split('\n')[0]
38
+ format = os.path.splitext(filepaht)[1][1:]
39
+ image = images[0]
40
+ img = Image.fromarray((255. * image.cpu().numpy()).astype(np.uint8))
41
+ save_image(img, filepaht, format, quality)
42
+ return {}
43
+
44
+ class SaveImagesNode:
45
+ @classmethod
46
+ def INPUT_TYPES(cls):
47
+ return {
48
+ "required": {
49
+ "images": ("IMAGE",),
50
+ "path": ("STRING", {"multiline": False, "dynamicPrompts": False}),
51
+ "prefix": ("STRING", {"default": "image"}),
52
+ "format": (["PNG", "JPG", "WEBP", "BMP"],),
53
+ "quality": ([100, 95, 90, 85, 80, 75, 70, 60, 50], {"default": 100}),
54
+ }
55
+ }
56
+ RETURN_TYPES = ()
57
+ FUNCTION = "save_image"
58
+ CATEGORY = "tbox/Image"
59
+ OUTPUT_NODE = True
60
+
61
+ def save_image(self, images, path, prefix, format, quality):
62
+ format = format.lower()
63
+ for i, image in enumerate(images):
64
+ img = Image.fromarray((255. * image.cpu().numpy()).astype(np.uint8))
65
+ filepath = self.generate_filename(path, prefix, i, format)
66
+ save_image(img, filepath, format, quality)
67
+ return {}
68
+
69
+ def IS_CHANGED(s, images):
70
+ return time.time()
71
+
72
+ def generate_filename(self, save_dir, prefix, index, format):
73
+ base_filename = f"{prefix}_{index+1}.{format}"
74
+ filename = os.path.join(save_dir, base_filename)
75
+ counter = 1
76
+ while os.path.exists(filename):
77
+ filename = os.path.join(save_dir, f"{prefix}_{index+1}_{counter}.{format}")
78
+ counter += 1
79
+ return filename
custom_nodes/ComfyUI-tbox/nodes/image/size_node.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import torch
4
+ import comfy.utils
5
+ import numpy as np
6
+ from PIL import Image, ImageSequence, ImageOps
7
+
8
+ class ConstrainImageNode:
9
+ """
10
+ A node that constrains an image to a maximum and minimum size while maintaining aspect ratio.
11
+ """
12
+
13
+ @classmethod
14
+ def INPUT_TYPES(cls):
15
+ return {
16
+ "required": {
17
+ "images": ("IMAGE",),
18
+ "max_width": ("INT", {"default": 1024, "min": 0}),
19
+ "max_height": ("INT", {"default": 1024, "min": 0}),
20
+ "min_width": ("INT", {"default": 0, "min": 0}),
21
+ "min_height": ("INT", {"default": 0, "min": 0}),
22
+ "crop_if_required": (["yes", "no"], {"default": "no"}),
23
+ },
24
+ }
25
+
26
+ RETURN_TYPES = ("IMAGE",)
27
+ FUNCTION = "constrain_image"
28
+ CATEGORY = "tbox/Image"
29
+ OUTPUT_IS_LIST = (True,)
30
+
31
+ def constrain_image(self, images, max_width, max_height, min_width, min_height, crop_if_required):
32
+ crop_if_required = crop_if_required == "yes"
33
+ results = []
34
+ for image in images:
35
+ i = 255. * image.cpu().numpy()
36
+ img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)).convert("RGB")
37
+
38
+ current_width, current_height = img.size
39
+ aspect_ratio = current_width / current_height
40
+
41
+ constrained_width = max(min(current_width, min_width), max_width)
42
+ constrained_height = max(min(current_height, min_height), max_height)
43
+
44
+ if constrained_width / constrained_height > aspect_ratio:
45
+ constrained_width = max(int(constrained_height * aspect_ratio), min_width)
46
+ if crop_if_required:
47
+ constrained_height = int(current_height / (current_width / constrained_width))
48
+ else:
49
+ constrained_height = max(int(constrained_width / aspect_ratio), min_height)
50
+ if crop_if_required:
51
+ constrained_width = int(current_width / (current_height / constrained_height))
52
+
53
+ resized_image = img.resize((constrained_width, constrained_height), Image.LANCZOS)
54
+
55
+ if crop_if_required and (constrained_width > max_width or constrained_height > max_height):
56
+ left = max((constrained_width - max_width) // 2, 0)
57
+ top = max((constrained_height - max_height) // 2, 0)
58
+ right = min(constrained_width, max_width) + left
59
+ bottom = min(constrained_height, max_height) + top
60
+ resized_image = resized_image.crop((left, top, right, bottom))
61
+
62
+ resized_image = np.array(resized_image).astype(np.float32) / 255.0
63
+ resized_image = torch.from_numpy(resized_image)[None,]
64
+ results.append(resized_image)
65
+
66
+ return (results,)
67
+
68
+ # https://github.com/bronkula/comfyui-fitsize
69
+ class ImageSizeNode:
70
+ @classmethod
71
+ def INPUT_TYPES(cls):
72
+ return {
73
+ "required": {
74
+ "image": ("IMAGE", ),
75
+ }
76
+ }
77
+
78
+ RETURN_TYPES = ("INT", "INT", "INT")
79
+ RETURN_NAMES = ("width", "height", "count")
80
+ FUNCTION = "get_size"
81
+ CATEGORY = "tbox/Image"
82
+ def get_size(self, image):
83
+ print(f'shape of image:{image.shape}')
84
+ return (image.shape[2], image.shape[1], image[0])
85
+
86
+
87
+
88
+ class ImageResizeNode:
89
+ @classmethod
90
+ def INPUT_TYPES(cls):
91
+ return {
92
+ "required": {
93
+ "image": ("IMAGE", ),
94
+ "method": (["nearest", "bilinear", "bicubic", "area", "nearest-exact", "lanczos"],),
95
+ },
96
+ "optional": {
97
+ "width": ("INT,FLOAT", { "default": 0.0, "step": 0.1 }),
98
+ "height": ("INT,FLOAT", { "default": 0.0, "step": 0.1 }),
99
+ },
100
+ }
101
+
102
+ RETURN_TYPES = ("IMAGE",)
103
+
104
+ FUNCTION = "resize"
105
+ CATEGORY = "tbox/Image"
106
+
107
+ def resize(self, image, method, width, height):
108
+ print(f'shape of image:{image.shape}, resolution:{width}x{height} type: {type(width)}, {type(height)}')
109
+ if width == 0 and height == 0:
110
+ s = image
111
+ else:
112
+ samples = image.movedim(-1,1)
113
+ if width == 0:
114
+ width = max(1, round(samples.shape[3] * height / samples.shape[2]))
115
+ elif height == 0:
116
+ height = max(1, round(samples.shape[2] * width / samples.shape[3]))
117
+
118
+ s = comfy.utils.common_upscale(samples, width, height, method, True)
119
+ s = s.movedim(1,-1)
120
+ return (s,)
121
+
custom_nodes/ComfyUI-tbox/nodes/image/watermark_node.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image, ImageSequence, ImageOps
5
+ from ..utils import tensor2pil, pil2tensor
6
+
7
+ PADDING = 4
8
+
9
+ class WatermarkNode:
10
+
11
+ @classmethod
12
+ def INPUT_TYPES(cls):
13
+ return {
14
+ "required": {
15
+ "images": ("IMAGE",),
16
+ "logo_list": ("IMAGE",),
17
+ },
18
+ "optional": {
19
+ "logo_mask": ("MASK",),
20
+ "enabled": ("BOOLEAN", {"default": True}),}
21
+ }
22
+ RETURN_TYPES = ("IMAGE",)
23
+ FUNCTION = "watermark"
24
+ CATEGORY = "tbox/Image"
25
+
26
+ def watermark(self, images, logo_list, logo_mask, enabled):
27
+ outputs = []
28
+ if enabled == False:
29
+ return(images,)
30
+ print(f'logo shape: {logo_list.shape}')
31
+ print(f'images shape: {images.shape}')
32
+ logo = tensor2pil(logo_list[0])
33
+ if logo_mask is not None:
34
+ logo_mask = tensor2pil(logo_mask)
35
+ for i, image in enumerate(images):
36
+ img = tensor2pil(image) #Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
37
+ dst = self.add_watermark2(img, logo, logo_mask, 85)
38
+ result = pil2tensor(dst)
39
+ outputs.append(result)
40
+ base_image = torch.stack([tensor.squeeze() for tensor in outputs])
41
+ return (base_image,)
42
+
43
+ def add_watermark2(self, image, logo, logo_mask, opacity=None):
44
+ logo_width, logo_height = logo.size
45
+ image_width, image_height = image.size
46
+ if image_height <= logo_height + PADDING * 2 or image_width <= logo_width + PADDING * 2:
47
+ return image
48
+ y = image_height - logo_height - PADDING * 1
49
+ x = PADDING
50
+ logo = logo.convert('RGBA')
51
+ opacity = int(opacity / 100 * 255)
52
+ logo.putalpha(Image.new("L", logo.size, opacity))
53
+ if logo_mask is not None:
54
+ logo.putalpha(ImageOps.invert(logo_mask))
55
+
56
+ position = (x, y)
57
+ image.paste(logo, position, logo)
58
+ return image
custom_nodes/ComfyUI-tbox/nodes/mask/mask_node.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import numpy as np
4
+
5
+ class MaskSubNode:
6
+ @classmethod
7
+ def INPUT_TYPES(cls):
8
+ return {
9
+ "required": {
10
+ "mask": ("MASK",),
11
+ },
12
+ "optional": {
13
+ "src1": ("MASK",),
14
+ "src2": ("MASK",),
15
+ "src3": ("MASK",),
16
+ "src4": ("MASK",),
17
+ "src5": ("MASK",),
18
+ "src6": ("MASK",),
19
+ }
20
+ }
21
+
22
+ CATEGORY = "mask"
23
+ RETURN_TYPES = ("MASK",)
24
+
25
+ FUNCTION = "sub"
26
+ CATEGORY = "tbox/Mask"
27
+
28
+ def sub_mask(self, dst, src):
29
+ if src != None:
30
+ mask = src.reshape((-1, src.shape[-2], src.shape[-1]))
31
+ return dst - mask
32
+ return dst
33
+
34
+ def add(self, mask, src1=None, src2=None, src3=None, src4=None, src5=None, src6=None):
35
+ print(f'mask shape: {mask.shape}')
36
+ output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone()
37
+ output[:, :, :] = self.sub_mask(output, src1)
38
+ output[:, :, :] = self.sub_mask(output, src2)
39
+ output[:, :, :] = self.sub_mask(output, src3)
40
+ output[:, :, :] = self.sub_mask(output, src4)
41
+ output[:, :, :] = self.sub_mask(output, src5)
42
+ output[:, :, :] = self.sub_mask(output, src6)
43
+ output = torch.clamp(output, 0.0, 1.0)
44
+ return (output, )
45
+
46
+ class MaskAddNode:
47
+ @classmethod
48
+ def INPUT_TYPES(cls):
49
+ return {
50
+ "required": {
51
+ "mask": ("MASK",),
52
+ },
53
+ "optional": {
54
+ "src1": ("MASK",),
55
+ "src2": ("MASK",),
56
+ "src3": ("MASK",),
57
+ "src4": ("MASK",),
58
+ "src5": ("MASK",),
59
+ "src6": ("MASK",),
60
+ }
61
+ }
62
+
63
+ CATEGORY = "mask"
64
+ RETURN_TYPES = ("MASK",)
65
+
66
+ FUNCTION = "add"
67
+ CATEGORY = "tbox/Mask"
68
+
69
+ def add_mask(self, dst, src):
70
+ if src != None:
71
+ mask = src.reshape((-1, src.shape[-2], src.shape[-1]))
72
+ return dst + mask
73
+ return dst
74
+
75
+ def add(self, mask, src1=None, src2=None, src3=None, src4=None, src5=None, src6=None):
76
+ print(f'mask shape: {mask.shape}')
77
+ output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone()
78
+ output[:, :, :] = self.add_mask(output, src1)
79
+ output[:, :, :] = self.add_mask(output, src2)
80
+ output[:, :, :] = self.add_mask(output, src3)
81
+ output[:, :, :] = self.add_mask(output, src4)
82
+ output[:, :, :] = self.add_mask(output, src5)
83
+ output[:, :, :] = self.add_mask(output, src6)
84
+ output = torch.clamp(output, 0.0, 1.0)
85
+ return (output, )
86
+
custom_nodes/ComfyUI-tbox/nodes/other/vram_node.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import torch.cuda
3
+ import comfy.model_management
4
+
5
+ class AnyType(str):
6
+ """A special class that is always equal in not equal comparisons. Credit to pythongosssss"""
7
+ def __eq__(self, __value: object) -> bool:
8
+ return True
9
+ def __ne__(self, __value: object) -> bool:
10
+ return False
11
+
12
+ any = AnyType("*")
13
+
14
+ class PurgeVRAMNode:
15
+
16
+ def __init__(self):
17
+ pass
18
+
19
+ @classmethod
20
+ def INPUT_TYPES(cls):
21
+ return {
22
+ "required": {
23
+ "anything": (any, {}),
24
+ "purge_cache": ("BOOLEAN", {"default": True}),
25
+ "purge_models": ("BOOLEAN", {"default": True}),
26
+ },
27
+ "optional": {
28
+ }
29
+ }
30
+
31
+ RETURN_TYPES = ()
32
+ FUNCTION = "purge_vram"
33
+ CATEGORY = "tbox/other"
34
+ OUTPUT_NODE = True
35
+
36
+ def purge_vram(self, anything, purge_cache, purge_models):
37
+
38
+ gc.collect()
39
+ if torch.cuda.is_available():
40
+ torch.cuda.empty_cache()
41
+ torch.cuda.ipc_collect()
42
+ if purge_models:
43
+ comfy.model_management.unload_all_models()
44
+ comfy.model_management.soft_empty_cache()
45
+ return (None,)
custom_nodes/ComfyUI-tbox/nodes/preprocessor/canny_node.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..utils import common_annotator_call, create_node_input_types
2
+ import comfy.model_management as model_management
3
+ import nodes
4
+
5
+ class Canny_Preprocessor:
6
+ @classmethod
7
+ def INPUT_TYPES(s):
8
+ return create_node_input_types(
9
+ low_threshold=("INT", {"default": 100, "min": 0, "max": 255}),
10
+ high_threshold=("INT", {"default": 100, "min": 0, "max": 255}),
11
+ resolution=("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 64})
12
+ )
13
+
14
+ RETURN_TYPES = ("IMAGE",)
15
+ FUNCTION = "execute"
16
+
17
+ CATEGORY = "tbox/ControlNet Preprocessors"
18
+
19
+ def execute(self, image, low_threshold=100, high_threshold=200, resolution=512, **kwargs):
20
+ from canny import CannyDetector
21
+
22
+ return (common_annotator_call(CannyDetector(), image, low_threshold=low_threshold, high_threshold=high_threshold, resolution=resolution), )
custom_nodes/ComfyUI-tbox/nodes/preprocessor/densepose_node.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..utils import common_annotator_call, create_node_input_types
2
+ import comfy.model_management as model_management
3
+
4
+ class DensePose_Preprocessor:
5
+ @classmethod
6
+ def INPUT_TYPES(s):
7
+ return create_node_input_types(
8
+ model=(["densepose_r50_fpn_dl.torchscript", "densepose_r101_fpn_dl.torchscript"], {"default": "densepose_r50_fpn_dl.torchscript"}),
9
+ cmap=(["Viridis (MagicAnimate)", "Parula (CivitAI)"], {"default": "Viridis (MagicAnimate)"})
10
+ )
11
+
12
+ RETURN_TYPES = ("IMAGE",)
13
+ FUNCTION = "execute"
14
+
15
+ CATEGORY = "tbox/ControlNet Preprocessors"
16
+
17
+ def execute(self, image, model, cmap, resolution=512):
18
+ from densepose import DenseposeDetector
19
+ model = DenseposeDetector \
20
+ .from_pretrained(filename=model) \
21
+ .to(model_management.get_torch_device())
22
+ return (common_annotator_call(model, image, cmap="viridis" if "Viridis" in cmap else "parula", resolution=resolution), )
custom_nodes/ComfyUI-tbox/nodes/preprocessor/dwpose_node.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..utils import common_annotator_call, create_node_input_types
2
+ import comfy.model_management as model_management
3
+ import numpy as np
4
+ import warnings
5
+ from dwpose import DwposeDetector, AnimalposeDetector
6
+ import os
7
+ import json
8
+
9
+
10
+ DWPOSE_MODEL_NAME = "yzd-v/DWPose"
11
+ #Trigger startup caching for onnxruntime
12
+ GPU_PROVIDERS = ["CUDAExecutionProvider", "DirectMLExecutionProvider", "OpenVINOExecutionProvider", "ROCMExecutionProvider", "CoreMLExecutionProvider"]
13
+
14
+ def check_ort_gpu():
15
+ try:
16
+ import onnxruntime as ort
17
+ for provider in GPU_PROVIDERS:
18
+ if provider in ort.get_available_providers():
19
+ return True
20
+ return False
21
+ except:
22
+ return False
23
+
24
+ if not os.environ.get("DWPOSE_ONNXRT_CHECKED"):
25
+ if check_ort_gpu():
26
+ print("DWPose: Onnxruntime with acceleration providers detected")
27
+ else:
28
+ warnings.warn("DWPose: Onnxruntime not found or doesn't come with acceleration providers, switch to OpenCV with CPU device. DWPose might run very slowly")
29
+ os.environ['AUX_ORT_PROVIDERS'] = ''
30
+ os.environ["DWPOSE_ONNXRT_CHECKED"] = '1'
31
+
32
+ class DWPose_Preprocessor:
33
+ @classmethod
34
+ def INPUT_TYPES(s):
35
+ input_types = create_node_input_types(
36
+ detect_hand=(["enable", "disable"], {"default": "enable"}),
37
+ detect_body=(["enable", "disable"], {"default": "enable"}),
38
+ detect_face=(["enable", "disable"], {"default": "enable"})
39
+ )
40
+ input_types["optional"] = {
41
+ **input_types["optional"],
42
+ "bbox_detector": (
43
+ ["yolox_l.torchscript.pt", "yolox_l.onnx", "yolo_nas_l_fp16.onnx", "yolo_nas_m_fp16.onnx", "yolo_nas_s_fp16.onnx"],
44
+ {"default": "yolox_l.onnx"}
45
+ ),
46
+ "pose_estimator": (["dw-ll_ucoco_384_bs5.torchscript.pt", "dw-ll_ucoco_384.onnx", "dw-ll_ucoco.onnx"], {"default": "dw-ll_ucoco_384_bs5.torchscript.pt"})
47
+ }
48
+ return input_types
49
+
50
+ RETURN_TYPES = ("IMAGE", "POSE_KEYPOINT")
51
+ FUNCTION = "estimate_pose"
52
+
53
+ CATEGORY = "tbox/ControlNet Preprocessors"
54
+
55
+ def estimate_pose(self, image, detect_hand, detect_body, detect_face, resolution=512, bbox_detector="yolox_l.onnx", pose_estimator="dw-ll_ucoco_384.onnx", **kwargs):
56
+ if bbox_detector == "yolox_l.onnx":
57
+ yolo_repo = DWPOSE_MODEL_NAME
58
+ elif "yolox" in bbox_detector:
59
+ yolo_repo = "hr16/yolox-onnx"
60
+ elif "yolo_nas" in bbox_detector:
61
+ yolo_repo = "hr16/yolo-nas-fp16"
62
+ else:
63
+ raise NotImplementedError(f"Download mechanism for {bbox_detector}")
64
+
65
+ if pose_estimator == "dw-ll_ucoco_384.onnx":
66
+ pose_repo = DWPOSE_MODEL_NAME
67
+ elif pose_estimator.endswith(".onnx"):
68
+ pose_repo = "hr16/UnJIT-DWPose"
69
+ elif pose_estimator.endswith(".torchscript.pt"):
70
+ pose_repo = "hr16/DWPose-TorchScript-BatchSize5"
71
+ else:
72
+ raise NotImplementedError(f"Download mechanism for {pose_estimator}")
73
+
74
+ model = DwposeDetector.from_pretrained(
75
+ pose_repo,
76
+ yolo_repo,
77
+ det_filename=bbox_detector, pose_filename=pose_estimator,
78
+ torchscript_device=model_management.get_torch_device()
79
+ )
80
+ detect_hand = detect_hand == "enable"
81
+ detect_body = detect_body == "enable"
82
+ detect_face = detect_face == "enable"
83
+ self.openpose_dicts = []
84
+ def func(image, **kwargs):
85
+ pose_img, openpose_dict = model(image, **kwargs)
86
+ self.openpose_dicts.append(openpose_dict)
87
+ return pose_img
88
+
89
+ out = common_annotator_call(func, image, include_hand=detect_hand, include_face=detect_face, include_body=detect_body, image_and_json=True, resolution=resolution)
90
+ del model
91
+ return {
92
+ 'ui': { "openpose_json": [json.dumps(self.openpose_dicts, indent=4)] },
93
+ "result": (out, self.openpose_dicts)
94
+ }
95
+
96
+ class AnimalPose_Preprocessor:
97
+ @classmethod
98
+ def INPUT_TYPES(s):
99
+ return create_node_input_types(
100
+ bbox_detector = (
101
+ ["yolox_l.torchscript.pt", "yolox_l.onnx", "yolo_nas_l_fp16.onnx", "yolo_nas_m_fp16.onnx", "yolo_nas_s_fp16.onnx"],
102
+ {"default": "yolox_l.torchscript.pt"}
103
+ ),
104
+ pose_estimator = (["rtmpose-m_ap10k_256_bs5.torchscript.pt", "rtmpose-m_ap10k_256.onnx"], {"default": "rtmpose-m_ap10k_256_bs5.torchscript.pt"})
105
+ )
106
+
107
+ RETURN_TYPES = ("IMAGE", "POSE_KEYPOINT")
108
+ FUNCTION = "estimate_pose"
109
+
110
+ CATEGORY = "tbox/ControlNet Preprocessors"
111
+
112
+ def estimate_pose(self, image, resolution=512, bbox_detector="yolox_l.onnx", pose_estimator="rtmpose-m_ap10k_256.onnx", **kwargs):
113
+ if bbox_detector == "yolox_l.onnx":
114
+ yolo_repo = DWPOSE_MODEL_NAME
115
+ elif "yolox" in bbox_detector:
116
+ yolo_repo = "hr16/yolox-onnx"
117
+ elif "yolo_nas" in bbox_detector:
118
+ yolo_repo = "hr16/yolo-nas-fp16"
119
+ else:
120
+ raise NotImplementedError(f"Download mechanism for {bbox_detector}")
121
+
122
+ if pose_estimator == "dw-ll_ucoco_384.onnx":
123
+ pose_repo = DWPOSE_MODEL_NAME
124
+ elif pose_estimator.endswith(".onnx"):
125
+ pose_repo = "hr16/UnJIT-DWPose"
126
+ elif pose_estimator.endswith(".torchscript.pt"):
127
+ pose_repo = "hr16/DWPose-TorchScript-BatchSize5"
128
+ else:
129
+ raise NotImplementedError(f"Download mechanism for {pose_estimator}")
130
+
131
+ model = AnimalposeDetector.from_pretrained(
132
+ pose_repo,
133
+ yolo_repo,
134
+ det_filename=bbox_detector, pose_filename=pose_estimator,
135
+ torchscript_device=model_management.get_torch_device()
136
+ )
137
+
138
+ self.openpose_dicts = []
139
+ def func(image, **kwargs):
140
+ pose_img, openpose_dict = model(image, **kwargs)
141
+ self.openpose_dicts.append(openpose_dict)
142
+ return pose_img
143
+
144
+ out = common_annotator_call(func, image, image_and_json=True, resolution=resolution)
145
+ del model
146
+ return {
147
+ 'ui': { "openpose_json": [json.dumps(self.openpose_dicts, indent=4)] },
148
+ "result": (out, self.openpose_dicts)
149
+ }
150
+
151
+ NODE_CLASS_MAPPINGS = {
152
+ "DWPreprocessor": DWPose_Preprocessor,
153
+ "AnimalPosePreprocessor": AnimalPose_Preprocessor
154
+ }
155
+ NODE_DISPLAY_NAME_MAPPINGS = {
156
+ "DWPreprocessor": "DWPose Estimator",
157
+ "AnimalPosePreprocessor": "AnimalPose Estimator (AP10K)"
158
+ }
custom_nodes/ComfyUI-tbox/nodes/preprocessor/lineart_node.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..utils import common_annotator_call, create_node_input_types
2
+ import comfy.model_management as model_management
3
+ import nodes
4
+
5
+ class LineArt_Preprocessor:
6
+ @classmethod
7
+ def INPUT_TYPES(s):
8
+ return create_node_input_types(
9
+ coarse=(["disable", "enable"], {"default": "enable"}),
10
+ resolution=("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 64})
11
+ )
12
+
13
+ RETURN_TYPES = ("IMAGE",)
14
+ FUNCTION = "execute"
15
+
16
+ CATEGORY = "tbox/ControlNet Preprocessors"
17
+
18
+
19
+ def execute(self, image, resolution=512, **kwargs):
20
+ from lineart import LineartDetector
21
+
22
+ model = LineartDetector.from_pretrained().to(model_management.get_torch_device())
23
+ out = common_annotator_call(model, image, resolution=resolution, coarse = kwargs["coarse"] == "enable")
24
+ del model
25
+ return (out, )
26
+
27
+ class Lineart_Standard_Preprocessor:
28
+ @classmethod
29
+ def INPUT_TYPES(s):
30
+ return create_node_input_types(
31
+ guassian_sigma=("FLOAT", {"default":6.0, "max": 100.0}),
32
+ intensity_threshold=("INT", {"default": 8, "max": 16}),
33
+ resolution=("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 64})
34
+ )
35
+
36
+ RETURN_TYPES = ("IMAGE",)
37
+ FUNCTION = "execute"
38
+
39
+ CATEGORY = "tbox/ControlNet Preprocessors"
40
+
41
+
42
+ def execute(self, image, guassian_sigma=6, intensity_threshold=8, resolution=512, **kwargs):
43
+ from lineart import LineartStandardDetector
44
+ return (common_annotator_call(LineartStandardDetector(), image, guassian_sigma=guassian_sigma, intensity_threshold=intensity_threshold, resolution=resolution), )
custom_nodes/ComfyUI-tbox/nodes/preprocessor/midas_node.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..utils import common_annotator_call, create_node_input_types
2
+ import comfy.model_management as model_management
3
+ import numpy as np
4
+
5
+ class MIDAS_Depth_Map_Preprocessor:
6
+ @classmethod
7
+ def INPUT_TYPES(s):
8
+ return create_node_input_types(
9
+ a = ("FLOAT", {"default": np.pi * 2.0, "min": 0.0, "max": np.pi * 5.0, "step": 0.05}),
10
+ bg_threshold = ("FLOAT", {"default": 0.1, "min": 0, "max": 1, "step": 0.05})
11
+ )
12
+
13
+ RETURN_TYPES = ("IMAGE",)
14
+ FUNCTION = "execute"
15
+
16
+ CATEGORY = "tbox/ControlNet Preprocessors"
17
+
18
+ def execute(self, image, a, bg_threshold, resolution=512, **kwargs):
19
+ from midas import MidasDetector
20
+
21
+ # Ref: https://github.com/lllyasviel/ControlNet/blob/main/gradio_depth2image.py
22
+ model = MidasDetector.from_pretrained().to(model_management.get_torch_device())
23
+ out = common_annotator_call(model, image, resolution=resolution, a=a, bg_th=bg_threshold)
24
+ del model
25
+ return (out, )
custom_nodes/ComfyUI-tbox/nodes/utils.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import torch
4
+ import nodes
5
+ import server
6
+ import folder_paths
7
+ import numpy as np
8
+ from typing import Iterable
9
+ from PIL import Image
10
+
11
+ BIGMIN = -(2**53-1)
12
+ BIGMAX = (2**53-1)
13
+
14
+ DIMMAX = 8192
15
+
16
+
17
+ def tensor_to_int(tensor, bits):
18
+ #TODO: investigate benefit of rounding by adding 0.5 before clip/cast
19
+ tensor = tensor.cpu().numpy() * (2**bits-1)
20
+ return np.clip(tensor, 0, (2**bits-1))
21
+ def tensor_to_shorts(tensor):
22
+ return tensor_to_int(tensor, 16).astype(np.uint16)
23
+ def tensor_to_bytes(tensor):
24
+ return tensor_to_int(tensor, 8).astype(np.uint8)
25
+ def tensor2pil(x):
26
+ return Image.fromarray(np.clip(255. * x.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
27
+ def pil2tensor(image: Image.Image) -> torch.Tensor:
28
+ return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
29
+
30
+
31
+ def is_url(url):
32
+ return url.split("://")[0] in ["http", "https"]
33
+
34
+ def strip_path(path):
35
+ #This leaves whitespace inside quotes and only a single "
36
+ #thus ' ""test"' -> '"test'
37
+ #consider path.strip(string.whitespace+"\"")
38
+ #or weightier re.fullmatch("[\\s\"]*(.+?)[\\s\"]*", path).group(1)
39
+ path = path.strip()
40
+ if path.startswith("\""):
41
+ path = path[1:]
42
+ if path.endswith("\""):
43
+ path = path[:-1]
44
+ return path
45
+ def hash_path(path):
46
+ if path is None:
47
+ return "input"
48
+ if is_url(path):
49
+ return "url"
50
+ return calculate_file_hash(strip_path(path))
51
+
52
+ # modified from https://stackoverflow.com/questions/22058048/hashing-a-file-in-python
53
+ def calculate_file_hash(filename: str, hash_every_n: int = 1):
54
+ #Larger video files were taking >.5 seconds to hash even when cached,
55
+ #so instead the modified time from the filesystem is used as a hash
56
+ h = hashlib.sha256()
57
+ h.update(filename.encode())
58
+ h.update(str(os.path.getmtime(filename)).encode())
59
+ return h.hexdigest()
60
+
61
+ def is_safe_path(path):
62
+ if "VHS_STRICT_PATHS" not in os.environ:
63
+ return True
64
+ basedir = os.path.abspath('.')
65
+ try:
66
+ common_path = os.path.commonpath([basedir, path])
67
+ except:
68
+ #Different drive on windows
69
+ return False
70
+ return common_path == basedir
71
+
72
+ def validate_path(path, allow_none=False, allow_url=True):
73
+ if path is None:
74
+ return allow_none
75
+ if is_url(path):
76
+ #Probably not feasible to check if url resolves here
77
+ if not allow_url:
78
+ return "URLs are unsupported for this path"
79
+ return is_safe_path(path)
80
+ if not os.path.isfile(strip_path(path)):
81
+ return "Invalid file path: {}".format(path)
82
+ return is_safe_path(path)
83
+
84
+ def common_annotator_call(model, tensor_image, input_batch=False, show_pbar=False, **kwargs):
85
+ if "detect_resolution" in kwargs:
86
+ del kwargs["detect_resolution"] #Prevent weird case?
87
+
88
+ if "resolution" in kwargs:
89
+ detect_resolution = kwargs["resolution"] if type(kwargs["resolution"]) == int and kwargs["resolution"] >= 64 else 512
90
+ del kwargs["resolution"]
91
+ else:
92
+ detect_resolution = 512
93
+
94
+ if input_batch:
95
+ np_images = np.asarray(tensor_image * 255., dtype=np.uint8)
96
+ np_results = model(np_images, output_type="np", detect_resolution=detect_resolution, **kwargs)
97
+ return torch.from_numpy(np_results.astype(np.float32) / 255.0)
98
+
99
+ batch_size = tensor_image.shape[0]
100
+
101
+ out_tensor = None
102
+ for i, image in enumerate(tensor_image):
103
+ np_image = np.asarray(image.cpu() * 255., dtype=np.uint8)
104
+ np_result = model(np_image, output_type="np", detect_resolution=detect_resolution, **kwargs)
105
+ out = torch.from_numpy(np_result.astype(np.float32) / 255.0)
106
+ if out_tensor is None:
107
+ out_tensor = torch.zeros(batch_size, *out.shape, dtype=torch.float32)
108
+ out_tensor[i] = out
109
+
110
+ return out_tensor
111
+
112
+ def create_node_input_types(**extra_kwargs):
113
+ return {
114
+ "required": {
115
+ "image": ("IMAGE",)
116
+ },
117
+ "optional": {
118
+ **extra_kwargs,
119
+ "resolution": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 64})
120
+ }
121
+ }
122
+
123
+
124
+ prompt_queue = server.PromptServer.instance.prompt_queue
125
+ def requeue_workflow_unchecked():
126
+ """Requeues the current workflow without checking for multiple requeues"""
127
+ currently_running = prompt_queue.currently_running
128
+ print(f'requeue_workflow_unchecked >>>>>> ')
129
+ (_, _, prompt, extra_data, outputs_to_execute) = next(iter(currently_running.values()))
130
+
131
+ #Ensure batch_managers are marked stale
132
+ prompt = prompt.copy()
133
+ for uid in prompt:
134
+ if prompt[uid]['class_type'] == 'BatchManager':
135
+ prompt[uid]['inputs']['requeue'] = prompt[uid]['inputs'].get('requeue',0)+1
136
+
137
+ #execution.py has guards for concurrency, but server doesn't.
138
+ #TODO: Check that this won't be an issue
139
+ number = -server.PromptServer.instance.number
140
+ server.PromptServer.instance.number += 1
141
+ prompt_id = str(server.uuid.uuid4())
142
+ prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
143
+ print(f'requeue_workflow_unchecked <<<<<<<<<< prompt_id:{prompt_id}, number:{number}')
144
+
145
+ requeue_guard = [None, 0, 0, {}]
146
+ def requeue_workflow(requeue_required=(-1,True)):
147
+ assert(len(prompt_queue.currently_running) == 1)
148
+ global requeue_guard
149
+ (run_number, _, prompt, _, _) = next(iter(prompt_queue.currently_running.values()))
150
+ print(f'requeue_workflow >> run_number:{run_number}\n')
151
+ if requeue_guard[0] != run_number:
152
+ #Calculate a count of how many outputs are managed by a batch manager
153
+ managed_outputs=0
154
+ for bm_uid in prompt:
155
+ if prompt[bm_uid]['class_type'] == 'BatchManager':
156
+ for output_uid in prompt:
157
+ if prompt[output_uid]['class_type'] in ["VideoSaver"]:
158
+ for inp in prompt[output_uid]['inputs'].values():
159
+ if inp == [bm_uid, 0]:
160
+ managed_outputs+=1
161
+ requeue_guard = [run_number, 0, managed_outputs, {}]
162
+ requeue_guard[1] = requeue_guard[1]+1
163
+ requeue_guard[3][requeue_required[0]] = requeue_required[1]
164
+ if requeue_guard[1] == requeue_guard[2] and max(requeue_guard[3].values()):
165
+ requeue_workflow_unchecked()
custom_nodes/ComfyUI-tbox/nodes/video/batch_node.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import hashlib
3
+ import os
4
+
5
+ class BatchManagerNode:
6
+ def __init__(self, frames_per_batch=-1):
7
+ print("BatchNode init")
8
+ self.frames_per_batch = frames_per_batch
9
+ self.inputs = {}
10
+ self.outputs = {}
11
+ self.unique_id = None
12
+ self.has_closed_inputs = False
13
+ self.total_frames = float('inf')
14
+
15
+ def reset(self):
16
+ print("BatchNode reset")
17
+ self.close_inputs()
18
+ for key in self.outputs:
19
+ if getattr(self.outputs[key][-1], "gi_suspended", False):
20
+ try:
21
+ self.outputs[key][-1].send(None)
22
+ except StopIteration:
23
+ pass
24
+ self.__init__(self.frames_per_batch)
25
+ def has_open_inputs(self):
26
+ return len(self.inputs) > 0
27
+ def close_inputs(self):
28
+ for key in self.inputs:
29
+ if getattr(self.inputs[key][-1], "gi_suspended", False):
30
+ try:
31
+ self.inputs[key][-1].send(1)
32
+ except StopIteration:
33
+ pass
34
+ self.inputs = {}
35
+
36
+ @classmethod
37
+ def INPUT_TYPES(s):
38
+ return {
39
+ "required": {"frames_per_batch": ("INT", {"default": 16, "min": 1, "max": 128, "step": 1})},
40
+ "hidden": {"prompt": "PROMPT", "unique_id": "UNIQUE_ID"},
41
+ }
42
+
43
+ RETURN_TYPES = ("BatchManager",)
44
+ RETURN_NAMES = ("meta_batch",)
45
+ CATEGORY = "tbox/Video"
46
+ FUNCTION = "update_batch"
47
+
48
+ def update_batch(self, frames_per_batch, prompt=None, unique_id=None):
49
+ if unique_id is not None and prompt is not None:
50
+ requeue = prompt[unique_id]['inputs'].get('requeue', 0)
51
+ else:
52
+ requeue = 0
53
+ print(f'update_batch >> unique_id: {unique_id}; requeue: {requeue}')
54
+ if requeue == 0:
55
+ self.reset()
56
+ self.frames_per_batch = frames_per_batch
57
+ self.unique_id = unique_id
58
+ else:
59
+ num_batches = (self.total_frames+self.frames_per_batch-1)//frames_per_batch
60
+ print(f'Meta-Batch {requeue}/{num_batches}')
61
+ #onExecuted seems to not be called unless some message is sent
62
+ return (self,)
63
+
64
+ @classmethod
65
+ def IS_CHANGED(self, frames_per_batch, prompt=None, unique_id=None):
66
+ print(f"BatchManagerNode >>> IS_CHANGED : {result}")
67
+ random_bytes = os.urandom(32)
68
+ result = hashlib.sha256(random_bytes).hexdigest()
69
+ return result
custom_nodes/ComfyUI-tbox/nodes/video/ffmpeg.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ import torch
4
+ import shutil
5
+ import subprocess
6
+ import folder_paths
7
+ from torch import Tensor
8
+ from collections.abc import Mapping
9
+
10
+ audio_extensions = ['mp3', 'mp4', 'wav', 'ogg']
11
+ video_extensions = ['webm', 'mp4', 'mov']
12
+
13
+ def ffmpeg_suitability(path):
14
+ try:
15
+ version = subprocess.run([path, "-version"], check=True,
16
+ capture_output=True).stdout.decode("utf-8")
17
+ except:
18
+ return 0
19
+ score = 0
20
+ #rough layout of the importance of various features
21
+ simple_criterion = [("libvpx", 20),("264",10), ("265",3),
22
+ ("svtav1",5),("libopus", 1)]
23
+ for criterion in simple_criterion:
24
+ if version.find(criterion[0]) >= 0:
25
+ score += criterion[1]
26
+ #obtain rough compile year from copyright information
27
+ copyright_index = version.find('2000-2')
28
+ if copyright_index >= 0:
29
+ copyright_year = version[copyright_index+6:copyright_index+9]
30
+ if copyright_year.isnumeric():
31
+ score += int(copyright_year)
32
+ return score
33
+
34
+ folder_paths.folder_names_and_paths["VHS_video_formats"] = (
35
+ [
36
+ os.path.join(os.path.dirname(os.path.abspath(__file__)), "video_formats"),
37
+ ],
38
+ [".json"]
39
+ )
40
+
41
+
42
+ if "VHS_FORCE_FFMPEG_PATH" in os.environ:
43
+ ffmpeg_path = os.environ.get("VHS_FORCE_FFMPEG_PATH")
44
+ else:
45
+ ffmpeg_paths = []
46
+ try:
47
+ from imageio_ffmpeg import get_ffmpeg_exe
48
+ imageio_ffmpeg_path = get_ffmpeg_exe()
49
+ ffmpeg_paths.append(imageio_ffmpeg_path)
50
+ except:
51
+ if "VHS_USE_IMAGEIO_FFMPEG" in os.environ:
52
+ raise
53
+ print("Failed to import imageio_ffmpeg")
54
+ if "VHS_USE_IMAGEIO_FFMPEG" in os.environ:
55
+ ffmpeg_path = imageio_ffmpeg_path
56
+ else:
57
+ system_ffmpeg = shutil.which("ffmpeg")
58
+ if system_ffmpeg is not None:
59
+ ffmpeg_paths.append(system_ffmpeg)
60
+ if os.path.isfile("ffmpeg"):
61
+ ffmpeg_paths.append(os.path.abspath("ffmpeg"))
62
+ if os.path.isfile("ffmpeg.exe"):
63
+ ffmpeg_paths.append(os.path.abspath("ffmpeg.exe"))
64
+ if len(ffmpeg_paths) == 0:
65
+ print("No valid ffmpeg found.")
66
+ ffmpeg_path = None
67
+ elif len(ffmpeg_paths) == 1:
68
+ #Evaluation of suitability isn't required, can take sole option
69
+ #to reduce startup time
70
+ ffmpeg_path = ffmpeg_paths[0]
71
+ else:
72
+ ffmpeg_path = max(ffmpeg_paths, key=ffmpeg_suitability)
73
+
74
+ gifski_path = os.environ.get("VHS_GIFSKI", None)
75
+ if gifski_path is None:
76
+ gifski_path = os.environ.get("JOV_GIFSKI", None)
77
+ if gifski_path is None:
78
+ gifski_path = shutil.which("gifski")
79
+
80
+ ytdl_path = os.environ.get("VHS_YTDL", None) or shutil.which('yt-dlp') \
81
+ or shutil.which('youtube-dl')
82
+
83
+ def get_audio(file, start_time=0, duration=0):
84
+ args = [ffmpeg_path, "-i", file]
85
+ if start_time > 0:
86
+ args += ["-ss", str(start_time)]
87
+ if duration > 0:
88
+ args += ["-t", str(duration)]
89
+ try:
90
+ #TODO: scan for sample rate and maintain
91
+ res = subprocess.run(args + ["-f", "f32le", "-"],
92
+ capture_output=True, check=True)
93
+ audio = torch.frombuffer(bytearray(res.stdout), dtype=torch.float32)
94
+ match = re.search(', (\\d+) Hz, (\\w+), ',res.stderr.decode('utf-8'))
95
+ except subprocess.CalledProcessError as e:
96
+ raise Exception(f"VHS failed to extract audio from {file}:\n" \
97
+ + e.stderr.decode("utf-8"))
98
+ if match:
99
+ ar = int(match.group(1))
100
+ #NOTE: Just throwing an error for other channel types right now
101
+ #Will deal with issues if they come
102
+ ac = {"mono": 1, "stereo": 2}[match.group(2)]
103
+ else:
104
+ ar = 44100
105
+ ac = 2
106
+ audio = audio.reshape((-1,ac)).transpose(0,1).unsqueeze(0)
107
+ return {'waveform': audio, 'sample_rate': ar}
108
+
109
+ class LazyAudioMap(Mapping):
110
+ def __init__(self, file, start_time, duration):
111
+ self.file = file
112
+ self.start_time=start_time
113
+ self.duration=duration
114
+ self._dict=None
115
+ def __getitem__(self, key):
116
+ if self._dict is None:
117
+ self._dict = get_audio(self.file, self.start_time, self.duration)
118
+ return self._dict[key]
119
+ def __iter__(self):
120
+ if self._dict is None:
121
+ self._dict = get_audio(self.file, self.start_time, self.duration)
122
+ return iter(self._dict)
123
+ def __len__(self):
124
+ if self._dict is None:
125
+ self._dict = get_audio(self.file, self.start_time, self.duration)
126
+ return len(self._dict)
127
+
128
+ def lazy_get_audio(file, start_time=0, duration=0):
129
+ return LazyAudioMap(file, start_time, duration)
custom_nodes/ComfyUI-tbox/nodes/video/info_node.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ class VideoInfoNode:
4
+ @classmethod
5
+ def INPUT_TYPES(s):
6
+ return {
7
+ "required": {
8
+ "video_info": ("VHS_VIDEOINFO",),
9
+ }
10
+ }
11
+
12
+ CATEGORY = "tbox/Video"
13
+
14
+ RETURN_TYPES = ("FLOAT", "INT", "FLOAT", "INT", "INT", "FLOAT","INT", "FLOAT", "INT", "INT")
15
+ RETURN_NAMES = (
16
+ "source_fps",
17
+ "source_frame_count",
18
+ "source_duration",
19
+ "source_width",
20
+ "source_height",
21
+ "loaded_fps",
22
+ "loaded_frame_count",
23
+ "loaded_duration",
24
+ "loaded_width",
25
+ "loaded_height",
26
+ )
27
+ FUNCTION = "get_video_info"
28
+
29
+ def get_video_info(self, video_info):
30
+ keys = ["fps", "frame_count", "duration", "width", "height"]
31
+
32
+ source_info = []
33
+ loaded_info = []
34
+
35
+ for key in keys:
36
+ source_info.append(video_info[f"source_{key}"])
37
+ loaded_info.append(video_info[f"loaded_{key}"])
38
+
39
+ return (*source_info, *loaded_info)
custom_nodes/ComfyUI-tbox/nodes/video/load_node.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import cv2
4
+ import torch
5
+ import requests
6
+ import itertools
7
+ import folder_paths
8
+ import psutil
9
+ import numpy as np
10
+ from comfy.utils import common_upscale
11
+ from io import BytesIO
12
+ from PIL import Image, ImageSequence, ImageOps
13
+ from .ffmpeg import lazy_get_audio, video_extensions
14
+ from ..utils import BIGMAX, DIMMAX, strip_path, validate_path
15
+
16
+
17
+
18
+
19
+ def is_gif(filename) -> bool:
20
+ file_parts = filename.split('.')
21
+ return len(file_parts) > 1 and file_parts[-1] == "gif"
22
+
23
+ def target_size(width, height, force_size, custom_width, custom_height, downscale_ratio=8) -> tuple[int, int]:
24
+ if force_size == "Disabled":
25
+ pass
26
+ elif force_size == "Custom Width" or force_size.endswith('x?'):
27
+ height *= custom_width/width
28
+ width = custom_width
29
+ elif force_size == "Custom Height" or force_size.startswith('?x'):
30
+ width *= custom_height/height
31
+ height = custom_height
32
+ else:
33
+ width = custom_width
34
+ height = custom_height
35
+ width = int(width/downscale_ratio + 0.5) * downscale_ratio
36
+ height = int(height/downscale_ratio + 0.5) * downscale_ratio
37
+ return (width, height)
38
+
39
+ def cv_frame_generator(path, force_rate, frame_load_cap, skip_first_frames,
40
+ select_every_nth, meta_batch=None, unique_id=None):
41
+ video_cap = cv2.VideoCapture(strip_path(path))
42
+ if not video_cap.isOpened():
43
+ raise ValueError(f"{path} could not be loaded with cv.")
44
+
45
+ # extract video metadata
46
+ fps = video_cap.get(cv2.CAP_PROP_FPS)
47
+ width = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
48
+ height = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
49
+ total_frames = int(video_cap.get(cv2.CAP_PROP_FRAME_COUNT))
50
+ duration = total_frames / fps
51
+
52
+ # set video_cap to look at start_index frame
53
+ total_frame_count = 0
54
+ total_frames_evaluated = -1
55
+ frames_added = 0
56
+ base_frame_time = 1 / fps
57
+ prev_frame = None
58
+
59
+ if force_rate == 0:
60
+ target_frame_time = base_frame_time
61
+ else:
62
+ target_frame_time = 1/force_rate
63
+
64
+ yield (width, height, fps, duration, total_frames, target_frame_time)
65
+ if total_frames > 0:
66
+ if force_rate != 0:
67
+ yieldable_frames = int(total_frames / fps * force_rate)
68
+ else:
69
+ yieldable_frames = total_frames
70
+ if frame_load_cap != 0:
71
+ yieldable_frames = min(frame_load_cap, yieldable_frames)
72
+ else:
73
+ yieldable_frames = 0
74
+
75
+ if meta_batch is not None:
76
+ yield yieldable_frames
77
+
78
+ time_offset=target_frame_time - base_frame_time
79
+ while video_cap.isOpened():
80
+ if time_offset < target_frame_time:
81
+ is_returned = video_cap.grab()
82
+ # if didn't return frame, video has ended
83
+ if not is_returned:
84
+ break
85
+ time_offset += base_frame_time
86
+ if time_offset < target_frame_time:
87
+ continue
88
+ time_offset -= target_frame_time
89
+ # if not at start_index, skip doing anything with frame
90
+ total_frame_count += 1
91
+ if total_frame_count <= skip_first_frames:
92
+ continue
93
+ else:
94
+ total_frames_evaluated += 1
95
+
96
+ # if should not be selected, skip doing anything with frame
97
+ if total_frames_evaluated%select_every_nth != 0:
98
+ continue
99
+
100
+ # opencv loads images in BGR format (yuck), so need to convert to RGB for ComfyUI use
101
+ # follow up: can videos ever have an alpha channel?
102
+ # To my testing: No. opencv has no support for alpha
103
+ unused, frame = video_cap.retrieve()
104
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
105
+ # convert frame to comfyui's expected format
106
+ # TODO: frame contains no exif information. Check if opencv2 has already applied
107
+ frame = np.array(frame, dtype=np.float32)
108
+ torch.from_numpy(frame).div_(255)
109
+ if prev_frame is not None:
110
+ inp = yield prev_frame
111
+ if inp is not None:
112
+ #ensure the finally block is called
113
+ return
114
+ prev_frame = frame
115
+ frames_added += 1
116
+
117
+ # if cap exists and we've reached it, stop processing frames
118
+ if frame_load_cap > 0 and frames_added >= frame_load_cap:
119
+ break
120
+
121
+ if meta_batch is not None:
122
+ meta_batch.inputs.pop(unique_id)
123
+ meta_batch.has_closed_inputs = True
124
+ if prev_frame is not None:
125
+ yield prev_frame
126
+
127
+ def batched(it, n):
128
+ while batch := tuple(itertools.islice(it, n)):
129
+ yield batch
130
+
131
+ def load_video_cv(path: str, force_rate: int, force_size: str,
132
+ custom_width: int,custom_height: int, frame_load_cap: int,
133
+ skip_first_frames: int, select_every_nth: int,
134
+ meta_batch=None, unique_id=None,
135
+ memory_limit_mb=None):
136
+
137
+ if meta_batch is None or unique_id not in meta_batch.inputs:
138
+ gen = cv_frame_generator(path, force_rate, frame_load_cap, skip_first_frames,
139
+ select_every_nth, meta_batch, unique_id)
140
+ (width, height, fps, duration, total_frames, target_frame_time) = next(gen)
141
+
142
+ if meta_batch is not None:
143
+ meta_batch.inputs[unique_id] = (gen, width, height, fps, duration, total_frames, target_frame_time)
144
+ yieldable_frames = next(gen)
145
+ if yieldable_frames:
146
+ meta_batch.total_frames = min(meta_batch.total_frames, yieldable_frames)
147
+ else:
148
+ (gen, width, height, fps, duration, total_frames, target_frame_time) = meta_batch.inputs[unique_id]
149
+
150
+ print(f'[{width}x{height}]@{fps} - duration:{duration}, total_frames: {total_frames}')
151
+
152
+ memory_limit = memory_limit_mb
153
+ if memory_limit_mb is not None:
154
+ memory_limit *= 2 ** 20
155
+ else:
156
+ #TODO: verify if garbage collection should be performed here.
157
+ #leaves ~128 MB unreserved for safety
158
+ try:
159
+ memory_limit = (psutil.virtual_memory().available + psutil.swap_memory().free) - 2 ** 27
160
+ except:
161
+ print("Failed to calculate available memory. Memory load limit has been disabled")
162
+
163
+ if memory_limit is not None:
164
+ #TODO: use better estimate for when vae is not None
165
+ #Consider completely ignoring for load_latent case?
166
+ max_loadable_frames = int(memory_limit//(width*height*3*(.1)))
167
+
168
+ if meta_batch is not None:
169
+ if meta_batch.frames_per_batch > max_loadable_frames:
170
+ raise RuntimeError(f"Meta Batch set to {meta_batch.frames_per_batch} frames but only {max_loadable_frames} can fit in memory")
171
+ gen = itertools.islice(gen, meta_batch.frames_per_batch)
172
+ else:
173
+ original_gen = gen
174
+ gen = itertools.islice(gen, max_loadable_frames)
175
+
176
+ downscale_ratio = 8
177
+ frames_per_batch = (1920 * 1080 * 16) // (width * height) or 1
178
+ if force_size != "Disabled":
179
+ new_size = target_size(width, height, force_size, custom_width, custom_height, downscale_ratio)
180
+ if new_size[0] != width or new_size[1] != height:
181
+ def rescale(frame):
182
+ s = torch.from_numpy(np.fromiter(frame, np.dtype((np.float32, (height, width, 3)))))
183
+ s = s.movedim(-1,1)
184
+ s = common_upscale(s, new_size[0], new_size[1], "lanczos", "center")
185
+ return s.movedim(1,-1).numpy()
186
+ gen = itertools.chain.from_iterable(map(rescale, batched(gen, frames_per_batch)))
187
+ else:
188
+ new_size = width, height
189
+
190
+ #Some minor wizardry to eliminate a copy and reduce max memory by a factor of ~2
191
+ images = torch.from_numpy(np.fromiter(gen, np.dtype((np.float32, (new_size[1], new_size[0], 3)))))
192
+ if meta_batch is None and memory_limit is not None:
193
+ try:
194
+ next(original_gen)
195
+ raise RuntimeError(f"Memory limit hit after loading {len(images)} frames. Stopping execution.")
196
+ except StopIteration:
197
+ pass
198
+ if len(images) == 0:
199
+ raise RuntimeError("No frames generated")
200
+
201
+ #Setup lambda for lazy audio capture
202
+ audio = lazy_get_audio(path, skip_first_frames * target_frame_time,
203
+ frame_load_cap*target_frame_time*select_every_nth)
204
+ #Adjust target_frame_time for select_every_nth
205
+ target_frame_time *= select_every_nth
206
+ video_info = {
207
+ "source_fps": fps,
208
+ "source_frame_count": total_frames,
209
+ "source_duration": duration,
210
+ "source_width": width,
211
+ "source_height": height,
212
+ "loaded_fps": 1/target_frame_time,
213
+ "loaded_frame_count": len(images),
214
+ "loaded_duration": len(images) * target_frame_time,
215
+ "loaded_width": new_size[0],
216
+ "loaded_height": new_size[1],
217
+ }
218
+
219
+ return (images, len(images), audio, video_info)
220
+
221
+
222
+
223
+ class LoadVideoNode:
224
+ @classmethod
225
+ def INPUT_TYPES(s):
226
+ return {
227
+ "required": {
228
+ "path": ("STRING", {"default": "/Users/wadahana/Desktop/live-motion2.mp4", "multiline": True, "vhs_path_extensions": video_extensions}),
229
+ "force_rate": ("INT", {"default": 0, "min": 0, "max": 60, "step": 1}),
230
+ "force_size": (["Disabled", "Custom Height", "Custom Width", "Custom", "256x?", "?x256", "256x256", "512x?", "?x512", "512x512"],),
231
+ "custom_width": ("INT", {"default": 512, "min": 0, "max": DIMMAX, "step": 8}),
232
+ "custom_height": ("INT", {"default": 512, "min": 0, "max": DIMMAX, "step": 8}),
233
+ "frame_load_cap": ("INT", {"default": 0, "min": 0, "max": BIGMAX, "step": 1}),
234
+ "skip_first_frames": ("INT", {"default": 0, "min": 0, "max": BIGMAX, "step": 1}),
235
+ "select_every_nth": ("INT", {"default": 1, "min": 1, "max": BIGMAX, "step": 1}),
236
+ },
237
+ "optional": {
238
+ "meta_batch": ("BatchManager",),
239
+ },
240
+ "hidden": {
241
+ "unique_id": "UNIQUE_ID"
242
+ },
243
+ }
244
+
245
+ CATEGORY = "tbox/Video"
246
+
247
+ RETURN_TYPES = ("IMAGE", "INT", "AUDIO", "VHS_VIDEOINFO")
248
+ RETURN_NAMES = ("IMAGE", "frame_count", "audio", "video_info")
249
+
250
+ FUNCTION = "load_video"
251
+
252
+ def load_video(self, **kwargs):
253
+ if kwargs['path'] is None :
254
+ raise Exception("video is not a valid path: " + kwargs['path'])
255
+
256
+ kwargs['path'] = kwargs['path'].split('\n')[0]
257
+ if validate_path(kwargs['path']) != True:
258
+ raise Exception("video is not a valid path: " + kwargs['path'])
259
+ # if is_url(kwargs['video']):
260
+ # kwargs['video'] = try_download_video(kwargs['video']) or kwargs['video']
261
+ return load_video_cv(**kwargs)
custom_nodes/ComfyUI-tbox/nodes/video/save_node.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import cv2
4
+ import sys
5
+ import json
6
+ import torch
7
+ import datetime
8
+ import itertools
9
+ import subprocess
10
+ import folder_paths
11
+ import numpy as np
12
+ from string import Template
13
+ from pathlib import Path
14
+ from PIL import Image, ExifTags
15
+ from PIL.PngImagePlugin import PngInfo
16
+ from .ffmpeg import ffmpeg_path, gifski_path
17
+ from ..utils import tensor_to_bytes, tensor_to_shorts, requeue_workflow
18
+
19
+
20
+
21
+ def gen_format_widgets(video_format):
22
+ for k in video_format:
23
+ if k.endswith("_pass"):
24
+ for i in range(len(video_format[k])):
25
+ if isinstance(video_format[k][i], list):
26
+ item = [video_format[k][i]]
27
+ yield item
28
+ video_format[k][i] = item[0]
29
+ else:
30
+ if isinstance(video_format[k], list):
31
+ item = [video_format[k]]
32
+ yield item
33
+ video_format[k] = item[0]
34
+
35
+ def get_format_widget_defaults(format_name):
36
+ video_format_path = folder_paths.get_full_path("VHS_video_formats", format_name + ".json")
37
+ with open(video_format_path, 'r') as stream:
38
+ video_format = json.load(stream)
39
+ results = {}
40
+ for w in gen_format_widgets(video_format):
41
+ if len(w[0]) > 2 and 'default' in w[0][2]:
42
+ default = w[0][2]['default']
43
+ else:
44
+ if type(w[0][1]) is list:
45
+ default = w[0][1][0]
46
+ else:
47
+ #NOTE: This doesn't respect max/min, but should be good enough as a fallback to a fallback to a fallback
48
+ default = {"BOOLEAN": False, "INT": 0, "FLOAT": 0, "STRING": ""}[w[0][1]]
49
+ results[w[0][0]] = default
50
+ return results
51
+
52
+ def get_video_formats():
53
+ formats = []
54
+ for format_name in folder_paths.get_filename_list("VHS_video_formats"):
55
+ format_name = format_name[:-5]
56
+ formats.append("video/" + format_name)
57
+ return formats
58
+
59
+ def gifski_process(args, video_format, file_path, env):
60
+ frame_data = yield
61
+ with subprocess.Popen(args + video_format['main_pass'] + ['-f', 'yuv4mpegpipe', '-'],
62
+ stderr=subprocess.PIPE, stdin=subprocess.PIPE,
63
+ stdout=subprocess.PIPE, env=env) as procff:
64
+ with subprocess.Popen([gifski_path] + video_format['gifski_pass']
65
+ + ['-q', '-o', file_path, '-'], stderr=subprocess.PIPE,
66
+ stdin=procff.stdout, stdout=subprocess.PIPE,
67
+ env=env) as procgs:
68
+ try:
69
+ while frame_data is not None:
70
+ procff.stdin.write(frame_data)
71
+ frame_data = yield
72
+ procff.stdin.flush()
73
+ procff.stdin.close()
74
+ resff = procff.stderr.read()
75
+ resgs = procgs.stderr.read()
76
+ outgs = procgs.stdout.read()
77
+ except BrokenPipeError as e:
78
+ procff.stdin.close()
79
+ resff = procff.stderr.read()
80
+ resgs = procgs.stderr.read()
81
+ raise Exception("An error occurred while creating gifski output\n" \
82
+ + "Make sure you are using gifski --version >=1.32.0\nffmpeg: " \
83
+ + resff.decode("utf-8") + '\ngifski: ' + resgs.decode("utf-8"))
84
+ if len(resff) > 0:
85
+ print(resff.decode("utf-8"), end="", file=sys.stderr)
86
+ if len(resgs) > 0:
87
+ print(resgs.decode("utf-8"), end="", file=sys.stderr)
88
+ #should always be empty as the quiet flag is passed
89
+ if len(outgs) > 0:
90
+ print(outgs.decode("utf-8"))
91
+
92
+ def ffmpeg_process(args, video_format, video_metadata, file_path, env):
93
+
94
+ res = None
95
+ frame_data = yield
96
+ total_frames_output = 0
97
+ if video_format.get('save_metadata', 'False') != 'False':
98
+ os.makedirs(folder_paths.get_temp_directory(), exist_ok=True)
99
+ metadata = json.dumps(video_metadata)
100
+ metadata_path = os.path.join(folder_paths.get_temp_directory(), "metadata.txt")
101
+ #metadata from file should escape = ; # \ and newline
102
+ metadata = metadata.replace("\\","\\\\")
103
+ metadata = metadata.replace(";","\\;")
104
+ metadata = metadata.replace("#","\\#")
105
+ metadata = metadata.replace("=","\\=")
106
+ metadata = metadata.replace("\n","\\\n")
107
+ metadata = "comment=" + metadata
108
+ with open(metadata_path, "w") as f:
109
+ f.write(";FFMETADATA1\n")
110
+ f.write(metadata)
111
+ m_args = args[:1] + ["-i", metadata_path] + args[1:] + ["-metadata", "creation_time=now"]
112
+ print(f'ffmpeg: {m_args}')
113
+ with subprocess.Popen(m_args + [file_path], stderr=subprocess.PIPE,
114
+ stdin=subprocess.PIPE, env=env) as proc:
115
+ try:
116
+ while frame_data is not None:
117
+ proc.stdin.write(frame_data)
118
+ #TODO: skip flush for increased speed
119
+ frame_data = yield
120
+ total_frames_output+=1
121
+ proc.stdin.flush()
122
+ proc.stdin.close()
123
+ res = proc.stderr.read()
124
+ except BrokenPipeError as e:
125
+ err = proc.stderr.read()
126
+ #Check if output file exists. If it does, the re-execution
127
+ #will also fail. This obscures the cause of the error
128
+ #and seems to never occur concurrent to the metadata issue
129
+ if os.path.exists(file_path):
130
+ raise Exception("An error occurred in the ffmpeg subprocess:\n" \
131
+ + err.decode("utf-8"))
132
+ #Res was not set
133
+ print(err.decode("utf-8"), end="", file=sys.stderr)
134
+ print("An error occurred when saving with metadata")
135
+ if res != b'':
136
+ with subprocess.Popen(args + [file_path], stderr=subprocess.PIPE,
137
+ stdin=subprocess.PIPE, env=env) as proc:
138
+ try:
139
+ while frame_data is not None:
140
+ proc.stdin.write(frame_data)
141
+ frame_data = yield
142
+ total_frames_output+=1
143
+ proc.stdin.flush()
144
+ proc.stdin.close()
145
+ res = proc.stderr.read()
146
+ except BrokenPipeError as e:
147
+ res = proc.stderr.read()
148
+ raise Exception("An error occurred in the ffmpeg subprocess:\n" \
149
+ + res.decode("utf-8"))
150
+ yield total_frames_output
151
+ if len(res) > 0:
152
+ print(res.decode("utf-8"), end="", file=sys.stderr)
153
+
154
+ def to_pingpong(inp):
155
+ if not hasattr(inp, "__getitem__"):
156
+ inp = list(inp)
157
+ yield from inp
158
+ for i in range(len(inp)-2,0,-1):
159
+ yield inp[i]
160
+
161
+ def apply_format_widgets(format_name, kwargs):
162
+ video_format_path = folder_paths.get_full_path("VHS_video_formats", format_name + ".json")
163
+ with open(video_format_path, 'r') as stream:
164
+ video_format = json.load(stream)
165
+
166
+ for w in gen_format_widgets(video_format):
167
+ assert(w[0][0] in kwargs)
168
+ if len(w[0]) > 3:
169
+ w[0] = Template(w[0][3]).substitute(val=kwargs[w[0][0]])
170
+ else:
171
+ w[0] = str(kwargs[w[0][0]])
172
+ return video_format
173
+
174
+ class SaveVideoNode:
175
+ @classmethod
176
+ def INPUT_TYPES(s):
177
+ ffmpeg_formats = get_video_formats()
178
+ return {
179
+ "required": {
180
+ "path": ("STRING", {"multiline": True, "dynamicPrompts": False}),
181
+ "format": (ffmpeg_formats,),
182
+ "quality": ([100, 95, 90, 85, 80, 75, 70, 60, 50], {"default": 100}),
183
+ "pingpong": ("BOOLEAN", {"default": False}),
184
+ },
185
+ "optional": {
186
+ "images": ("IMAGE",),
187
+ "audio": ("AUDIO",),
188
+ "frame_rate": ("INT,FLOAT", { "default": 25.0, "step": 1.0, "min": 1.0, "max": 60.0 }),
189
+ "meta_batch": ("BatchManager",),
190
+ },
191
+ "hidden": {
192
+ "prompt": "PROMPT",
193
+ "unique_id": "UNIQUE_ID"
194
+ },
195
+ }
196
+
197
+ RETURN_TYPES = ()
198
+ CATEGORY = "tbox/Video"
199
+ FUNCTION = "save_video"
200
+ OUTPUT_NODE = True
201
+
202
+ def save_video(
203
+ self,
204
+ path,
205
+ frame_rate=25,
206
+ images=None,
207
+ format="video/h264-mp4",
208
+ quality=85,
209
+ pingpong=False,
210
+ audio=None,
211
+ prompt=None,
212
+ meta_batch=None,
213
+ unique_id=None,
214
+ manual_format_widgets=None,
215
+ ):
216
+ if images is None:
217
+ return {}
218
+ if isinstance(images, torch.Tensor) and images.size(0) == 0:
219
+ return {}
220
+
221
+ if frame_rate < 1:
222
+ frame_rate = 1
223
+ elif frame_rate > 120:
224
+ frame_rate = 120
225
+
226
+ num_frames = len(images)
227
+
228
+ first_image = images[0]
229
+ images = iter(images)
230
+
231
+ file_path = os.path.abspath(path.split('\n')[0])
232
+ output_dir = os.path.dirname(file_path)
233
+ filename = os.path.basename(file_path)
234
+ name, extension = os.path.splitext(filename)
235
+
236
+ output_process = None
237
+
238
+ video_metadata = {}
239
+ if prompt is not None:
240
+ video_metadata["prompt"] = prompt
241
+
242
+ if meta_batch is not None and unique_id in meta_batch.outputs:
243
+ (counter, output_process) = meta_batch.outputs[unique_id]
244
+ else:
245
+ counter = 0
246
+ output_process = None
247
+
248
+ format_type, format_ext = format.split("/")
249
+
250
+ # Use ffmpeg to save a video
251
+ if ffmpeg_path is None:
252
+ raise ProcessLookupError(f"ffmpeg is required for video outputs and could not be found.\nIn order to use video outputs, you must either:\n- Install imageio-ffmpeg with pip,\n- Place a ffmpeg executable in {os.path.abspath('')}, or\n- Install ffmpeg and add it to the system path.")
253
+
254
+ #Acquire additional format_widget values
255
+ kwargs = None
256
+ if manual_format_widgets is None:
257
+ if prompt is not None:
258
+ kwargs = prompt[unique_id]['inputs']
259
+ else:
260
+ manual_format_widgets = {}
261
+
262
+ if kwargs is None:
263
+ kwargs = get_format_widget_defaults(format_ext)
264
+ missing = {}
265
+ for k in kwargs.keys():
266
+ if k in manual_format_widgets:
267
+ kwargs[k] = manual_format_widgets[k]
268
+ else:
269
+ missing[k] = kwargs[k]
270
+ if len(missing) > 0:
271
+ print("Extra format values were not provided, the following defaults will be used: " + str(kwargs) + "\nThis is likely due to usage of ComfyUI-to-python. These values can be manually set by supplying a manual_format_widgets argument")
272
+
273
+ video_format = apply_format_widgets(format_ext, kwargs)
274
+ has_alpha = first_image.shape[-1] == 4
275
+ dim_alignment = video_format.get("dim_alignment", 8)
276
+ if (first_image.shape[1] % dim_alignment) or (first_image.shape[0] % dim_alignment):
277
+ #output frames must be padded
278
+ to_pad = (-first_image.shape[1] % dim_alignment,
279
+ -first_image.shape[0] % dim_alignment)
280
+ padding = (to_pad[0]//2, to_pad[0] - to_pad[0]//2,
281
+ to_pad[1]//2, to_pad[1] - to_pad[1]//2)
282
+ padfunc = torch.nn.ReplicationPad2d(padding)
283
+ def pad(image):
284
+ image = image.permute((2,0,1))#HWC to CHW
285
+ padded = padfunc(image.to(dtype=torch.float32))
286
+ return padded.permute((1,2,0))
287
+ images = map(pad, images)
288
+ new_dims = (-first_image.shape[1] % dim_alignment + first_image.shape[1],
289
+ -first_image.shape[0] % dim_alignment + first_image.shape[0])
290
+ dimensions = f"{new_dims[0]}x{new_dims[1]}"
291
+ print(f"Output images were not of valid resolution and have had padding applied: {dimensions}")
292
+ else:
293
+ dimensions = f"{first_image.shape[1]}x{first_image.shape[0]}"
294
+
295
+ if pingpong:
296
+ if meta_batch is not None:
297
+ print("pingpong is incompatible with batched output")
298
+ images = to_pingpong(images)
299
+
300
+ images = map(tensor_to_bytes, images)
301
+ if has_alpha:
302
+ i_pix_fmt = 'rgba'
303
+ else:
304
+ i_pix_fmt = 'rgb24'
305
+
306
+ args = [ffmpeg_path, "-v", "error", "-f", "rawvideo", "-pix_fmt", i_pix_fmt,
307
+ "-s", dimensions, "-r", str(frame_rate), "-i", "-"]
308
+
309
+ images = map(lambda x: x.tobytes(), images)
310
+ env=os.environ.copy()
311
+ if "environment" in video_format:
312
+ env.update(video_format["environment"])
313
+
314
+ if "pre_pass" in video_format:
315
+ images = [b''.join(images)]
316
+ os.makedirs(folder_paths.get_temp_directory(), exist_ok=True)
317
+ pre_pass_args = args[:13] + video_format['pre_pass']
318
+ try:
319
+ subprocess.run(pre_pass_args, input=images[0], env=env,
320
+ capture_output=True, check=True)
321
+ except subprocess.CalledProcessError as e:
322
+ raise Exception("An error occurred in the ffmpeg prepass:\n" \
323
+ + e.stderr.decode("utf-8"))
324
+ if "inputs_main_pass" in video_format:
325
+ args = args[:13] + video_format['inputs_main_pass'] + args[13:]
326
+
327
+ if output_process is None:
328
+ args += video_format['main_pass']
329
+ output_process = ffmpeg_process(args, video_format, video_metadata, file_path, env)
330
+ #Proceed to first yield
331
+ output_process.send(None)
332
+ if meta_batch is not None:
333
+ meta_batch.outputs[unique_id] = (0, output_process)
334
+
335
+ for image in images:
336
+ output_process.send(image)
337
+ if meta_batch is not None:
338
+ requeue_workflow((meta_batch.unique_id, not meta_batch.has_closed_inputs))
339
+ if meta_batch is None or meta_batch.has_closed_inputs:
340
+ #Close pipe and wait for termination.
341
+ try:
342
+ total_frames_output = output_process.send(None)
343
+ output_process.send(None)
344
+ except StopIteration:
345
+ pass
346
+ if meta_batch is not None:
347
+ meta_batch.outputs.pop(unique_id)
348
+ #if len(meta_batch.outputs) == 0:
349
+ # meta_batch.reset()
350
+ else:
351
+ return {}
352
+
353
+ a_waveform = None
354
+ if audio is not None:
355
+ try:
356
+ #safely check if audio produced by VHS_LoadVideo actually exists
357
+ a_waveform = audio['waveform']
358
+ except:
359
+ print(f'save audio >> not waveform')
360
+ pass
361
+ if a_waveform is not None:
362
+ # Create audio file if input was provided
363
+ output_file_with_audio = f"{name}-audio{extension}"
364
+ output_file_with_audio_path = os.path.join(output_dir, output_file_with_audio)
365
+ if "audio_pass" not in video_format:
366
+ print("Selected video format does not have explicit audio support")
367
+ video_format["audio_pass"] = ["-c:a", "libopus"]
368
+
369
+
370
+ # FFmpeg command with audio re-encoding
371
+ #TODO: expose audio quality options if format widgets makes it in
372
+ #Reconsider forcing apad/shortest
373
+ channels = audio['waveform'].size(1)
374
+ min_audio_dur = total_frames_output / frame_rate + 1
375
+ mux_args = [ffmpeg_path, "-v", "error", "-i", file_path,
376
+ "-ar", str(audio['sample_rate']), "-ac", str(channels),
377
+ "-y","-f", "f32le", "-i", "-", "-c:v", "copy"] \
378
+ + video_format["audio_pass"] \
379
+ + ["-af", "apad=whole_dur="+str(min_audio_dur),
380
+ "-shortest", output_file_with_audio_path]
381
+
382
+ audio_data = audio['waveform'].squeeze(0).transpose(0,1) \
383
+ .numpy().tobytes()
384
+ try:
385
+ res = subprocess.run(mux_args, input=audio_data,
386
+ env=env, capture_output=True, check=True)
387
+ if res.returncode == 0:
388
+ self.replace_file(output_file_with_audio_path, file_path)
389
+ except subprocess.CalledProcessError as e:
390
+ raise Exception("An error occured in the ffmpeg subprocess:\n" \
391
+ + e.stderr.decode("utf-8"))
392
+ if res.stderr:
393
+ print(res.stderr.decode("utf-8"), end="", file=sys.stderr)
394
+
395
+
396
+ return {}
397
+
398
+ @classmethod
399
+ def VALIDATE_INPUTS(self, format, **kwargs):
400
+ return True
401
+
402
+ def replace_file(self, audio_path, file_path):
403
+ try:
404
+ # 删除 file_path 文件(如果存在)
405
+ if os.path.exists(file_path):
406
+ os.remove(file_path)
407
+ print(f"Deleted file: {file_path}")
408
+ else:
409
+ print(f"File not found, skipping deletion: {file_path}")
410
+
411
+ # 将 output_file_with_audio_path 重命名为 file_path
412
+ os.rename(audio_path, file_path)
413
+ print(f"Renamed {audio_path} to {file_path}")
414
+ except Exception as e:
415
+ print(f"An error occurred: {e}")
custom_nodes/ComfyUI-tbox/nodes/video/video_formats/h264-mp4.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "main_pass":
3
+ [
4
+ "-y", "-c:v", "libx264",
5
+ "-pix_fmt", "yuv420p",
6
+ "-crf", "20"
7
+ ],
8
+ "audio_pass": ["-c:a", "aac"],
9
+ "extension": "mp4"
10
+ }
custom_nodes/ComfyUI-tbox/nodes/video/video_formats/h265-mp4.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "main_pass":
3
+ [
4
+ "-y", "-c:v", "libx265",
5
+ "-vtag", "hvc1",
6
+ "-pix_fmt", "yuv420p10le",
7
+ "-crf", "22",
8
+ "-preset", "medium",
9
+ "-x265-params", "log-level=quiet"
10
+ ],
11
+ "audio_pass": ["-c:a", "aac"],
12
+ "extension": "mp4"
13
+ }
custom_nodes/ComfyUI-tbox/nodes/video/video_formats/nvenc_h264-mp4.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "main_pass":
3
+ [
4
+ "-y", "-c:v", "h264_nvenc",
5
+ "-pix_fmt", "yuv420p",
6
+ "-qp", "20"
7
+ ],
8
+ "audio_pass": ["-c:a", "aac"],
9
+ "bitrate": ["bitrate","INT", {"default": 10, "min": 1, "max": 999, "step": 1 }],
10
+ "megabit": ["megabit","BOOLEAN", {"default": true}],
11
+ "extension": "mp4"
12
+ }
custom_nodes/ComfyUI-tbox/nodes/video/video_formats/nvenc_h265-mp4.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "main_pass":
3
+ [
4
+ "-y", "-c:v", "hevc_nvenc",
5
+ "-vtag", "hvc1",
6
+ "-qp", "22"
7
+ ],
8
+ "audio_pass": ["-c:a", "aac"],
9
+ "extension": "mp4"
10
+ }
custom_nodes/ComfyUI-tbox/nodes/video/video_formats/webm.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "main_pass":
3
+ [
4
+ "-y",
5
+ "-crf", 20,
6
+ "-pix_fmt", "yuv420p",
7
+ "-b:v", "0"
8
+ ],
9
+ "audio_pass": ["-c:a", "libvorbis"],
10
+ "extension": "webm"
11
+ }
custom_nodes/ComfyUI-tbox/requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ opencv-python-headless>=4.7.0.72
4
+ scikit-learn
5
+ scikit-image
6
+ insightface
7
+ ultralytics
8
+ onnxruntime-gpu==1.18.0
9
+ onnxruntime==1.18.0
10
+ imageio_ffmpeg
11
+ pykalman
custom_nodes/ComfyUI-tbox/src/canny/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import cv2
3
+ import numpy as np
4
+ from PIL import Image
5
+ from common import resize_image_with_pad, common_input_validate, HWC3
6
+
7
+ class CannyDetector:
8
+ def __call__(self, input_image=None, low_threshold=100, high_threshold=200, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs):
9
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
10
+ detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
11
+ detected_map = cv2.Canny(detected_map, low_threshold, high_threshold)
12
+ detected_map = HWC3(remove_pad(detected_map))
13
+
14
+ if output_type == "pil":
15
+ detected_map = Image.fromarray(detected_map)
16
+
17
+ return detected_map
custom_nodes/ComfyUI-tbox/src/common.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ import os
5
+ import tempfile
6
+ import warnings
7
+ from contextlib import suppress
8
+ from pathlib import Path
9
+ from huggingface_hub import constants, hf_hub_download
10
+ from ast import literal_eval
11
+
12
+ TEMP_DIR = tempfile.gettempdir()
13
+ ANNOTATOR_CKPTS_PATH = os.path.join(Path(__file__).parents[2], 'ckpts')
14
+ USE_SYMLINKS = False
15
+
16
+
17
+ BIGMIN = -(2**53-1)
18
+ BIGMAX = (2**53-1)
19
+
20
+ DIMMAX = 8192
21
+
22
+ try:
23
+ ANNOTATOR_CKPTS_PATH = os.environ['AUX_ANNOTATOR_CKPTS_PATH']
24
+ except:
25
+ warnings.warn("Custom pressesor model path not set successfully.")
26
+ pass
27
+
28
+ try:
29
+ USE_SYMLINKS = literal_eval(os.environ['AUX_USE_SYMLINKS'])
30
+ except:
31
+ warnings.warn("USE_SYMLINKS not set successfully. Using default value: False to download models.")
32
+ pass
33
+
34
+ try:
35
+ TEMP_DIR = os.environ['AUX_TEMP_DIR']
36
+ if len(TEMP_DIR) >= 60:
37
+ warnings.warn(f"custom temp dir is too long. Using default")
38
+ TEMP_DIR = tempfile.gettempdir()
39
+ except:
40
+ warnings.warn(f"custom temp dir not set successfully")
41
+ pass
42
+
43
+ here = Path(__file__).parent.resolve()
44
+
45
+ def safer_memory(x):
46
+ # Fix many MAC/AMD problems
47
+ return np.ascontiguousarray(x.copy()).copy()
48
+
49
+ UPSCALE_METHODS = ["INTER_NEAREST", "INTER_LINEAR", "INTER_AREA", "INTER_CUBIC", "INTER_LANCZOS4"]
50
+ def get_upscale_method(method_str):
51
+ assert method_str in UPSCALE_METHODS, f"Method {method_str} not found in {UPSCALE_METHODS}"
52
+ return getattr(cv2, method_str)
53
+
54
+ def pad64(x):
55
+ return int(np.ceil(float(x) / 64.0) * 64 - x)
56
+
57
+ def resize_image_with_pad(input_image, resolution, upscale_method = "", skip_hwc3=False, mode='edge'):
58
+ if skip_hwc3:
59
+ img = input_image
60
+ else:
61
+ img = HWC3(input_image)
62
+ H_raw, W_raw, _ = img.shape
63
+ if resolution == 0:
64
+ return img, lambda x: x
65
+ k = float(resolution) / float(min(H_raw, W_raw))
66
+ H_target = int(np.round(float(H_raw) * k))
67
+ W_target = int(np.round(float(W_raw) * k))
68
+ img = cv2.resize(img, (W_target, H_target), interpolation=get_upscale_method(upscale_method) if k > 1 else cv2.INTER_AREA)
69
+ H_pad, W_pad = pad64(H_target), pad64(W_target)
70
+ img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode)
71
+
72
+ def remove_pad(x):
73
+ return safer_memory(x[:H_target, :W_target, ...])
74
+
75
+ return safer_memory(img_padded), remove_pad
76
+
77
+
78
+ def common_input_validate(input_image, output_type, **kwargs):
79
+ if "img" in kwargs:
80
+ warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning)
81
+ input_image = kwargs.pop("img")
82
+
83
+ if "return_pil" in kwargs:
84
+ warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning)
85
+ output_type = "pil" if kwargs["return_pil"] else "np"
86
+
87
+ if type(output_type) is bool:
88
+ warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions")
89
+ if output_type:
90
+ output_type = "pil"
91
+
92
+ if input_image is None:
93
+ raise ValueError("input_image must be defined.")
94
+
95
+ if not isinstance(input_image, np.ndarray):
96
+ input_image = np.array(input_image, dtype=np.uint8)
97
+ output_type = output_type or "pil"
98
+ else:
99
+ output_type = output_type or "np"
100
+
101
+ return (input_image, output_type)
102
+
103
+ def custom_hf_download(pretrained_model_or_path, filename, cache_dir=TEMP_DIR, ckpts_dir=ANNOTATOR_CKPTS_PATH, subfolder=str(""), use_symlinks=USE_SYMLINKS, repo_type="model"):
104
+
105
+ print(f'cache_dir: {cache_dir}')
106
+ print(f'ckpts_dir: {ckpts_dir}')
107
+ print(f'use_symlinks: {use_symlinks}')
108
+ local_dir = os.path.join(ckpts_dir, pretrained_model_or_path)
109
+ model_path = os.path.join(local_dir, *subfolder.split('/'), filename)
110
+
111
+ if len(str(model_path)) >= 255:
112
+ warnings.warn(f"Path {model_path} is too long, \n please change annotator_ckpts_path in config.yaml")
113
+
114
+ if not os.path.exists(model_path):
115
+ print(f"Failed to find {model_path}.\n Downloading from huggingface.co")
116
+ print(f"cacher folder is {cache_dir}, you can change it by custom_tmp_path in config.yaml")
117
+ if use_symlinks:
118
+ cache_dir_d = constants.HF_HUB_CACHE # use huggingface newer env variables `HF_HUB_CACHE`
119
+ if cache_dir_d is None:
120
+ import platform
121
+ if platform.system() == "Windows":
122
+ cache_dir_d = os.path.join(os.getenv("USERPROFILE"), ".cache", "huggingface", "hub")
123
+ else:
124
+ cache_dir_d = os.path.join(os.getenv("HOME"), ".cache", "huggingface", "hub")
125
+ try:
126
+ # test_link
127
+ Path(cache_dir_d).mkdir(parents=True, exist_ok=True)
128
+ Path(ckpts_dir).mkdir(parents=True, exist_ok=True)
129
+ (Path(cache_dir_d) / f"linktest_{filename}.txt").touch()
130
+ # symlink instead of link avoid `invalid cross-device link` error.
131
+ os.symlink(os.path.join(cache_dir_d, f"linktest_{filename}.txt"), os.path.join(ckpts_dir, f"linktest_{filename}.txt"))
132
+ print("Using symlinks to download models. \n",\
133
+ "Make sure you have enough space on your cache folder. \n",\
134
+ "And do not purge the cache folder after downloading.\n",\
135
+ "Otherwise, you will have to re-download the models every time you run the script.\n",\
136
+ "You can use USE_SYMLINKS: False in config.yaml to avoid this behavior.")
137
+ except:
138
+ print("Maybe not able to create symlink. Disable using symlinks.")
139
+ use_symlinks = False
140
+ cache_dir_d = os.path.join(cache_dir, "ckpts", pretrained_model_or_path)
141
+ finally: # always remove test link files
142
+ with suppress(FileNotFoundError):
143
+ os.remove(os.path.join(ckpts_dir, f"linktest_{filename}.txt"))
144
+ os.remove(os.path.join(cache_dir_d, f"linktest_{filename}.txt"))
145
+ else:
146
+ cache_dir_d = os.path.join(cache_dir, "ckpts", pretrained_model_or_path)
147
+
148
+ model_path = hf_hub_download(repo_id=pretrained_model_or_path,
149
+ cache_dir=cache_dir_d,
150
+ local_dir=local_dir,
151
+ subfolder=subfolder,
152
+ filename=filename,
153
+ local_dir_use_symlinks=use_symlinks,
154
+ resume_download=True,
155
+ etag_timeout=100,
156
+ repo_type=repo_type
157
+ )
158
+ if not use_symlinks:
159
+ try:
160
+ import shutil
161
+ shutil.rmtree(os.path.join(cache_dir, "ckpts"))
162
+ except Exception as e :
163
+ print(e)
164
+
165
+ print(f"model_path is {model_path}")
166
+
167
+ return model_path
168
+
169
+
170
+ def HWC3(x):
171
+ assert x.dtype == np.uint8
172
+ if x.ndim == 2:
173
+ x = x[:, :, None]
174
+ assert x.ndim == 3
175
+ H, W, C = x.shape
176
+ assert C == 1 or C == 3 or C == 4
177
+ if C == 3:
178
+ return x
179
+ if C == 1:
180
+ return np.concatenate([x, x, x], axis=2)
181
+ if C == 4:
182
+ color = x[:, :, 0:3].astype(np.float32)
183
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
184
+ y = color * alpha + 255.0 * (1.0 - alpha)
185
+ y = y.clip(0, 255).astype(np.uint8)
186
+ return y
custom_nodes/ComfyUI-tbox/src/densepose/__init__.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision # Fix issue Unknown builtin op: torchvision::nms
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange
7
+ from PIL import Image
8
+
9
+ from common import HWC3, resize_image_with_pad, common_input_validate, custom_hf_download
10
+ from .densepose import DensePoseMaskedColormapResultsVisualizer, _extract_i_from_iuvarr, densepose_chart_predictor_output_to_result_with_confidences
11
+
12
+ N_PART_LABELS = 24
13
+ DENSEPOSE_MODEL_NAME = "LayerNorm/DensePose-TorchScript-with-hint-image"
14
+
15
+ class DenseposeDetector:
16
+ def __init__(self, model):
17
+ self.dense_pose_estimation = model
18
+ self.device = "cpu"
19
+ self.result_visualizer = DensePoseMaskedColormapResultsVisualizer(
20
+ alpha=1,
21
+ data_extractor=_extract_i_from_iuvarr,
22
+ segm_extractor=_extract_i_from_iuvarr,
23
+ val_scale = 255.0 / N_PART_LABELS
24
+ )
25
+
26
+ @classmethod
27
+ def from_pretrained(cls, pretrained_model_or_path=DENSEPOSE_MODEL_NAME, filename="densepose_r50_fpn_dl.torchscript"):
28
+ torchscript_model_path = custom_hf_download(pretrained_model_or_path, filename)
29
+ densepose = torch.jit.load(torchscript_model_path, map_location="cpu")
30
+ return cls(densepose)
31
+
32
+ def to(self, device):
33
+ self.dense_pose_estimation.to(device)
34
+ self.device = device
35
+ return self
36
+
37
+ def __call__(self, input_image, detect_resolution=512, output_type="pil", upscale_method="INTER_CUBIC", cmap="viridis", **kwargs):
38
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
39
+ input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
40
+ H, W = input_image.shape[:2]
41
+
42
+ hint_image_canvas = np.zeros([H, W], dtype=np.uint8)
43
+ hint_image_canvas = np.tile(hint_image_canvas[:, :, np.newaxis], [1, 1, 3])
44
+
45
+ input_image = rearrange(torch.from_numpy(input_image).to(self.device), 'h w c -> c h w')
46
+
47
+ pred_boxes, corase_segm, fine_segm, u, v = self.dense_pose_estimation(input_image)
48
+
49
+ extractor = densepose_chart_predictor_output_to_result_with_confidences
50
+ densepose_results = [extractor(pred_boxes[i:i+1], corase_segm[i:i+1], fine_segm[i:i+1], u[i:i+1], v[i:i+1]) for i in range(len(pred_boxes))]
51
+
52
+ if cmap=="viridis":
53
+ self.result_visualizer.mask_visualizer.cmap = cv2.COLORMAP_VIRIDIS
54
+ hint_image = self.result_visualizer.visualize(hint_image_canvas, densepose_results)
55
+ hint_image = cv2.cvtColor(hint_image, cv2.COLOR_BGR2RGB)
56
+ hint_image[:, :, 0][hint_image[:, :, 0] == 0] = 68
57
+ hint_image[:, :, 1][hint_image[:, :, 1] == 0] = 1
58
+ hint_image[:, :, 2][hint_image[:, :, 2] == 0] = 84
59
+ else:
60
+ self.result_visualizer.mask_visualizer.cmap = cv2.COLORMAP_PARULA
61
+ hint_image = self.result_visualizer.visualize(hint_image_canvas, densepose_results)
62
+ hint_image = cv2.cvtColor(hint_image, cv2.COLOR_BGR2RGB)
63
+
64
+ detected_map = remove_pad(HWC3(hint_image))
65
+ if output_type == "pil":
66
+ detected_map = Image.fromarray(detected_map)
67
+ return detected_map
custom_nodes/ComfyUI-tbox/src/densepose/densepose.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import math
3
+ import numpy as np
4
+ from enum import IntEnum
5
+ from typing import List, Tuple, Union
6
+ import torch
7
+ from torch.nn import functional as F
8
+ import logging
9
+ import cv2
10
+
11
+ Image = np.ndarray
12
+ Boxes = torch.Tensor
13
+ ImageSizeType = Tuple[int, int]
14
+ _RawBoxType = Union[List[float], Tuple[float, ...], torch.Tensor, np.ndarray]
15
+ IntTupleBox = Tuple[int, int, int, int]
16
+
17
+ class BoxMode(IntEnum):
18
+ """
19
+ Enum of different ways to represent a box.
20
+ """
21
+
22
+ XYXY_ABS = 0
23
+ """
24
+ (x0, y0, x1, y1) in absolute floating points coordinates.
25
+ The coordinates in range [0, width or height].
26
+ """
27
+ XYWH_ABS = 1
28
+ """
29
+ (x0, y0, w, h) in absolute floating points coordinates.
30
+ """
31
+ XYXY_REL = 2
32
+ """
33
+ Not yet supported!
34
+ (x0, y0, x1, y1) in range [0, 1]. They are relative to the size of the image.
35
+ """
36
+ XYWH_REL = 3
37
+ """
38
+ Not yet supported!
39
+ (x0, y0, w, h) in range [0, 1]. They are relative to the size of the image.
40
+ """
41
+ XYWHA_ABS = 4
42
+ """
43
+ (xc, yc, w, h, a) in absolute floating points coordinates.
44
+ (xc, yc) is the center of the rotated box, and the angle a is in degrees ccw.
45
+ """
46
+
47
+ @staticmethod
48
+ def convert(box: _RawBoxType, from_mode: "BoxMode", to_mode: "BoxMode") -> _RawBoxType:
49
+ """
50
+ Args:
51
+ box: can be a k-tuple, k-list or an Nxk array/tensor, where k = 4 or 5
52
+ from_mode, to_mode (BoxMode)
53
+
54
+ Returns:
55
+ The converted box of the same type.
56
+ """
57
+ if from_mode == to_mode:
58
+ return box
59
+
60
+ original_type = type(box)
61
+ is_numpy = isinstance(box, np.ndarray)
62
+ single_box = isinstance(box, (list, tuple))
63
+ if single_box:
64
+ assert len(box) == 4 or len(box) == 5, (
65
+ "BoxMode.convert takes either a k-tuple/list or an Nxk array/tensor,"
66
+ " where k == 4 or 5"
67
+ )
68
+ arr = torch.tensor(box)[None, :]
69
+ else:
70
+ # avoid modifying the input box
71
+ if is_numpy:
72
+ arr = torch.from_numpy(np.asarray(box)).clone()
73
+ else:
74
+ arr = box.clone()
75
+
76
+ assert to_mode not in [BoxMode.XYXY_REL, BoxMode.XYWH_REL] and from_mode not in [
77
+ BoxMode.XYXY_REL,
78
+ BoxMode.XYWH_REL,
79
+ ], "Relative mode not yet supported!"
80
+
81
+ if from_mode == BoxMode.XYWHA_ABS and to_mode == BoxMode.XYXY_ABS:
82
+ assert (
83
+ arr.shape[-1] == 5
84
+ ), "The last dimension of input shape must be 5 for XYWHA format"
85
+ original_dtype = arr.dtype
86
+ arr = arr.double()
87
+
88
+ w = arr[:, 2]
89
+ h = arr[:, 3]
90
+ a = arr[:, 4]
91
+ c = torch.abs(torch.cos(a * math.pi / 180.0))
92
+ s = torch.abs(torch.sin(a * math.pi / 180.0))
93
+ # This basically computes the horizontal bounding rectangle of the rotated box
94
+ new_w = c * w + s * h
95
+ new_h = c * h + s * w
96
+
97
+ # convert center to top-left corner
98
+ arr[:, 0] -= new_w / 2.0
99
+ arr[:, 1] -= new_h / 2.0
100
+ # bottom-right corner
101
+ arr[:, 2] = arr[:, 0] + new_w
102
+ arr[:, 3] = arr[:, 1] + new_h
103
+
104
+ arr = arr[:, :4].to(dtype=original_dtype)
105
+ elif from_mode == BoxMode.XYWH_ABS and to_mode == BoxMode.XYWHA_ABS:
106
+ original_dtype = arr.dtype
107
+ arr = arr.double()
108
+ arr[:, 0] += arr[:, 2] / 2.0
109
+ arr[:, 1] += arr[:, 3] / 2.0
110
+ angles = torch.zeros((arr.shape[0], 1), dtype=arr.dtype)
111
+ arr = torch.cat((arr, angles), axis=1).to(dtype=original_dtype)
112
+ else:
113
+ if to_mode == BoxMode.XYXY_ABS and from_mode == BoxMode.XYWH_ABS:
114
+ arr[:, 2] += arr[:, 0]
115
+ arr[:, 3] += arr[:, 1]
116
+ elif from_mode == BoxMode.XYXY_ABS and to_mode == BoxMode.XYWH_ABS:
117
+ arr[:, 2] -= arr[:, 0]
118
+ arr[:, 3] -= arr[:, 1]
119
+ else:
120
+ raise NotImplementedError(
121
+ "Conversion from BoxMode {} to {} is not supported yet".format(
122
+ from_mode, to_mode
123
+ )
124
+ )
125
+
126
+ if single_box:
127
+ return original_type(arr.flatten().tolist())
128
+ if is_numpy:
129
+ return arr.numpy()
130
+ else:
131
+ return arr
132
+
133
+ class MatrixVisualizer:
134
+ """
135
+ Base visualizer for matrix data
136
+ """
137
+
138
+ def __init__(
139
+ self,
140
+ inplace=True,
141
+ cmap=cv2.COLORMAP_PARULA,
142
+ val_scale=1.0,
143
+ alpha=0.7,
144
+ interp_method_matrix=cv2.INTER_LINEAR,
145
+ interp_method_mask=cv2.INTER_NEAREST,
146
+ ):
147
+ self.inplace = inplace
148
+ self.cmap = cmap
149
+ self.val_scale = val_scale
150
+ self.alpha = alpha
151
+ self.interp_method_matrix = interp_method_matrix
152
+ self.interp_method_mask = interp_method_mask
153
+
154
+ def visualize(self, image_bgr, mask, matrix, bbox_xywh):
155
+ self._check_image(image_bgr)
156
+ self._check_mask_matrix(mask, matrix)
157
+ if self.inplace:
158
+ image_target_bgr = image_bgr
159
+ else:
160
+ image_target_bgr = image_bgr * 0
161
+ x, y, w, h = [int(v) for v in bbox_xywh]
162
+ if w <= 0 or h <= 0:
163
+ return image_bgr
164
+ mask, matrix = self._resize(mask, matrix, w, h)
165
+ mask_bg = np.tile((mask == 0)[:, :, np.newaxis], [1, 1, 3])
166
+ matrix_scaled = matrix.astype(np.float32) * self.val_scale
167
+ _EPSILON = 1e-6
168
+ if np.any(matrix_scaled > 255 + _EPSILON):
169
+ logger = logging.getLogger(__name__)
170
+ logger.warning(
171
+ f"Matrix has values > {255 + _EPSILON} after " f"scaling, clipping to [0..255]"
172
+ )
173
+ matrix_scaled_8u = matrix_scaled.clip(0, 255).astype(np.uint8)
174
+ matrix_vis = cv2.applyColorMap(matrix_scaled_8u, self.cmap)
175
+ matrix_vis[mask_bg] = image_target_bgr[y : y + h, x : x + w, :][mask_bg]
176
+ image_target_bgr[y : y + h, x : x + w, :] = (
177
+ image_target_bgr[y : y + h, x : x + w, :] * (1.0 - self.alpha) + matrix_vis * self.alpha
178
+ )
179
+ return image_target_bgr.astype(np.uint8)
180
+
181
+ def _resize(self, mask, matrix, w, h):
182
+ if (w != mask.shape[1]) or (h != mask.shape[0]):
183
+ mask = cv2.resize(mask, (w, h), self.interp_method_mask)
184
+ if (w != matrix.shape[1]) or (h != matrix.shape[0]):
185
+ matrix = cv2.resize(matrix, (w, h), self.interp_method_matrix)
186
+ return mask, matrix
187
+
188
+ def _check_image(self, image_rgb):
189
+ assert len(image_rgb.shape) == 3
190
+ assert image_rgb.shape[2] == 3
191
+ assert image_rgb.dtype == np.uint8
192
+
193
+ def _check_mask_matrix(self, mask, matrix):
194
+ assert len(matrix.shape) == 2
195
+ assert len(mask.shape) == 2
196
+ assert mask.dtype == np.uint8
197
+
198
+ class DensePoseResultsVisualizer:
199
+ def visualize(
200
+ self,
201
+ image_bgr: Image,
202
+ results,
203
+ ) -> Image:
204
+ context = self.create_visualization_context(image_bgr)
205
+ for i, result in enumerate(results):
206
+ boxes_xywh, labels, uv = result
207
+ iuv_array = torch.cat(
208
+ (labels[None].type(torch.float32), uv * 255.0)
209
+ ).type(torch.uint8)
210
+ self.visualize_iuv_arr(context, iuv_array.cpu().numpy(), boxes_xywh)
211
+ image_bgr = self.context_to_image_bgr(context)
212
+ return image_bgr
213
+
214
+ def create_visualization_context(self, image_bgr: Image):
215
+ return image_bgr
216
+
217
+ def visualize_iuv_arr(self, context, iuv_arr: np.ndarray, bbox_xywh) -> None:
218
+ pass
219
+
220
+ def context_to_image_bgr(self, context):
221
+ return context
222
+
223
+ def get_image_bgr_from_context(self, context):
224
+ return context
225
+
226
+ class DensePoseMaskedColormapResultsVisualizer(DensePoseResultsVisualizer):
227
+ def __init__(
228
+ self,
229
+ data_extractor,
230
+ segm_extractor,
231
+ inplace=True,
232
+ cmap=cv2.COLORMAP_PARULA,
233
+ alpha=0.7,
234
+ val_scale=1.0,
235
+ **kwargs,
236
+ ):
237
+ self.mask_visualizer = MatrixVisualizer(
238
+ inplace=inplace, cmap=cmap, val_scale=val_scale, alpha=alpha
239
+ )
240
+ self.data_extractor = data_extractor
241
+ self.segm_extractor = segm_extractor
242
+
243
+ def context_to_image_bgr(self, context):
244
+ return context
245
+
246
+ def visualize_iuv_arr(self, context, iuv_arr: np.ndarray, bbox_xywh) -> None:
247
+ image_bgr = self.get_image_bgr_from_context(context)
248
+ matrix = self.data_extractor(iuv_arr)
249
+ segm = self.segm_extractor(iuv_arr)
250
+ mask = np.zeros(matrix.shape, dtype=np.uint8)
251
+ mask[segm > 0] = 1
252
+ image_bgr = self.mask_visualizer.visualize(image_bgr, mask, matrix, bbox_xywh)
253
+
254
+
255
+ def _extract_i_from_iuvarr(iuv_arr):
256
+ return iuv_arr[0, :, :]
257
+
258
+
259
+ def _extract_u_from_iuvarr(iuv_arr):
260
+ return iuv_arr[1, :, :]
261
+
262
+
263
+ def _extract_v_from_iuvarr(iuv_arr):
264
+ return iuv_arr[2, :, :]
265
+
266
+ def make_int_box(box: torch.Tensor) -> IntTupleBox:
267
+ int_box = [0, 0, 0, 0]
268
+ int_box[0], int_box[1], int_box[2], int_box[3] = tuple(box.long().tolist())
269
+ return int_box[0], int_box[1], int_box[2], int_box[3]
270
+
271
+ def densepose_chart_predictor_output_to_result_with_confidences(
272
+ boxes: Boxes,
273
+ coarse_segm,
274
+ fine_segm,
275
+ u, v
276
+
277
+ ):
278
+ boxes_xyxy_abs = boxes.clone()
279
+ boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
280
+ box_xywh = make_int_box(boxes_xywh_abs[0])
281
+
282
+ labels = resample_fine_and_coarse_segm_tensors_to_bbox(fine_segm, coarse_segm, box_xywh).squeeze(0)
283
+ uv = resample_uv_tensors_to_bbox(u, v, labels, box_xywh)
284
+ confidences = []
285
+ return box_xywh, labels, uv
286
+
287
+ def resample_fine_and_coarse_segm_tensors_to_bbox(
288
+ fine_segm: torch.Tensor, coarse_segm: torch.Tensor, box_xywh_abs: IntTupleBox
289
+ ):
290
+ """
291
+ Resample fine and coarse segmentation tensors to the given
292
+ bounding box and derive labels for each pixel of the bounding box
293
+
294
+ Args:
295
+ fine_segm: float tensor of shape [1, C, Hout, Wout]
296
+ coarse_segm: float tensor of shape [1, K, Hout, Wout]
297
+ box_xywh_abs (tuple of 4 int): bounding box given by its upper-left
298
+ corner coordinates, width (W) and height (H)
299
+ Return:
300
+ Labels for each pixel of the bounding box, a long tensor of size [1, H, W]
301
+ """
302
+ x, y, w, h = box_xywh_abs
303
+ w = max(int(w), 1)
304
+ h = max(int(h), 1)
305
+ # coarse segmentation
306
+ coarse_segm_bbox = F.interpolate(
307
+ coarse_segm,
308
+ (h, w),
309
+ mode="bilinear",
310
+ align_corners=False,
311
+ ).argmax(dim=1)
312
+ # combined coarse and fine segmentation
313
+ labels = (
314
+ F.interpolate(fine_segm, (h, w), mode="bilinear", align_corners=False).argmax(dim=1)
315
+ * (coarse_segm_bbox > 0).long()
316
+ )
317
+ return labels
318
+
319
+ def resample_uv_tensors_to_bbox(
320
+ u: torch.Tensor,
321
+ v: torch.Tensor,
322
+ labels: torch.Tensor,
323
+ box_xywh_abs: IntTupleBox,
324
+ ) -> torch.Tensor:
325
+ """
326
+ Resamples U and V coordinate estimates for the given bounding box
327
+
328
+ Args:
329
+ u (tensor [1, C, H, W] of float): U coordinates
330
+ v (tensor [1, C, H, W] of float): V coordinates
331
+ labels (tensor [H, W] of long): labels obtained by resampling segmentation
332
+ outputs for the given bounding box
333
+ box_xywh_abs (tuple of 4 int): bounding box that corresponds to predictor outputs
334
+ Return:
335
+ Resampled U and V coordinates - a tensor [2, H, W] of float
336
+ """
337
+ x, y, w, h = box_xywh_abs
338
+ w = max(int(w), 1)
339
+ h = max(int(h), 1)
340
+ u_bbox = F.interpolate(u, (h, w), mode="bilinear", align_corners=False)
341
+ v_bbox = F.interpolate(v, (h, w), mode="bilinear", align_corners=False)
342
+ uv = torch.zeros([2, h, w], dtype=torch.float32, device=u.device)
343
+ for part_id in range(1, u_bbox.size(1)):
344
+ uv[0][labels == part_id] = u_bbox[0, part_id][labels == part_id]
345
+ uv[1][labels == part_id] = v_bbox[0, part_id][labels == part_id]
346
+ return uv
347
+
custom_nodes/ComfyUI-tbox/src/dwpose/LICENSE ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ OPENPOSE: MULTIPERSON KEYPOINT DETECTION
2
+ SOFTWARE LICENSE AGREEMENT
3
+ ACADEMIC OR NON-PROFIT ORGANIZATION NONCOMMERCIAL RESEARCH USE ONLY
4
+
5
+ BY USING OR DOWNLOADING THE SOFTWARE, YOU ARE AGREEING TO THE TERMS OF THIS LICENSE AGREEMENT. IF YOU DO NOT AGREE WITH THESE TERMS, YOU MAY NOT USE OR DOWNLOAD THE SOFTWARE.
6
+
7
+ This is a license agreement ("Agreement") between your academic institution or non-profit organization or self (called "Licensee" or "You" in this Agreement) and Carnegie Mellon University (called "Licensor" in this Agreement). All rights not specifically granted to you in this Agreement are reserved for Licensor.
8
+
9
+ RESERVATION OF OWNERSHIP AND GRANT OF LICENSE:
10
+ Licensor retains exclusive ownership of any copy of the Software (as defined below) licensed under this Agreement and hereby grants to Licensee a personal, non-exclusive,
11
+ non-transferable license to use the Software for noncommercial research purposes, without the right to sublicense, pursuant to the terms and conditions of this Agreement. As used in this Agreement, the term "Software" means (i) the actual copy of all or any portion of code for program routines made accessible to Licensee by Licensor pursuant to this Agreement, inclusive of backups, updates, and/or merged copies permitted hereunder or subsequently supplied by Licensor, including all or any file structures, programming instructions, user interfaces and screen formats and sequences as well as any and all documentation and instructions related to it, and (ii) all or any derivatives and/or modifications created or made by You to any of the items specified in (i).
12
+
13
+ CONFIDENTIALITY: Licensee acknowledges that the Software is proprietary to Licensor, and as such, Licensee agrees to receive all such materials in confidence and use the Software only in accordance with the terms of this Agreement. Licensee agrees to use reasonable effort to protect the Software from unauthorized use, reproduction, distribution, or publication.
14
+
15
+ COPYRIGHT: The Software is owned by Licensor and is protected by United
16
+ States copyright laws and applicable international treaties and/or conventions.
17
+
18
+ PERMITTED USES: The Software may be used for your own noncommercial internal research purposes. You understand and agree that Licensor is not obligated to implement any suggestions and/or feedback you might provide regarding the Software, but to the extent Licensor does so, you are not entitled to any compensation related thereto.
19
+
20
+ DERIVATIVES: You may create derivatives of or make modifications to the Software, however, You agree that all and any such derivatives and modifications will be owned by Licensor and become a part of the Software licensed to You under this Agreement. You may only use such derivatives and modifications for your own noncommercial internal research purposes, and you may not otherwise use, distribute or copy such derivatives and modifications in violation of this Agreement.
21
+
22
+ BACKUPS: If Licensee is an organization, it may make that number of copies of the Software necessary for internal noncommercial use at a single site within its organization provided that all information appearing in or on the original labels, including the copyright and trademark notices are copied onto the labels of the copies.
23
+
24
+ USES NOT PERMITTED: You may not distribute, copy or use the Software except as explicitly permitted herein. Licensee has not been granted any trademark license as part of this Agreement and may not use the name or mark “OpenPose", "Carnegie Mellon" or any renditions thereof without the prior written permission of Licensor.
25
+
26
+ You may not sell, rent, lease, sublicense, lend, time-share or transfer, in whole or in part, or provide third parties access to prior or present versions (or any parts thereof) of the Software.
27
+
28
+ ASSIGNMENT: You may not assign this Agreement or your rights hereunder without the prior written consent of Licensor. Any attempted assignment without such consent shall be null and void.
29
+
30
+ TERM: The term of the license granted by this Agreement is from Licensee's acceptance of this Agreement by downloading the Software or by using the Software until terminated as provided below.
31
+
32
+ The Agreement automatically terminates without notice if you fail to comply with any provision of this Agreement. Licensee may terminate this Agreement by ceasing using the Software. Upon any termination of this Agreement, Licensee will delete any and all copies of the Software. You agree that all provisions which operate to protect the proprietary rights of Licensor shall remain in force should breach occur and that the obligation of confidentiality described in this Agreement is binding in perpetuity and, as such, survives the term of the Agreement.
33
+
34
+ FEE: Provided Licensee abides completely by the terms and conditions of this Agreement, there is no fee due to Licensor for Licensee's use of the Software in accordance with this Agreement.
35
+
36
+ DISCLAIMER OF WARRANTIES: THE SOFTWARE IS PROVIDED "AS-IS" WITHOUT WARRANTY OF ANY KIND INCLUDING ANY WARRANTIES OF PERFORMANCE OR MERCHANTABILITY OR FITNESS FOR A PARTICULAR USE OR PURPOSE OR OF NON-INFRINGEMENT. LICENSEE BEARS ALL RISK RELATING TO QUALITY AND PERFORMANCE OF THE SOFTWARE AND RELATED MATERIALS.
37
+
38
+ SUPPORT AND MAINTENANCE: No Software support or training by the Licensor is provided as part of this Agreement.
39
+
40
+ EXCLUSIVE REMEDY AND LIMITATION OF LIABILITY: To the maximum extent permitted under applicable law, Licensor shall not be liable for direct, indirect, special, incidental, or consequential damages or lost profits related to Licensee's use of and/or inability to use the Software, even if Licensor is advised of the possibility of such damage.
41
+
42
+ EXPORT REGULATION: Licensee agrees to comply with any and all applicable
43
+ U.S. export control laws, regulations, and/or other laws related to embargoes and sanction programs administered by the Office of Foreign Assets Control.
44
+
45
+ SEVERABILITY: If any provision(s) of this Agreement shall be held to be invalid, illegal, or unenforceable by a court or other tribunal of competent jurisdiction, the validity, legality and enforceability of the remaining provisions shall not in any way be affected or impaired thereby.
46
+
47
+ NO IMPLIED WAIVERS: No failure or delay by Licensor in enforcing any right or remedy under this Agreement shall be construed as a waiver of any future or other exercise of such right or remedy by Licensor.
48
+
49
+ GOVERNING LAW: This Agreement shall be construed and enforced in accordance with the laws of the Commonwealth of Pennsylvania without reference to conflict of laws principles. You consent to the personal jurisdiction of the courts of this County and waive their rights to venue outside of Allegheny County, Pennsylvania.
50
+
51
+ ENTIRE AGREEMENT AND AMENDMENTS: This Agreement constitutes the sole and entire agreement between Licensee and Licensor as to the matter set forth herein and supersedes any previous agreements, understandings, and arrangements between the parties relating hereto.
52
+
53
+
54
+
55
+ ************************************************************************
56
+
57
+ THIRD-PARTY SOFTWARE NOTICES AND INFORMATION
58
+
59
+ This project incorporates material from the project(s) listed below (collectively, "Third Party Code"). This Third Party Code is licensed to you under their original license terms set forth below. We reserves all other rights not expressly granted, whether by implication, estoppel or otherwise.
60
+
61
+ 1. Caffe, version 1.0.0, (https://github.com/BVLC/caffe/)
62
+
63
+ COPYRIGHT
64
+
65
+ All contributions by the University of California:
66
+ Copyright (c) 2014-2017 The Regents of the University of California (Regents)
67
+ All rights reserved.
68
+
69
+ All other contributions:
70
+ Copyright (c) 2014-2017, the respective contributors
71
+ All rights reserved.
72
+
73
+ Caffe uses a shared copyright model: each contributor holds copyright over
74
+ their contributions to Caffe. The project versioning records all such
75
+ contribution and copyright details. If a contributor wants to further mark
76
+ their specific copyright on a particular contribution, they should indicate
77
+ their copyright solely in the commit message of the change when it is
78
+ committed.
79
+
80
+ LICENSE
81
+
82
+ Redistribution and use in source and binary forms, with or without
83
+ modification, are permitted provided that the following conditions are met:
84
+
85
+ 1. Redistributions of source code must retain the above copyright notice, this
86
+ list of conditions and the following disclaimer.
87
+ 2. Redistributions in binary form must reproduce the above copyright notice,
88
+ this list of conditions and the following disclaimer in the documentation
89
+ and/or other materials provided with the distribution.
90
+
91
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
92
+ ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
93
+ WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
94
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
95
+ ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
96
+ (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
97
+ LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
98
+ ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
99
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
100
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
101
+
102
+ CONTRIBUTION AGREEMENT
103
+
104
+ By contributing to the BVLC/caffe repository through pull-request, comment,
105
+ or otherwise, the contributor releases their content to the
106
+ license and copyright terms herein.
107
+
108
+ ************END OF THIRD-PARTY SOFTWARE NOTICES AND INFORMATION**********
custom_nodes/ComfyUI-tbox/src/dwpose/__init__.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Openpose
2
+ # Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose
3
+ # 2nd Edited by https://github.com/Hzzone/pytorch-openpose
4
+ # 3rd Edited by ControlNet
5
+ # 4th Edited by ControlNet (added face and correct hands)
6
+ # 5th Edited by ControlNet (Improved JSON serialization/deserialization, and lots of bug fixs)
7
+ # This preprocessor is licensed by CMU for non-commercial use only.
8
+
9
+ import os
10
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
11
+
12
+ import json
13
+ import torch
14
+ import numpy as np
15
+ from . import util
16
+ from .body import Body, BodyResult, Keypoint
17
+ from .hand import Hand
18
+ from .face import Face
19
+ from .types import PoseResult, HandResult, FaceResult, AnimalPoseResult
20
+ #from huggingface_hub import hf_hub_download
21
+ from .wholebody import Wholebody
22
+ import warnings
23
+ from common import HWC3, resize_image_with_pad, common_input_validate, custom_hf_download
24
+ import cv2
25
+ from PIL import Image
26
+ from .animalpose import AnimalPoseImage
27
+
28
+ from typing import Tuple, List, Callable, Union, Optional
29
+
30
+
31
+ def draw_animalposes(animals: list[list[Keypoint]], H: int, W: int) -> np.ndarray:
32
+ canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
33
+ for animal_pose in animals:
34
+ canvas = draw_animalpose(canvas, animal_pose)
35
+ return canvas
36
+
37
+
38
+ def draw_animalpose(canvas: np.ndarray, keypoints: list[Keypoint]) -> np.ndarray:
39
+ # order of the keypoints for AP10k and a standardized list of colors for limbs
40
+ keypointPairsList = [
41
+ (1, 2),
42
+ (2, 3),
43
+ (1, 3),
44
+ (3, 4),
45
+ (4, 9),
46
+ (9, 10),
47
+ (10, 11),
48
+ (4, 6),
49
+ (6, 7),
50
+ (7, 8),
51
+ (4, 5),
52
+ (5, 15),
53
+ (15, 16),
54
+ (16, 17),
55
+ (5, 12),
56
+ (12, 13),
57
+ (13, 14),
58
+ ]
59
+ colorsList = [
60
+ (255, 255, 255),
61
+ (100, 255, 100),
62
+ (150, 255, 255),
63
+ (100, 50, 255),
64
+ (50, 150, 200),
65
+ (0, 255, 255),
66
+ (0, 150, 0),
67
+ (0, 0, 255),
68
+ (0, 0, 150),
69
+ (255, 50, 255),
70
+ (255, 0, 255),
71
+ (255, 0, 0),
72
+ (150, 0, 0),
73
+ (255, 255, 100),
74
+ (0, 150, 0),
75
+ (255, 255, 0),
76
+ (150, 150, 150),
77
+ ] # 16 colors needed
78
+
79
+ for ind, (i, j) in enumerate(keypointPairsList):
80
+ p1 = keypoints[i - 1]
81
+ p2 = keypoints[j - 1]
82
+
83
+ if p1 is not None and p2 is not None:
84
+ cv2.line(
85
+ canvas,
86
+ (int(p1.x), int(p1.y)),
87
+ (int(p2.x), int(p2.y)),
88
+ colorsList[ind],
89
+ 5,
90
+ )
91
+ return canvas
92
+
93
+
94
+ def draw_poses(poses: List[PoseResult], H, W, draw_body=True, draw_hand=True, draw_face=True):
95
+ """
96
+ Draw the detected poses on an empty canvas.
97
+
98
+ Args:
99
+ poses (List[PoseResult]): A list of PoseResult objects containing the detected poses.
100
+ H (int): The height of the canvas.
101
+ W (int): The width of the canvas.
102
+ draw_body (bool, optional): Whether to draw body keypoints. Defaults to True.
103
+ draw_hand (bool, optional): Whether to draw hand keypoints. Defaults to True.
104
+ draw_face (bool, optional): Whether to draw face keypoints. Defaults to True.
105
+
106
+ Returns:
107
+ numpy.ndarray: A 3D numpy array representing the canvas with the drawn poses.
108
+ """
109
+ canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
110
+
111
+ for pose in poses:
112
+ if draw_body:
113
+ canvas = util.draw_bodypose(canvas, pose.body.keypoints)
114
+
115
+ if draw_hand:
116
+ canvas = util.draw_handpose(canvas, pose.left_hand)
117
+ canvas = util.draw_handpose(canvas, pose.right_hand)
118
+
119
+ if draw_face:
120
+ canvas = util.draw_facepose(canvas, pose.face)
121
+
122
+ return canvas
123
+
124
+
125
+ def decode_json_as_poses(
126
+ pose_json: dict,
127
+ ) -> Tuple[List[PoseResult], List[AnimalPoseResult], int, int]:
128
+ """Decode the json_string complying with the openpose JSON output format
129
+ to poses that controlnet recognizes.
130
+ https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/doc/02_output.md
131
+
132
+ Args:
133
+ json_string: The json string to decode.
134
+
135
+ Returns:
136
+ human_poses
137
+ animal_poses
138
+ canvas_height
139
+ canvas_width
140
+ """
141
+ height = pose_json["canvas_height"]
142
+ width = pose_json["canvas_width"]
143
+
144
+ def chunks(lst, n):
145
+ """Yield successive n-sized chunks from lst."""
146
+ for i in range(0, len(lst), n):
147
+ yield lst[i : i + n]
148
+
149
+ def decompress_keypoints(
150
+ numbers: Optional[List[float]],
151
+ ) -> Optional[List[Optional[Keypoint]]]:
152
+ if not numbers:
153
+ return None
154
+
155
+ assert len(numbers) % 3 == 0
156
+
157
+ def create_keypoint(x, y, c):
158
+ if c < 1.0:
159
+ return None
160
+ keypoint = Keypoint(x, y)
161
+ return keypoint
162
+
163
+ return [create_keypoint(x, y, c) for x, y, c in chunks(numbers, n=3)]
164
+
165
+ return (
166
+ [
167
+ PoseResult(
168
+ body=BodyResult(
169
+ keypoints=decompress_keypoints(pose.get("pose_keypoints_2d"))
170
+ ),
171
+ left_hand=decompress_keypoints(pose.get("hand_left_keypoints_2d")),
172
+ right_hand=decompress_keypoints(pose.get("hand_right_keypoints_2d")),
173
+ face=decompress_keypoints(pose.get("face_keypoints_2d")),
174
+ )
175
+ for pose in pose_json.get("people", [])
176
+ ],
177
+ [decompress_keypoints(pose) for pose in pose_json.get("animals", [])],
178
+ height,
179
+ width,
180
+ )
181
+
182
+
183
+ def encode_poses_as_dict(poses: List[PoseResult], canvas_height: int, canvas_width: int) -> str:
184
+ """ Encode the pose as a dict following openpose JSON output format:
185
+ https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/doc/02_output.md
186
+ """
187
+ def compress_keypoints(keypoints: Union[List[Keypoint], None]) -> Union[List[float], None]:
188
+ if not keypoints:
189
+ return None
190
+
191
+ return [
192
+ value
193
+ for keypoint in keypoints
194
+ for value in (
195
+ [float(keypoint.x), float(keypoint.y), 1.0]
196
+ if keypoint is not None
197
+ else [0.0, 0.0, 0.0]
198
+ )
199
+ ]
200
+
201
+ return {
202
+ 'people': [
203
+ {
204
+ 'pose_keypoints_2d': compress_keypoints(pose.body.keypoints),
205
+ "face_keypoints_2d": compress_keypoints(pose.face),
206
+ "hand_left_keypoints_2d": compress_keypoints(pose.left_hand),
207
+ "hand_right_keypoints_2d":compress_keypoints(pose.right_hand),
208
+ }
209
+ for pose in poses
210
+ ],
211
+ 'canvas_height': canvas_height,
212
+ 'canvas_width': canvas_width,
213
+ }
214
+
215
+ global_cached_dwpose = Wholebody()
216
+
217
+ class DwposeDetector:
218
+ """
219
+ A class for detecting human poses in images using the Dwpose model.
220
+
221
+ Attributes:
222
+ model_dir (str): Path to the directory where the pose models are stored.
223
+ """
224
+ def __init__(self, dw_pose_estimation):
225
+ self.dw_pose_estimation = dw_pose_estimation
226
+
227
+ @classmethod
228
+ def from_pretrained(cls, pretrained_model_or_path, pretrained_det_model_or_path=None, det_filename=None, pose_filename=None, torchscript_device="cuda"):
229
+ global global_cached_dwpose
230
+ pretrained_det_model_or_path = pretrained_det_model_or_path or pretrained_model_or_path
231
+ det_filename = det_filename or "yolox_l.onnx"
232
+ pose_filename = pose_filename or "dw-ll_ucoco_384.onnx"
233
+ det_model_path = custom_hf_download(pretrained_det_model_or_path, det_filename)
234
+ pose_model_path = custom_hf_download(pretrained_model_or_path, pose_filename)
235
+
236
+ print(f"\nDWPose: Using {det_filename} for bbox detection and {pose_filename} for pose estimation")
237
+ if global_cached_dwpose.det is None or global_cached_dwpose.det_filename != det_filename:
238
+ t = Wholebody(det_model_path, None, torchscript_device=torchscript_device)
239
+ t.pose = global_cached_dwpose.pose
240
+ t.pose_filename = global_cached_dwpose.pose
241
+ global_cached_dwpose = t
242
+
243
+ if global_cached_dwpose.pose is None or global_cached_dwpose.pose_filename != pose_filename:
244
+ t = Wholebody(None, pose_model_path, torchscript_device=torchscript_device)
245
+ t.det = global_cached_dwpose.det
246
+ t.det_filename = global_cached_dwpose.det_filename
247
+ global_cached_dwpose = t
248
+ return cls(global_cached_dwpose)
249
+
250
+ def detect_poses(self, oriImg) -> List[PoseResult]:
251
+ with torch.no_grad():
252
+ keypoints_info = self.dw_pose_estimation(oriImg.copy())
253
+ return Wholebody.format_result(keypoints_info)
254
+
255
+ def __call__(self, input_image, detect_resolution=512, include_body=True, include_hand=False, include_face=False, hand_and_face=None, output_type="pil", image_and_json=False, upscale_method="INTER_CUBIC", **kwargs):
256
+ if hand_and_face is not None:
257
+ warnings.warn("hand_and_face is deprecated. Use include_hand and include_face instead.", DeprecationWarning)
258
+ include_hand = hand_and_face
259
+ include_face = hand_and_face
260
+
261
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
262
+ poses = self.detect_poses(input_image)
263
+
264
+ canvas = draw_poses(poses, input_image.shape[0], input_image.shape[1], draw_body=include_body, draw_hand=include_hand, draw_face=include_face)
265
+ canvas, remove_pad = resize_image_with_pad(canvas, detect_resolution, upscale_method)
266
+ detected_map = HWC3(remove_pad(canvas))
267
+
268
+ if output_type == "pil":
269
+ detected_map = Image.fromarray(detected_map)
270
+
271
+ if image_and_json:
272
+ return (detected_map, encode_poses_as_dict(poses, input_image.shape[0], input_image.shape[1]))
273
+
274
+ return detected_map
275
+
276
+ global_cached_animalpose = AnimalPoseImage()
277
+ class AnimalposeDetector:
278
+ """
279
+ A class for detecting animal poses in images using the RTMPose AP10k model.
280
+
281
+ Attributes:
282
+ model_dir (str): Path to the directory where the pose models are stored.
283
+ """
284
+ def __init__(self, animal_pose_estimation):
285
+ self.animal_pose_estimation = animal_pose_estimation
286
+
287
+ @classmethod
288
+ def from_pretrained(cls, pretrained_model_or_path, pretrained_det_model_or_path=None, det_filename="yolox_l.onnx", pose_filename="dw-ll_ucoco_384.onnx", torchscript_device="cuda"):
289
+ global global_cached_animalpose
290
+ det_model_path = custom_hf_download(pretrained_det_model_or_path, det_filename)
291
+ pose_model_path = custom_hf_download(pretrained_model_or_path, pose_filename)
292
+
293
+ print(f"\nAnimalPose: Using {det_filename} for bbox detection and {pose_filename} for pose estimation")
294
+ if global_cached_animalpose.det is None or global_cached_animalpose.det_filename != det_filename:
295
+ t = AnimalPoseImage(det_model_path, None, torchscript_device=torchscript_device)
296
+ t.pose = global_cached_animalpose.pose
297
+ t.pose_filename = global_cached_animalpose.pose
298
+ global_cached_animalpose = t
299
+
300
+ if global_cached_animalpose.pose is None or global_cached_animalpose.pose_filename != pose_filename:
301
+ t = AnimalPoseImage(None, pose_model_path, torchscript_device=torchscript_device)
302
+ t.det = global_cached_animalpose.det
303
+ t.det_filename = global_cached_animalpose.det_filename
304
+ global_cached_animalpose = t
305
+ return cls(global_cached_animalpose)
306
+
307
+ def __call__(self, input_image, detect_resolution=512, output_type="pil", image_and_json=False, upscale_method="INTER_CUBIC", **kwargs):
308
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
309
+ input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
310
+ result = self.animal_pose_estimation(input_image)
311
+ if result is None:
312
+ detected_map = np.zeros_like(input_image)
313
+ openpose_dict = {
314
+ 'version': 'ap10k',
315
+ 'animals': [],
316
+ 'canvas_height': input_image.shape[0],
317
+ 'canvas_width': input_image.shape[1]
318
+ }
319
+ else:
320
+ detected_map, openpose_dict = result
321
+ detected_map = remove_pad(detected_map)
322
+ if output_type == "pil":
323
+ detected_map = Image.fromarray(detected_map)
324
+
325
+ if image_and_json:
326
+ return (detected_map, openpose_dict)
327
+
328
+ return detected_map
custom_nodes/ComfyUI-tbox/src/dwpose/animalpose.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import os
4
+ import cv2
5
+ from .dw_onnx.cv_ox_det import inference_detector as inference_onnx_yolox
6
+ from .dw_onnx.cv_ox_yolo_nas import inference_detector as inference_onnx_yolo_nas
7
+ from .dw_onnx.cv_ox_pose import inference_pose as inference_onnx_pose
8
+
9
+ from .dw_torchscript.jit_det import inference_detector as inference_jit_yolox
10
+ from .dw_torchscript.jit_pose import inference_pose as inference_jit_pose
11
+ from typing import List, Optional
12
+ from .types import PoseResult, BodyResult, Keypoint
13
+ from timeit import default_timer
14
+ from .util import guess_onnx_input_shape_dtype, get_ort_providers, get_model_type, is_model_torchscript
15
+ import json
16
+ import torch
17
+
18
+ def drawBetweenKeypoints(pose_img, keypoints, indexes, color, scaleFactor):
19
+ ind0 = indexes[0] - 1
20
+ ind1 = indexes[1] - 1
21
+
22
+ point1 = (keypoints[ind0][0], keypoints[ind0][1])
23
+ point2 = (keypoints[ind1][0], keypoints[ind1][1])
24
+
25
+ thickness = int(5 // scaleFactor)
26
+
27
+
28
+ cv2.line(pose_img, (int(point1[0]), int(point1[1])), (int(point2[0]), int(point2[1])), color, thickness)
29
+
30
+
31
+ def drawBetweenKeypointsList(pose_img, keypoints, keypointPairsList, colorsList, scaleFactor):
32
+ for ind, keypointPair in enumerate(keypointPairsList):
33
+ drawBetweenKeypoints(pose_img, keypoints, keypointPair, colorsList[ind], scaleFactor)
34
+
35
+ def drawBetweenSetofKeypointLists(pose_img, keypoints_set, keypointPairsList, colorsList, scaleFactor):
36
+ for keypoints in keypoints_set:
37
+ drawBetweenKeypointsList(pose_img, keypoints, keypointPairsList, colorsList, scaleFactor)
38
+
39
+
40
+ def padImg(img, size, blackBorder=True):
41
+ left, right, top, bottom = 0, 0, 0, 0
42
+
43
+ # pad x
44
+ if img.shape[1] < size[1]:
45
+ sidePadding = int((size[1] - img.shape[1]) // 2)
46
+ left = sidePadding
47
+ right = sidePadding
48
+
49
+ # pad extra on right if padding needed is an odd number
50
+ if img.shape[1] % 2 == 1:
51
+ right += 1
52
+
53
+ # pad y
54
+ if img.shape[0] < size[0]:
55
+ topBottomPadding = int((size[0] - img.shape[0]) // 2)
56
+ top = topBottomPadding
57
+ bottom = topBottomPadding
58
+
59
+ # pad extra on bottom if padding needed is an odd number
60
+ if img.shape[0] % 2 == 1:
61
+ bottom += 1
62
+
63
+ if blackBorder:
64
+ paddedImg = cv2.copyMakeBorder(src=img, top=top, bottom=bottom, left=left, right=right, borderType=cv2.BORDER_CONSTANT, value=(0,0,0))
65
+ else:
66
+ paddedImg = cv2.copyMakeBorder(src=img, top=top, bottom=bottom, left=left, right=right, borderType=cv2.BORDER_REPLICATE)
67
+
68
+ return paddedImg
69
+
70
+ def smartCrop(img, size, center):
71
+
72
+ width = img.shape[1]
73
+ height = img.shape[0]
74
+ xSize = size[1]
75
+ ySize = size[0]
76
+ xCenter = center[0]
77
+ yCenter = center[1]
78
+
79
+ if img.shape[0] > size[0] or img.shape[1] > size[1]:
80
+
81
+
82
+ leftMargin = xCenter - xSize//2
83
+ rightMargin = xCenter + xSize//2
84
+ upMargin = yCenter - ySize//2
85
+ downMargin = yCenter + ySize//2
86
+
87
+
88
+ if(leftMargin < 0):
89
+ xCenter += (-leftMargin)
90
+ if(rightMargin > width):
91
+ xCenter -= (rightMargin - width)
92
+
93
+ if(upMargin < 0):
94
+ yCenter -= -upMargin
95
+ if(downMargin > height):
96
+ yCenter -= (downMargin - height)
97
+
98
+
99
+ img = cv2.getRectSubPix(img, size, (xCenter, yCenter))
100
+
101
+
102
+
103
+ return img
104
+
105
+
106
+
107
+ def calculateScaleFactor(img, size, poseSpanX, poseSpanY):
108
+
109
+ poseSpanX = max(poseSpanX, size[0])
110
+
111
+ scaleFactorX = 1
112
+
113
+
114
+ if poseSpanX > size[0]:
115
+ scaleFactorX = size[0] / poseSpanX
116
+
117
+ scaleFactorY = 1
118
+ if poseSpanY > size[1]:
119
+ scaleFactorY = size[1] / poseSpanY
120
+
121
+ scaleFactor = min(scaleFactorX, scaleFactorY)
122
+
123
+
124
+ return scaleFactor
125
+
126
+
127
+
128
+ def scaleImg(img, size, poseSpanX, poseSpanY, scaleFactor):
129
+ scaledImg = img
130
+
131
+ scaledImg = cv2.resize(img, (0, 0), fx=scaleFactor, fy=scaleFactor)
132
+
133
+ return scaledImg, scaleFactor
134
+
135
+ class AnimalPoseImage:
136
+ def __init__(self, det_model_path: Optional[str] = None, pose_model_path: Optional[str] = None, torchscript_device="cuda"):
137
+ self.det_filename = det_model_path and os.path.basename(det_model_path)
138
+ self.pose_filename = pose_model_path and os.path.basename(pose_model_path)
139
+ self.det, self.pose = None, None
140
+ # return type: None ort cv2 torchscript
141
+ self.det_model_type = get_model_type("AnimalPose",self.det_filename)
142
+ self.pose_model_type = get_model_type("AnimalPose",self.pose_filename)
143
+ # Always loads to CPU to avoid building OpenCV.
144
+ cv2_device = 'cpu'
145
+ cv2_backend = cv2.dnn.DNN_BACKEND_OPENCV if cv2_device == 'cpu' else cv2.dnn.DNN_BACKEND_CUDA
146
+ # You need to manually build OpenCV through cmake to work with your GPU.
147
+ cv2_providers = cv2.dnn.DNN_TARGET_CPU if cv2_device == 'cpu' else cv2.dnn.DNN_TARGET_CUDA
148
+ ort_providers = get_ort_providers()
149
+
150
+ if self.det_model_type is None:
151
+ pass
152
+ elif self.det_model_type == "ort":
153
+ try:
154
+ import onnxruntime as ort
155
+ self.det = ort.InferenceSession(det_model_path, providers=ort_providers)
156
+ except:
157
+ print(f"Failed to load onnxruntime with {self.det.get_providers()}.\nPlease change EP_list in the config.yaml and restart ComfyUI")
158
+ self.det = ort.InferenceSession(det_model_path, providers=["CPUExecutionProvider"])
159
+ elif self.det_model_type == "cv2":
160
+ try:
161
+ self.det = cv2.dnn.readNetFromONNX(det_model_path)
162
+ self.det.setPreferableBackend(cv2_backend)
163
+ self.det.setPreferableTarget(cv2_providers)
164
+ except:
165
+ print("TopK operators may not work on your OpenCV, try use onnxruntime with CPUExecutionProvider")
166
+ try:
167
+ import onnxruntime as ort
168
+ self.det = ort.InferenceSession(det_model_path, providers=["CPUExecutionProvider"])
169
+ except:
170
+ print(f"Failed to load {det_model_path}, you can use other models instead")
171
+ else:
172
+ self.det = torch.jit.load(det_model_path)
173
+ self.det.to(torchscript_device)
174
+
175
+ if self.pose_model_type is None:
176
+ pass
177
+ elif self.pose_model_type == "ort":
178
+ try:
179
+ import onnxruntime as ort
180
+ self.pose = ort.InferenceSession(pose_model_path, providers=ort_providers)
181
+ except:
182
+ print(f"Failed to load onnxruntime with {self.pose.get_providers()}.\nPlease change EP_list in the config.yaml and restart ComfyUI")
183
+ self.pose = ort.InferenceSession(pose_model_path, providers=["CPUExecutionProvider"])
184
+ elif self.pose_model_type == "cv2":
185
+ self.pose = cv2.dnn.readNetFromONNX(pose_model_path)
186
+ self.pose.setPreferableBackend(cv2_backend)
187
+ self.pose.setPreferableTarget(cv2_providers)
188
+ else:
189
+ self.pose = torch.jit.load(pose_model_path)
190
+ self.pose.to(torchscript_device)
191
+
192
+ if self.pose_filename is not None:
193
+ self.pose_input_size, _ = guess_onnx_input_shape_dtype(self.pose_filename)
194
+
195
+ def __call__(self, oriImg):
196
+ import torch.utils.benchmark.utils.timer as torch_timer
197
+ detect_classes = list(range(14, 23 + 1)) #https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/datasets/coco.yaml
198
+
199
+ if is_model_torchscript(self.det):
200
+ det_start = torch_timer.timer()
201
+ det_result = inference_jit_yolox(self.det, oriImg, detect_classes=detect_classes)
202
+ print(f"AnimalPose: Bbox {((torch_timer.timer() - det_start) * 1000):.2f}ms")
203
+ else:
204
+ det_start = default_timer()
205
+ det_onnx_dtype = np.float32 if "yolox" in self.det_filename else np.uint8
206
+ if "yolox" in self.det_filename:
207
+ det_result = inference_onnx_yolox(self.det, oriImg, detect_classes=detect_classes, dtype=det_onnx_dtype)
208
+ else:
209
+ #FP16 and INT8 YOLO NAS accept uint8 input
210
+ det_result = inference_onnx_yolo_nas(self.det, oriImg, detect_classes=detect_classes, dtype=det_onnx_dtype)
211
+ print(f"AnimalPose: Bbox {((default_timer() - det_start) * 1000):.2f}ms")
212
+ if (det_result is None) or (det_result.shape[0] == 0):
213
+ openpose_dict = {
214
+ 'version': 'ap10k',
215
+ 'animals': [],
216
+ 'canvas_height': oriImg.shape[0],
217
+ 'canvas_width': oriImg.shape[1]
218
+ }
219
+ return np.zeros_like(oriImg), openpose_dict
220
+
221
+ if is_model_torchscript(self.pose):
222
+ pose_start = torch_timer.timer()
223
+ keypoint_sets, scores = inference_jit_pose(self.pose, det_result, oriImg, self.pose_input_size)
224
+ print(f"AnimalPose: Pose {((torch_timer.timer() - pose_start) * 1000):.2f}ms on {det_result.shape[0]} animals\n")
225
+ else:
226
+ pose_start = default_timer()
227
+ _, pose_onnx_dtype = guess_onnx_input_shape_dtype(self.pose_filename)
228
+ keypoint_sets, scores = inference_onnx_pose(self.pose, det_result, oriImg, self.pose_input_size, dtype=pose_onnx_dtype)
229
+ print(f"AnimalPose: Pose {((default_timer() - pose_start) * 1000):.2f}ms on {det_result.shape[0]} animals\n")
230
+
231
+ animal_kps_scores = []
232
+ pose_img = np.zeros((oriImg.shape[0], oriImg.shape[1], 3), dtype = np.uint8)
233
+ for (idx, keypoints) in enumerate(keypoint_sets):
234
+ # don't use keypoints that go outside the frame in calculations for the center
235
+ interorKeypoints = keypoints[((keypoints[:,0] > 0) & (keypoints[:,0] < oriImg.shape[1])) & ((keypoints[:,1] > 0) & (keypoints[:,1] < oriImg.shape[0]))]
236
+
237
+ xVals = interorKeypoints[:,0]
238
+ yVals = interorKeypoints[:,1]
239
+
240
+ minX = np.amin(xVals)
241
+ minY = np.amin(yVals)
242
+ maxX = np.amax(xVals)
243
+ maxY = np.amax(yVals)
244
+
245
+ poseSpanX = maxX - minX
246
+ poseSpanY = maxY - minY
247
+
248
+ # find mean center
249
+
250
+ xSum = np.sum(xVals)
251
+ ySum = np.sum(yVals)
252
+
253
+ xCenter = xSum // xVals.shape[0]
254
+ yCenter = ySum // yVals.shape[0]
255
+ center_of_keypoints = (xCenter,yCenter)
256
+
257
+ # order of the keypoints for AP10k and a standardized list of colors for limbs
258
+ keypointPairsList = [(1,2), (2,3), (1,3), (3,4), (4,9), (9,10), (10,11), (4,6), (6,7), (7,8), (4,5), (5,15), (15,16), (16,17), (5,12), (12,13), (13,14)]
259
+ colorsList = [(255,255,255), (100,255,100), (150,255,255), (100,50,255), (50,150,200), (0,255,255), (0,150,0), (0,0,255), (0,0,150), (255,50,255), (255,0,255), (255,0,0), (150,0,0), (255,255,100), (0,150,0), (255,255,0), (150,150,150)] # 16 colors needed
260
+
261
+ drawBetweenKeypointsList(pose_img, keypoints, keypointPairsList, colorsList, scaleFactor=1.0)
262
+ score = scores[idx, ..., None]
263
+ score[score > 1.0] = 1.0
264
+ score[score < 0.0] = 0.0
265
+ animal_kps_scores.append(np.concatenate((keypoints, score), axis=-1))
266
+
267
+ openpose_dict = {
268
+ 'version': 'ap10k',
269
+ 'animals': [keypoints.tolist() for keypoints in animal_kps_scores],
270
+ 'canvas_height': oriImg.shape[0],
271
+ 'canvas_width': oriImg.shape[1]
272
+ }
273
+ return pose_img, openpose_dict
custom_nodes/ComfyUI-tbox/src/dwpose/body.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import math
4
+ import time
5
+ from scipy.ndimage.filters import gaussian_filter
6
+ import matplotlib.pyplot as plt
7
+ import matplotlib
8
+ import torch
9
+ from torchvision import transforms
10
+ from typing import NamedTuple, List, Union
11
+
12
+ from . import util
13
+ from .model import bodypose_model
14
+ from .types import Keypoint, BodyResult
15
+
16
+ class Body(object):
17
+ def __init__(self, model_path):
18
+ self.model = bodypose_model()
19
+ # if torch.cuda.is_available():
20
+ # self.model = self.model.cuda()
21
+ # print('cuda')
22
+ model_dict = util.transfer(self.model, torch.load(model_path))
23
+ self.model.load_state_dict(model_dict)
24
+ self.model.eval()
25
+
26
+ def __call__(self, oriImg):
27
+ # scale_search = [0.5, 1.0, 1.5, 2.0]
28
+ scale_search = [0.5]
29
+ boxsize = 368
30
+ stride = 8
31
+ padValue = 128
32
+ thre1 = 0.1
33
+ thre2 = 0.05
34
+ multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
35
+ heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19))
36
+ paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
37
+
38
+ for m in range(len(multiplier)):
39
+ scale = multiplier[m]
40
+ imageToTest = util.smart_resize_k(oriImg, fx=scale, fy=scale)
41
+ imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue)
42
+ im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
43
+ im = np.ascontiguousarray(im)
44
+
45
+ data = torch.from_numpy(im).float()
46
+ if torch.cuda.is_available():
47
+ data = data.cuda()
48
+ # data = data.permute([2, 0, 1]).unsqueeze(0).float()
49
+ with torch.no_grad():
50
+ data = data.to(self.cn_device)
51
+ Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data)
52
+ Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy()
53
+ Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy()
54
+
55
+ # extract outputs, resize, and remove padding
56
+ # heatmap = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[1]].data), (1, 2, 0)) # output 1 is heatmaps
57
+ heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2), (1, 2, 0)) # output 1 is heatmaps
58
+ heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride)
59
+ heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
60
+ heatmap = util.smart_resize(heatmap, (oriImg.shape[0], oriImg.shape[1]))
61
+
62
+ # paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs
63
+ paf = np.transpose(np.squeeze(Mconv7_stage6_L1), (1, 2, 0)) # output 0 is PAFs
64
+ paf = util.smart_resize_k(paf, fx=stride, fy=stride)
65
+ paf = paf[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
66
+ paf = util.smart_resize(paf, (oriImg.shape[0], oriImg.shape[1]))
67
+
68
+ heatmap_avg += heatmap_avg + heatmap / len(multiplier)
69
+ paf_avg += + paf / len(multiplier)
70
+
71
+ all_peaks = []
72
+ peak_counter = 0
73
+
74
+ for part in range(18):
75
+ map_ori = heatmap_avg[:, :, part]
76
+ one_heatmap = gaussian_filter(map_ori, sigma=3)
77
+
78
+ map_left = np.zeros(one_heatmap.shape)
79
+ map_left[1:, :] = one_heatmap[:-1, :]
80
+ map_right = np.zeros(one_heatmap.shape)
81
+ map_right[:-1, :] = one_heatmap[1:, :]
82
+ map_up = np.zeros(one_heatmap.shape)
83
+ map_up[:, 1:] = one_heatmap[:, :-1]
84
+ map_down = np.zeros(one_heatmap.shape)
85
+ map_down[:, :-1] = one_heatmap[:, 1:]
86
+
87
+ peaks_binary = np.logical_and.reduce(
88
+ (one_heatmap >= map_left, one_heatmap >= map_right, one_heatmap >= map_up, one_heatmap >= map_down, one_heatmap > thre1))
89
+ peaks = list(zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])) # note reverse
90
+ peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks]
91
+ peak_id = range(peak_counter, peak_counter + len(peaks))
92
+ peaks_with_score_and_id = [peaks_with_score[i] + (peak_id[i],) for i in range(len(peak_id))]
93
+
94
+ all_peaks.append(peaks_with_score_and_id)
95
+ peak_counter += len(peaks)
96
+
97
+ # find connection in the specified sequence, center 29 is in the position 15
98
+ limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
99
+ [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
100
+ [1, 16], [16, 18], [3, 17], [6, 18]]
101
+ # the middle joints heatmap correpondence
102
+ mapIdx = [[31, 32], [39, 40], [33, 34], [35, 36], [41, 42], [43, 44], [19, 20], [21, 22], \
103
+ [23, 24], [25, 26], [27, 28], [29, 30], [47, 48], [49, 50], [53, 54], [51, 52], \
104
+ [55, 56], [37, 38], [45, 46]]
105
+
106
+ connection_all = []
107
+ special_k = []
108
+ mid_num = 10
109
+
110
+ for k in range(len(mapIdx)):
111
+ score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]]
112
+ candA = all_peaks[limbSeq[k][0] - 1]
113
+ candB = all_peaks[limbSeq[k][1] - 1]
114
+ nA = len(candA)
115
+ nB = len(candB)
116
+ indexA, indexB = limbSeq[k]
117
+ if (nA != 0 and nB != 0):
118
+ connection_candidate = []
119
+ for i in range(nA):
120
+ for j in range(nB):
121
+ vec = np.subtract(candB[j][:2], candA[i][:2])
122
+ norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1])
123
+ norm = max(0.001, norm)
124
+ vec = np.divide(vec, norm)
125
+
126
+ startend = list(zip(np.linspace(candA[i][0], candB[j][0], num=mid_num), \
127
+ np.linspace(candA[i][1], candB[j][1], num=mid_num)))
128
+
129
+ vec_x = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 0] \
130
+ for I in range(len(startend))])
131
+ vec_y = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 1] \
132
+ for I in range(len(startend))])
133
+
134
+ score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(vec_y, vec[1])
135
+ score_with_dist_prior = sum(score_midpts) / len(score_midpts) + min(
136
+ 0.5 * oriImg.shape[0] / norm - 1, 0)
137
+ criterion1 = len(np.nonzero(score_midpts > thre2)[0]) > 0.8 * len(score_midpts)
138
+ criterion2 = score_with_dist_prior > 0
139
+ if criterion1 and criterion2:
140
+ connection_candidate.append(
141
+ [i, j, score_with_dist_prior, score_with_dist_prior + candA[i][2] + candB[j][2]])
142
+
143
+ connection_candidate = sorted(connection_candidate, key=lambda x: x[2], reverse=True)
144
+ connection = np.zeros((0, 5))
145
+ for c in range(len(connection_candidate)):
146
+ i, j, s = connection_candidate[c][0:3]
147
+ if (i not in connection[:, 3] and j not in connection[:, 4]):
148
+ connection = np.vstack([connection, [candA[i][3], candB[j][3], s, i, j]])
149
+ if (len(connection) >= min(nA, nB)):
150
+ break
151
+
152
+ connection_all.append(connection)
153
+ else:
154
+ special_k.append(k)
155
+ connection_all.append([])
156
+
157
+ # last number in each row is the total parts number of that person
158
+ # the second last number in each row is the score of the overall configuration
159
+ subset = -1 * np.ones((0, 20))
160
+ candidate = np.array([item for sublist in all_peaks for item in sublist])
161
+
162
+ for k in range(len(mapIdx)):
163
+ if k not in special_k:
164
+ partAs = connection_all[k][:, 0]
165
+ partBs = connection_all[k][:, 1]
166
+ indexA, indexB = np.array(limbSeq[k]) - 1
167
+
168
+ for i in range(len(connection_all[k])): # = 1:size(temp,1)
169
+ found = 0
170
+ subset_idx = [-1, -1]
171
+ for j in range(len(subset)): # 1:size(subset,1):
172
+ if subset[j][indexA] == partAs[i] or subset[j][indexB] == partBs[i]:
173
+ subset_idx[found] = j
174
+ found += 1
175
+
176
+ if found == 1:
177
+ j = subset_idx[0]
178
+ if subset[j][indexB] != partBs[i]:
179
+ subset[j][indexB] = partBs[i]
180
+ subset[j][-1] += 1
181
+ subset[j][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
182
+ elif found == 2: # if found 2 and disjoint, merge them
183
+ j1, j2 = subset_idx
184
+ membership = ((subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int))[:-2]
185
+ if len(np.nonzero(membership == 2)[0]) == 0: # merge
186
+ subset[j1][:-2] += (subset[j2][:-2] + 1)
187
+ subset[j1][-2:] += subset[j2][-2:]
188
+ subset[j1][-2] += connection_all[k][i][2]
189
+ subset = np.delete(subset, j2, 0)
190
+ else: # as like found == 1
191
+ subset[j1][indexB] = partBs[i]
192
+ subset[j1][-1] += 1
193
+ subset[j1][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
194
+
195
+ # if find no partA in the subset, create a new subset
196
+ elif not found and k < 17:
197
+ row = -1 * np.ones(20)
198
+ row[indexA] = partAs[i]
199
+ row[indexB] = partBs[i]
200
+ row[-1] = 2
201
+ row[-2] = sum(candidate[connection_all[k][i, :2].astype(int), 2]) + connection_all[k][i][2]
202
+ subset = np.vstack([subset, row])
203
+ # delete some rows of subset which has few parts occur
204
+ deleteIdx = []
205
+ for i in range(len(subset)):
206
+ if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4:
207
+ deleteIdx.append(i)
208
+ subset = np.delete(subset, deleteIdx, axis=0)
209
+
210
+ # subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts
211
+ # candidate: x, y, score, id
212
+ return candidate, subset
213
+
214
+ @staticmethod
215
+ def format_body_result(candidate: np.ndarray, subset: np.ndarray) -> List[BodyResult]:
216
+ """
217
+ Format the body results from the candidate and subset arrays into a list of BodyResult objects.
218
+
219
+ Args:
220
+ candidate (np.ndarray): An array of candidates containing the x, y coordinates, score, and id
221
+ for each body part.
222
+ subset (np.ndarray): An array of subsets containing indices to the candidate array for each
223
+ person detected. The last two columns of each row hold the total score and total parts
224
+ of the person.
225
+
226
+ Returns:
227
+ List[BodyResult]: A list of BodyResult objects, where each object represents a person with
228
+ detected keypoints, total score, and total parts.
229
+ """
230
+ return [
231
+ BodyResult(
232
+ keypoints=[
233
+ Keypoint(
234
+ x=candidate[candidate_index][0],
235
+ y=candidate[candidate_index][1],
236
+ score=candidate[candidate_index][2],
237
+ id=candidate[candidate_index][3]
238
+ ) if candidate_index != -1 else None
239
+ for candidate_index in person[:18].astype(int)
240
+ ],
241
+ total_score=person[18],
242
+ total_parts=person[19]
243
+ )
244
+ for person in subset
245
+ ]
246
+
247
+
248
+ if __name__ == "__main__":
249
+ body_estimation = Body('../model/body_pose_model.pth')
250
+
251
+ test_image = '../images/ski.jpg'
252
+ oriImg = cv2.imread(test_image) # B,G,R order
253
+ candidate, subset = body_estimation(oriImg)
254
+ bodies = body_estimation.format_body_result(candidate, subset)
255
+
256
+ canvas = oriImg
257
+ for body in bodies:
258
+ canvas = util.draw_bodypose(canvas, body)
259
+
260
+ plt.imshow(canvas[:, :, [2, 1, 0]])
261
+ plt.show()
custom_nodes/ComfyUI-tbox/src/dwpose/dw_onnx/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ #Dummy file ensuring this package will be recognized
custom_nodes/ComfyUI-tbox/src/dwpose/dw_onnx/cv_ox_det.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ def nms(boxes, scores, nms_thr):
5
+ """Single class NMS implemented in Numpy."""
6
+ x1 = boxes[:, 0]
7
+ y1 = boxes[:, 1]
8
+ x2 = boxes[:, 2]
9
+ y2 = boxes[:, 3]
10
+
11
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
12
+ order = scores.argsort()[::-1]
13
+
14
+ keep = []
15
+ while order.size > 0:
16
+ i = order[0]
17
+ keep.append(i)
18
+ xx1 = np.maximum(x1[i], x1[order[1:]])
19
+ yy1 = np.maximum(y1[i], y1[order[1:]])
20
+ xx2 = np.minimum(x2[i], x2[order[1:]])
21
+ yy2 = np.minimum(y2[i], y2[order[1:]])
22
+
23
+ w = np.maximum(0.0, xx2 - xx1 + 1)
24
+ h = np.maximum(0.0, yy2 - yy1 + 1)
25
+ inter = w * h
26
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
27
+
28
+ inds = np.where(ovr <= nms_thr)[0]
29
+ order = order[inds + 1]
30
+
31
+ return keep
32
+
33
+ def multiclass_nms(boxes, scores, nms_thr, score_thr):
34
+ """Multiclass NMS implemented in Numpy. Class-aware version."""
35
+ final_dets = []
36
+ num_classes = scores.shape[1]
37
+ for cls_ind in range(num_classes):
38
+ cls_scores = scores[:, cls_ind]
39
+ valid_score_mask = cls_scores > score_thr
40
+ if valid_score_mask.sum() == 0:
41
+ continue
42
+ else:
43
+ valid_scores = cls_scores[valid_score_mask]
44
+ valid_boxes = boxes[valid_score_mask]
45
+ keep = nms(valid_boxes, valid_scores, nms_thr)
46
+ if len(keep) > 0:
47
+ cls_inds = np.ones((len(keep), 1)) * cls_ind
48
+ dets = np.concatenate(
49
+ [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
50
+ )
51
+ final_dets.append(dets)
52
+ if len(final_dets) == 0:
53
+ return None
54
+ return np.concatenate(final_dets, 0)
55
+
56
+ def demo_postprocess(outputs, img_size, p6=False):
57
+ grids = []
58
+ expanded_strides = []
59
+ strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
60
+
61
+ hsizes = [img_size[0] // stride for stride in strides]
62
+ wsizes = [img_size[1] // stride for stride in strides]
63
+
64
+ for hsize, wsize, stride in zip(hsizes, wsizes, strides):
65
+ xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
66
+ grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
67
+ grids.append(grid)
68
+ shape = grid.shape[:2]
69
+ expanded_strides.append(np.full((*shape, 1), stride))
70
+
71
+ grids = np.concatenate(grids, 1)
72
+ expanded_strides = np.concatenate(expanded_strides, 1)
73
+ outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
74
+ outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
75
+
76
+ return outputs
77
+
78
+ def preprocess(img, input_size, swap=(2, 0, 1)):
79
+ if len(img.shape) == 3:
80
+ padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
81
+ else:
82
+ padded_img = np.ones(input_size, dtype=np.uint8) * 114
83
+
84
+ r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
85
+ resized_img = cv2.resize(
86
+ img,
87
+ (int(img.shape[1] * r), int(img.shape[0] * r)),
88
+ interpolation=cv2.INTER_LINEAR,
89
+ ).astype(np.uint8)
90
+ padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
91
+
92
+ padded_img = padded_img.transpose(swap)
93
+ padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
94
+ return padded_img, r
95
+
96
+ def inference_detector(session, oriImg, detect_classes=[0], dtype=np.float32):
97
+ input_shape = (640,640)
98
+ img, ratio = preprocess(oriImg, input_shape)
99
+
100
+ input = img[None, :, :, :]
101
+ input = input.astype(dtype)
102
+ if "InferenceSession" in type(session).__name__:
103
+ input_name = session.get_inputs()[0].name
104
+ output = session.run(None, {input_name: input})
105
+ else:
106
+ outNames = session.getUnconnectedOutLayersNames()
107
+ session.setInput(input)
108
+ output = session.forward(outNames)
109
+
110
+ predictions = demo_postprocess(output[0], input_shape)[0]
111
+
112
+ boxes = predictions[:, :4]
113
+ scores = predictions[:, 4:5] * predictions[:, 5:]
114
+
115
+ boxes_xyxy = np.ones_like(boxes)
116
+ boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2.
117
+ boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2.
118
+ boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2.
119
+ boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2.
120
+ boxes_xyxy /= ratio
121
+ dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
122
+ if dets is None:
123
+ return None
124
+ final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
125
+ isscore = final_scores>0.3
126
+ iscat = np.isin(final_cls_inds, detect_classes)
127
+ isbbox = [ i and j for (i, j) in zip(isscore, iscat)]
128
+ final_boxes = final_boxes[isbbox]
129
+ return final_boxes
custom_nodes/ComfyUI-tbox/src/dwpose/dw_onnx/cv_ox_pose.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import cv2
4
+ import numpy as np
5
+
6
+ def preprocess(
7
+ img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256)
8
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
9
+ """Do preprocessing for DWPose model inference.
10
+
11
+ Args:
12
+ img (np.ndarray): Input image in shape.
13
+ input_size (tuple): Input image size in shape (w, h).
14
+
15
+ Returns:
16
+ tuple:
17
+ - resized_img (np.ndarray): Preprocessed image.
18
+ - center (np.ndarray): Center of image.
19
+ - scale (np.ndarray): Scale of image.
20
+ """
21
+ # get shape of image
22
+ img_shape = img.shape[:2]
23
+ out_img, out_center, out_scale = [], [], []
24
+ if len(out_bbox) == 0:
25
+ out_bbox = [[0, 0, img_shape[1], img_shape[0]]]
26
+ for i in range(len(out_bbox)):
27
+ x0 = out_bbox[i][0]
28
+ y0 = out_bbox[i][1]
29
+ x1 = out_bbox[i][2]
30
+ y1 = out_bbox[i][3]
31
+ bbox = np.array([x0, y0, x1, y1])
32
+
33
+ # get center and scale
34
+ center, scale = bbox_xyxy2cs(bbox, padding=1.25)
35
+
36
+ # do affine transformation
37
+ resized_img, scale = top_down_affine(input_size, scale, center, img)
38
+
39
+ # normalize image
40
+ mean = np.array([123.675, 116.28, 103.53])
41
+ std = np.array([58.395, 57.12, 57.375])
42
+ resized_img = (resized_img - mean) / std
43
+
44
+ out_img.append(resized_img)
45
+ out_center.append(center)
46
+ out_scale.append(scale)
47
+
48
+ return out_img, out_center, out_scale
49
+
50
+
51
+ def inference(sess, img, dtype=np.float32):
52
+ """Inference DWPose model. Processing all image segments at once to take advantage of GPU's parallelism ability if onnxruntime is installed
53
+
54
+ Args:
55
+ sess : ONNXRuntime session.
56
+ img : Input image in shape.
57
+
58
+ Returns:
59
+ outputs : Output of DWPose model.
60
+ """
61
+ all_out = []
62
+ # build input
63
+ input = np.stack(img, axis=0).transpose(0, 3, 1, 2)
64
+ input = input.astype(dtype)
65
+ if "InferenceSession" in type(sess).__name__:
66
+ input_name = sess.get_inputs()[0].name
67
+ all_outputs = sess.run(None, {input_name: input})
68
+ for batch_idx in range(len(all_outputs[0])):
69
+ outputs = [all_outputs[i][batch_idx:batch_idx+1,...] for i in range(len(all_outputs))]
70
+ all_out.append(outputs)
71
+ return all_out
72
+
73
+ #OpenCV doesn't support batch processing sadly
74
+ for i in range(len(img)):
75
+ input = img[i].transpose(2, 0, 1)
76
+ input = input[None, :, :, :]
77
+
78
+ outNames = sess.getUnconnectedOutLayersNames()
79
+ sess.setInput(input)
80
+ outputs = sess.forward(outNames)
81
+ all_out.append(outputs)
82
+
83
+ return all_out
84
+
85
+ def postprocess(outputs: List[np.ndarray],
86
+ model_input_size: Tuple[int, int],
87
+ center: Tuple[int, int],
88
+ scale: Tuple[int, int],
89
+ simcc_split_ratio: float = 2.0
90
+ ) -> Tuple[np.ndarray, np.ndarray]:
91
+ """Postprocess for DWPose model output.
92
+
93
+ Args:
94
+ outputs (np.ndarray): Output of RTMPose model.
95
+ model_input_size (tuple): RTMPose model Input image size.
96
+ center (tuple): Center of bbox in shape (x, y).
97
+ scale (tuple): Scale of bbox in shape (w, h).
98
+ simcc_split_ratio (float): Split ratio of simcc.
99
+
100
+ Returns:
101
+ tuple:
102
+ - keypoints (np.ndarray): Rescaled keypoints.
103
+ - scores (np.ndarray): Model predict scores.
104
+ """
105
+ all_key = []
106
+ all_score = []
107
+ for i in range(len(outputs)):
108
+ # use simcc to decode
109
+ simcc_x, simcc_y = outputs[i]
110
+ keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio)
111
+
112
+ # rescale keypoints
113
+ keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2
114
+ all_key.append(keypoints[0])
115
+ all_score.append(scores[0])
116
+
117
+ return np.array(all_key), np.array(all_score)
118
+
119
+
120
+ def bbox_xyxy2cs(bbox: np.ndarray,
121
+ padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]:
122
+ """Transform the bbox format from (x,y,w,h) into (center, scale)
123
+
124
+ Args:
125
+ bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted
126
+ as (left, top, right, bottom)
127
+ padding (float): BBox padding factor that will be multilied to scale.
128
+ Default: 1.0
129
+
130
+ Returns:
131
+ tuple: A tuple containing center and scale.
132
+ - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or
133
+ (n, 2)
134
+ - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or
135
+ (n, 2)
136
+ """
137
+ # convert single bbox from (4, ) to (1, 4)
138
+ dim = bbox.ndim
139
+ if dim == 1:
140
+ bbox = bbox[None, :]
141
+
142
+ # get bbox center and scale
143
+ x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3])
144
+ center = np.hstack([x1 + x2, y1 + y2]) * 0.5
145
+ scale = np.hstack([x2 - x1, y2 - y1]) * padding
146
+
147
+ if dim == 1:
148
+ center = center[0]
149
+ scale = scale[0]
150
+
151
+ return center, scale
152
+
153
+
154
+ def _fix_aspect_ratio(bbox_scale: np.ndarray,
155
+ aspect_ratio: float) -> np.ndarray:
156
+ """Extend the scale to match the given aspect ratio.
157
+
158
+ Args:
159
+ scale (np.ndarray): The image scale (w, h) in shape (2, )
160
+ aspect_ratio (float): The ratio of ``w/h``
161
+
162
+ Returns:
163
+ np.ndarray: The reshaped image scale in (2, )
164
+ """
165
+ w, h = np.hsplit(bbox_scale, [1])
166
+ bbox_scale = np.where(w > h * aspect_ratio,
167
+ np.hstack([w, w / aspect_ratio]),
168
+ np.hstack([h * aspect_ratio, h]))
169
+ return bbox_scale
170
+
171
+
172
+ def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray:
173
+ """Rotate a point by an angle.
174
+
175
+ Args:
176
+ pt (np.ndarray): 2D point coordinates (x, y) in shape (2, )
177
+ angle_rad (float): rotation angle in radian
178
+
179
+ Returns:
180
+ np.ndarray: Rotated point in shape (2, )
181
+ """
182
+ sn, cs = np.sin(angle_rad), np.cos(angle_rad)
183
+ rot_mat = np.array([[cs, -sn], [sn, cs]])
184
+ return rot_mat @ pt
185
+
186
+
187
+ def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray:
188
+ """To calculate the affine matrix, three pairs of points are required. This
189
+ function is used to get the 3rd point, given 2D points a & b.
190
+
191
+ The 3rd point is defined by rotating vector `a - b` by 90 degrees
192
+ anticlockwise, using b as the rotation center.
193
+
194
+ Args:
195
+ a (np.ndarray): The 1st point (x,y) in shape (2, )
196
+ b (np.ndarray): The 2nd point (x,y) in shape (2, )
197
+
198
+ Returns:
199
+ np.ndarray: The 3rd point.
200
+ """
201
+ direction = a - b
202
+ c = b + np.r_[-direction[1], direction[0]]
203
+ return c
204
+
205
+
206
+ def get_warp_matrix(center: np.ndarray,
207
+ scale: np.ndarray,
208
+ rot: float,
209
+ output_size: Tuple[int, int],
210
+ shift: Tuple[float, float] = (0., 0.),
211
+ inv: bool = False) -> np.ndarray:
212
+ """Calculate the affine transformation matrix that can warp the bbox area
213
+ in the input image to the output size.
214
+
215
+ Args:
216
+ center (np.ndarray[2, ]): Center of the bounding box (x, y).
217
+ scale (np.ndarray[2, ]): Scale of the bounding box
218
+ wrt [width, height].
219
+ rot (float): Rotation angle (degree).
220
+ output_size (np.ndarray[2, ] | list(2,)): Size of the
221
+ destination heatmaps.
222
+ shift (0-100%): Shift translation ratio wrt the width/height.
223
+ Default (0., 0.).
224
+ inv (bool): Option to inverse the affine transform direction.
225
+ (inv=False: src->dst or inv=True: dst->src)
226
+
227
+ Returns:
228
+ np.ndarray: A 2x3 transformation matrix
229
+ """
230
+ shift = np.array(shift)
231
+ src_w = scale[0]
232
+ dst_w = output_size[0]
233
+ dst_h = output_size[1]
234
+
235
+ # compute transformation matrix
236
+ rot_rad = np.deg2rad(rot)
237
+ src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad)
238
+ dst_dir = np.array([0., dst_w * -0.5])
239
+
240
+ # get four corners of the src rectangle in the original image
241
+ src = np.zeros((3, 2), dtype=np.float32)
242
+ src[0, :] = center + scale * shift
243
+ src[1, :] = center + src_dir + scale * shift
244
+ src[2, :] = _get_3rd_point(src[0, :], src[1, :])
245
+
246
+ # get four corners of the dst rectangle in the input image
247
+ dst = np.zeros((3, 2), dtype=np.float32)
248
+ dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
249
+ dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
250
+ dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
251
+
252
+ if inv:
253
+ warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src))
254
+ else:
255
+ warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst))
256
+
257
+ return warp_mat
258
+
259
+
260
+ def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict,
261
+ img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
262
+ """Get the bbox image as the model input by affine transform.
263
+
264
+ Args:
265
+ input_size (dict): The input size of the model.
266
+ bbox_scale (dict): The bbox scale of the img.
267
+ bbox_center (dict): The bbox center of the img.
268
+ img (np.ndarray): The original image.
269
+
270
+ Returns:
271
+ tuple: A tuple containing center and scale.
272
+ - np.ndarray[float32]: img after affine transform.
273
+ - np.ndarray[float32]: bbox scale after affine transform.
274
+ """
275
+ w, h = input_size
276
+ warp_size = (int(w), int(h))
277
+
278
+ # reshape bbox to fixed aspect ratio
279
+ bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h)
280
+
281
+ # get the affine matrix
282
+ center = bbox_center
283
+ scale = bbox_scale
284
+ rot = 0
285
+ warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h))
286
+
287
+ # do affine transform
288
+ img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR)
289
+
290
+ return img, bbox_scale
291
+
292
+
293
+ def get_simcc_maximum(simcc_x: np.ndarray,
294
+ simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
295
+ """Get maximum response location and value from simcc representations.
296
+
297
+ Note:
298
+ instance number: N
299
+ num_keypoints: K
300
+ heatmap height: H
301
+ heatmap width: W
302
+
303
+ Args:
304
+ simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
305
+ simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy)
306
+
307
+ Returns:
308
+ tuple:
309
+ - locs (np.ndarray): locations of maximum heatmap responses in shape
310
+ (K, 2) or (N, K, 2)
311
+ - vals (np.ndarray): values of maximum heatmap responses in shape
312
+ (K,) or (N, K)
313
+ """
314
+ N, K, Wx = simcc_x.shape
315
+ simcc_x = simcc_x.reshape(N * K, -1)
316
+ simcc_y = simcc_y.reshape(N * K, -1)
317
+
318
+ # get maximum value locations
319
+ x_locs = np.argmax(simcc_x, axis=1)
320
+ y_locs = np.argmax(simcc_y, axis=1)
321
+ locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
322
+ max_val_x = np.amax(simcc_x, axis=1)
323
+ max_val_y = np.amax(simcc_y, axis=1)
324
+
325
+ # get maximum value across x and y axis
326
+ mask = max_val_x > max_val_y
327
+ max_val_x[mask] = max_val_y[mask]
328
+ vals = max_val_x
329
+ locs[vals <= 0.] = -1
330
+
331
+ # reshape
332
+ locs = locs.reshape(N, K, 2)
333
+ vals = vals.reshape(N, K)
334
+
335
+ return locs, vals
336
+
337
+
338
+ def decode(simcc_x: np.ndarray, simcc_y: np.ndarray,
339
+ simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]:
340
+ """Modulate simcc distribution with Gaussian.
341
+
342
+ Args:
343
+ simcc_x (np.ndarray[K, Wx]): model predicted simcc in x.
344
+ simcc_y (np.ndarray[K, Wy]): model predicted simcc in y.
345
+ simcc_split_ratio (int): The split ratio of simcc.
346
+
347
+ Returns:
348
+ tuple: A tuple containing center and scale.
349
+ - np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2)
350
+ - np.ndarray[float32]: scores in shape (K,) or (n, K)
351
+ """
352
+ keypoints, scores = get_simcc_maximum(simcc_x, simcc_y)
353
+ keypoints /= simcc_split_ratio
354
+
355
+ return keypoints, scores
356
+
357
+
358
+ def inference_pose(session, out_bbox, oriImg, model_input_size=(288, 384), dtype=np.float32):
359
+ resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size)
360
+ outputs = inference(session, resized_img, dtype)
361
+ keypoints, scores = postprocess(outputs, model_input_size, center, scale)
362
+
363
+ return keypoints, scores
custom_nodes/ComfyUI-tbox/src/dwpose/dw_onnx/cv_ox_yolo_nas.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Source: https://github.com/Hyuto/yolo-nas-onnx/tree/master/yolo-nas-py
2
+ # Inspired from: https://github.com/Deci-AI/super-gradients/blob/3.1.1/src/super_gradients/training/processing/processing.py
3
+
4
+ import numpy as np
5
+ import cv2
6
+
7
+ def preprocess(img, input_size, swap=(2, 0, 1)):
8
+ if len(img.shape) == 3:
9
+ padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
10
+ else:
11
+ padded_img = np.ones(input_size, dtype=np.uint8) * 114
12
+
13
+ r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
14
+ resized_img = cv2.resize(
15
+ img,
16
+ (int(img.shape[1] * r), int(img.shape[0] * r)),
17
+ interpolation=cv2.INTER_LINEAR,
18
+ ).astype(np.uint8)
19
+ padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
20
+
21
+ padded_img = padded_img.transpose(swap)
22
+ padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
23
+ return padded_img, r
24
+
25
+ def inference_detector(session, oriImg, detect_classes=[0], dtype=np.uint8):
26
+ """
27
+ This function is only compatible with onnx models exported from the new API with built-in NMS
28
+ ```py
29
+ from super_gradients.conversion.conversion_enums import ExportQuantizationMode
30
+ from super_gradients.common.object_names import Models
31
+ from super_gradients.training import models
32
+
33
+ model = models.get(Models.YOLO_NAS_L, pretrained_weights="coco")
34
+
35
+ export_result = model.export(
36
+ "yolo_nas/yolo_nas_l_fp16.onnx",
37
+ quantization_mode=ExportQuantizationMode.FP16,
38
+ device="cuda"
39
+ )
40
+ ```
41
+ """
42
+ input_shape = (640,640)
43
+ img, ratio = preprocess(oriImg, input_shape)
44
+ input = img[None, :, :, :]
45
+ input = input.astype(dtype)
46
+ if "InferenceSession" in type(session).__name__:
47
+ input_name = session.get_inputs()[0].name
48
+ output = session.run(None, {input_name: input})
49
+ else:
50
+ outNames = session.getUnconnectedOutLayersNames()
51
+ session.setInput(input)
52
+ output = session.forward(outNames)
53
+ num_preds, pred_boxes, pred_scores, pred_classes = output
54
+ num_preds = num_preds[0,0]
55
+ if num_preds == 0:
56
+ return None
57
+ idxs = np.where((np.isin(pred_classes[0, :num_preds], detect_classes)) & (pred_scores[0, :num_preds] > 0.3))
58
+ if (len(idxs) == 0) or (idxs[0].size == 0):
59
+ return None
60
+ return pred_boxes[0, idxs].squeeze(axis=0) / ratio
custom_nodes/ComfyUI-tbox/src/dwpose/dw_torchscript/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ #Dummy file ensuring this package will be recognized
custom_nodes/ComfyUI-tbox/src/dwpose/dw_torchscript/jit_det.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+
5
+ def nms(boxes, scores, nms_thr):
6
+ """Single class NMS implemented in Numpy."""
7
+ x1 = boxes[:, 0]
8
+ y1 = boxes[:, 1]
9
+ x2 = boxes[:, 2]
10
+ y2 = boxes[:, 3]
11
+
12
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
13
+ order = scores.argsort()[::-1]
14
+
15
+ keep = []
16
+ while order.size > 0:
17
+ i = order[0]
18
+ keep.append(i)
19
+ xx1 = np.maximum(x1[i], x1[order[1:]])
20
+ yy1 = np.maximum(y1[i], y1[order[1:]])
21
+ xx2 = np.minimum(x2[i], x2[order[1:]])
22
+ yy2 = np.minimum(y2[i], y2[order[1:]])
23
+
24
+ w = np.maximum(0.0, xx2 - xx1 + 1)
25
+ h = np.maximum(0.0, yy2 - yy1 + 1)
26
+ inter = w * h
27
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
28
+
29
+ inds = np.where(ovr <= nms_thr)[0]
30
+ order = order[inds + 1]
31
+
32
+ return keep
33
+
34
+ def multiclass_nms(boxes, scores, nms_thr, score_thr):
35
+ """Multiclass NMS implemented in Numpy. Class-aware version."""
36
+ final_dets = []
37
+ num_classes = scores.shape[1]
38
+ for cls_ind in range(num_classes):
39
+ cls_scores = scores[:, cls_ind]
40
+ valid_score_mask = cls_scores > score_thr
41
+ if valid_score_mask.sum() == 0:
42
+ continue
43
+ else:
44
+ valid_scores = cls_scores[valid_score_mask]
45
+ valid_boxes = boxes[valid_score_mask]
46
+ keep = nms(valid_boxes, valid_scores, nms_thr)
47
+ if len(keep) > 0:
48
+ cls_inds = np.ones((len(keep), 1)) * cls_ind
49
+ dets = np.concatenate(
50
+ [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
51
+ )
52
+ final_dets.append(dets)
53
+ if len(final_dets) == 0:
54
+ return None
55
+ return np.concatenate(final_dets, 0)
56
+
57
+ def demo_postprocess(outputs, img_size, p6=False):
58
+ grids = []
59
+ expanded_strides = []
60
+ strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
61
+
62
+ hsizes = [img_size[0] // stride for stride in strides]
63
+ wsizes = [img_size[1] // stride for stride in strides]
64
+
65
+ for hsize, wsize, stride in zip(hsizes, wsizes, strides):
66
+ xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
67
+ grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
68
+ grids.append(grid)
69
+ shape = grid.shape[:2]
70
+ expanded_strides.append(np.full((*shape, 1), stride))
71
+
72
+ grids = np.concatenate(grids, 1)
73
+ expanded_strides = np.concatenate(expanded_strides, 1)
74
+ outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
75
+ outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
76
+
77
+ return outputs
78
+
79
+ def preprocess(img, input_size, swap=(2, 0, 1)):
80
+ if len(img.shape) == 3:
81
+ padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
82
+ else:
83
+ padded_img = np.ones(input_size, dtype=np.uint8) * 114
84
+
85
+ r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
86
+ resized_img = cv2.resize(
87
+ img,
88
+ (int(img.shape[1] * r), int(img.shape[0] * r)),
89
+ interpolation=cv2.INTER_LINEAR,
90
+ ).astype(np.uint8)
91
+ padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
92
+
93
+ padded_img = padded_img.transpose(swap)
94
+ padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
95
+ return padded_img, r
96
+
97
+ def inference_detector(model, oriImg, detect_classes=[0]):
98
+ input_shape = (640,640)
99
+ img, ratio = preprocess(oriImg, input_shape)
100
+
101
+ device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
102
+ input = img[None, :, :, :]
103
+ input = torch.from_numpy(input).to(device, dtype)
104
+
105
+ output = model(input).float().cpu().detach().numpy()
106
+ predictions = demo_postprocess(output[0], input_shape)
107
+
108
+ boxes = predictions[:, :4]
109
+ scores = predictions[:, 4:5] * predictions[:, 5:]
110
+
111
+ boxes_xyxy = np.ones_like(boxes)
112
+ boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2.
113
+ boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2.
114
+ boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2.
115
+ boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2.
116
+ boxes_xyxy /= ratio
117
+ dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
118
+ if dets is None:
119
+ return None
120
+ final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
121
+ isscore = final_scores>0.3
122
+ iscat = np.isin(final_cls_inds, detect_classes)
123
+ isbbox = [ i and j for (i, j) in zip(isscore, iscat)]
124
+ final_boxes = final_boxes[isbbox]
125
+ return final_boxes
custom_nodes/ComfyUI-tbox/src/dwpose/dw_torchscript/jit_pose.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+
7
+ def preprocess(
8
+ img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256)
9
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
10
+ """Do preprocessing for DWPose model inference.
11
+
12
+ Args:
13
+ img (np.ndarray): Input image in shape.
14
+ input_size (tuple): Input image size in shape (w, h).
15
+
16
+ Returns:
17
+ tuple:
18
+ - resized_img (np.ndarray): Preprocessed image.
19
+ - center (np.ndarray): Center of image.
20
+ - scale (np.ndarray): Scale of image.
21
+ """
22
+ # get shape of image
23
+ img_shape = img.shape[:2]
24
+ out_img, out_center, out_scale = [], [], []
25
+ if len(out_bbox) == 0:
26
+ out_bbox = [[0, 0, img_shape[1], img_shape[0]]]
27
+ for i in range(len(out_bbox)):
28
+ x0 = out_bbox[i][0]
29
+ y0 = out_bbox[i][1]
30
+ x1 = out_bbox[i][2]
31
+ y1 = out_bbox[i][3]
32
+ bbox = np.array([x0, y0, x1, y1])
33
+
34
+ # get center and scale
35
+ center, scale = bbox_xyxy2cs(bbox, padding=1.25)
36
+
37
+ # do affine transformation
38
+ resized_img, scale = top_down_affine(input_size, scale, center, img)
39
+
40
+ # normalize image
41
+ mean = np.array([123.675, 116.28, 103.53])
42
+ std = np.array([58.395, 57.12, 57.375])
43
+ resized_img = (resized_img - mean) / std
44
+
45
+ out_img.append(resized_img)
46
+ out_center.append(center)
47
+ out_scale.append(scale)
48
+
49
+ return out_img, out_center, out_scale
50
+
51
+ def inference(model, img, bs=5):
52
+ """Inference DWPose model implemented in TorchScript.
53
+
54
+ Args:
55
+ model : TorchScript Model.
56
+ img : Input image in shape.
57
+
58
+ Returns:
59
+ outputs : Output of DWPose model.
60
+ """
61
+ all_out = []
62
+ # build input
63
+ orig_img_count = len(img)
64
+ #Pad zeros to fit batch size
65
+ for _ in range(bs - (orig_img_count % bs)):
66
+ img.append(np.zeros_like(img[0]))
67
+ input = np.stack(img, axis=0).transpose(0, 3, 1, 2)
68
+ device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
69
+ input = torch.from_numpy(input).to(device, dtype)
70
+
71
+ out1, out2 = [], []
72
+ for i in range(input.shape[0] // bs):
73
+ curr_batch_output = model(input[i*bs:(i+1)*bs])
74
+ out1.append(curr_batch_output[0].float())
75
+ out2.append(curr_batch_output[1].float())
76
+ out1, out2 = torch.cat(out1, dim=0)[:orig_img_count], torch.cat(out2, dim=0)[:orig_img_count]
77
+ out1, out2 = out1.float().cpu().detach().numpy(), out2.float().cpu().detach().numpy()
78
+ all_outputs = out1, out2
79
+
80
+ for batch_idx in range(len(all_outputs[0])):
81
+ outputs = [all_outputs[i][batch_idx:batch_idx+1,...] for i in range(len(all_outputs))]
82
+ all_out.append(outputs)
83
+ return all_out
84
+ def postprocess(outputs: List[np.ndarray],
85
+ model_input_size: Tuple[int, int],
86
+ center: Tuple[int, int],
87
+ scale: Tuple[int, int],
88
+ simcc_split_ratio: float = 2.0
89
+ ) -> Tuple[np.ndarray, np.ndarray]:
90
+ """Postprocess for DWPose model output.
91
+
92
+ Args:
93
+ outputs (np.ndarray): Output of RTMPose model.
94
+ model_input_size (tuple): RTMPose model Input image size.
95
+ center (tuple): Center of bbox in shape (x, y).
96
+ scale (tuple): Scale of bbox in shape (w, h).
97
+ simcc_split_ratio (float): Split ratio of simcc.
98
+
99
+ Returns:
100
+ tuple:
101
+ - keypoints (np.ndarray): Rescaled keypoints.
102
+ - scores (np.ndarray): Model predict scores.
103
+ """
104
+ all_key = []
105
+ all_score = []
106
+ for i in range(len(outputs)):
107
+ # use simcc to decode
108
+ simcc_x, simcc_y = outputs[i]
109
+ keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio)
110
+
111
+ # rescale keypoints
112
+ keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2
113
+ all_key.append(keypoints[0])
114
+ all_score.append(scores[0])
115
+
116
+ return np.array(all_key), np.array(all_score)
117
+
118
+
119
+ def bbox_xyxy2cs(bbox: np.ndarray,
120
+ padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]:
121
+ """Transform the bbox format from (x,y,w,h) into (center, scale)
122
+
123
+ Args:
124
+ bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted
125
+ as (left, top, right, bottom)
126
+ padding (float): BBox padding factor that will be multilied to scale.
127
+ Default: 1.0
128
+
129
+ Returns:
130
+ tuple: A tuple containing center and scale.
131
+ - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or
132
+ (n, 2)
133
+ - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or
134
+ (n, 2)
135
+ """
136
+ # convert single bbox from (4, ) to (1, 4)
137
+ dim = bbox.ndim
138
+ if dim == 1:
139
+ bbox = bbox[None, :]
140
+
141
+ # get bbox center and scale
142
+ x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3])
143
+ center = np.hstack([x1 + x2, y1 + y2]) * 0.5
144
+ scale = np.hstack([x2 - x1, y2 - y1]) * padding
145
+
146
+ if dim == 1:
147
+ center = center[0]
148
+ scale = scale[0]
149
+
150
+ return center, scale
151
+
152
+
153
+ def _fix_aspect_ratio(bbox_scale: np.ndarray,
154
+ aspect_ratio: float) -> np.ndarray:
155
+ """Extend the scale to match the given aspect ratio.
156
+
157
+ Args:
158
+ scale (np.ndarray): The image scale (w, h) in shape (2, )
159
+ aspect_ratio (float): The ratio of ``w/h``
160
+
161
+ Returns:
162
+ np.ndarray: The reshaped image scale in (2, )
163
+ """
164
+ w, h = np.hsplit(bbox_scale, [1])
165
+ bbox_scale = np.where(w > h * aspect_ratio,
166
+ np.hstack([w, w / aspect_ratio]),
167
+ np.hstack([h * aspect_ratio, h]))
168
+ return bbox_scale
169
+
170
+
171
+ def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray:
172
+ """Rotate a point by an angle.
173
+
174
+ Args:
175
+ pt (np.ndarray): 2D point coordinates (x, y) in shape (2, )
176
+ angle_rad (float): rotation angle in radian
177
+
178
+ Returns:
179
+ np.ndarray: Rotated point in shape (2, )
180
+ """
181
+ sn, cs = np.sin(angle_rad), np.cos(angle_rad)
182
+ rot_mat = np.array([[cs, -sn], [sn, cs]])
183
+ return rot_mat @ pt
184
+
185
+
186
+ def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray:
187
+ """To calculate the affine matrix, three pairs of points are required. This
188
+ function is used to get the 3rd point, given 2D points a & b.
189
+
190
+ The 3rd point is defined by rotating vector `a - b` by 90 degrees
191
+ anticlockwise, using b as the rotation center.
192
+
193
+ Args:
194
+ a (np.ndarray): The 1st point (x,y) in shape (2, )
195
+ b (np.ndarray): The 2nd point (x,y) in shape (2, )
196
+
197
+ Returns:
198
+ np.ndarray: The 3rd point.
199
+ """
200
+ direction = a - b
201
+ c = b + np.r_[-direction[1], direction[0]]
202
+ return c
203
+
204
+
205
+ def get_warp_matrix(center: np.ndarray,
206
+ scale: np.ndarray,
207
+ rot: float,
208
+ output_size: Tuple[int, int],
209
+ shift: Tuple[float, float] = (0., 0.),
210
+ inv: bool = False) -> np.ndarray:
211
+ """Calculate the affine transformation matrix that can warp the bbox area
212
+ in the input image to the output size.
213
+
214
+ Args:
215
+ center (np.ndarray[2, ]): Center of the bounding box (x, y).
216
+ scale (np.ndarray[2, ]): Scale of the bounding box
217
+ wrt [width, height].
218
+ rot (float): Rotation angle (degree).
219
+ output_size (np.ndarray[2, ] | list(2,)): Size of the
220
+ destination heatmaps.
221
+ shift (0-100%): Shift translation ratio wrt the width/height.
222
+ Default (0., 0.).
223
+ inv (bool): Option to inverse the affine transform direction.
224
+ (inv=False: src->dst or inv=True: dst->src)
225
+
226
+ Returns:
227
+ np.ndarray: A 2x3 transformation matrix
228
+ """
229
+ shift = np.array(shift)
230
+ src_w = scale[0]
231
+ dst_w = output_size[0]
232
+ dst_h = output_size[1]
233
+
234
+ # compute transformation matrix
235
+ rot_rad = np.deg2rad(rot)
236
+ src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad)
237
+ dst_dir = np.array([0., dst_w * -0.5])
238
+
239
+ # get four corners of the src rectangle in the original image
240
+ src = np.zeros((3, 2), dtype=np.float32)
241
+ src[0, :] = center + scale * shift
242
+ src[1, :] = center + src_dir + scale * shift
243
+ src[2, :] = _get_3rd_point(src[0, :], src[1, :])
244
+
245
+ # get four corners of the dst rectangle in the input image
246
+ dst = np.zeros((3, 2), dtype=np.float32)
247
+ dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
248
+ dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
249
+ dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
250
+
251
+ if inv:
252
+ warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src))
253
+ else:
254
+ warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst))
255
+
256
+ return warp_mat
257
+
258
+
259
+ def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict,
260
+ img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
261
+ """Get the bbox image as the model input by affine transform.
262
+
263
+ Args:
264
+ input_size (dict): The input size of the model.
265
+ bbox_scale (dict): The bbox scale of the img.
266
+ bbox_center (dict): The bbox center of the img.
267
+ img (np.ndarray): The original image.
268
+
269
+ Returns:
270
+ tuple: A tuple containing center and scale.
271
+ - np.ndarray[float32]: img after affine transform.
272
+ - np.ndarray[float32]: bbox scale after affine transform.
273
+ """
274
+ w, h = input_size
275
+ warp_size = (int(w), int(h))
276
+
277
+ # reshape bbox to fixed aspect ratio
278
+ bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h)
279
+
280
+ # get the affine matrix
281
+ center = bbox_center
282
+ scale = bbox_scale
283
+ rot = 0
284
+ warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h))
285
+
286
+ # do affine transform
287
+ img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR)
288
+
289
+ return img, bbox_scale
290
+
291
+
292
+ def get_simcc_maximum(simcc_x: np.ndarray,
293
+ simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
294
+ """Get maximum response location and value from simcc representations.
295
+
296
+ Note:
297
+ instance number: N
298
+ num_keypoints: K
299
+ heatmap height: H
300
+ heatmap width: W
301
+
302
+ Args:
303
+ simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
304
+ simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy)
305
+
306
+ Returns:
307
+ tuple:
308
+ - locs (np.ndarray): locations of maximum heatmap responses in shape
309
+ (K, 2) or (N, K, 2)
310
+ - vals (np.ndarray): values of maximum heatmap responses in shape
311
+ (K,) or (N, K)
312
+ """
313
+ N, K, Wx = simcc_x.shape
314
+ simcc_x = simcc_x.reshape(N * K, -1)
315
+ simcc_y = simcc_y.reshape(N * K, -1)
316
+
317
+ # get maximum value locations
318
+ x_locs = np.argmax(simcc_x, axis=1)
319
+ y_locs = np.argmax(simcc_y, axis=1)
320
+ locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
321
+ max_val_x = np.amax(simcc_x, axis=1)
322
+ max_val_y = np.amax(simcc_y, axis=1)
323
+
324
+ # get maximum value across x and y axis
325
+ mask = max_val_x > max_val_y
326
+ max_val_x[mask] = max_val_y[mask]
327
+ vals = max_val_x
328
+ locs[vals <= 0.] = -1
329
+
330
+ # reshape
331
+ locs = locs.reshape(N, K, 2)
332
+ vals = vals.reshape(N, K)
333
+
334
+ return locs, vals
335
+
336
+
337
+ def decode(simcc_x: np.ndarray, simcc_y: np.ndarray,
338
+ simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]:
339
+ """Modulate simcc distribution with Gaussian.
340
+
341
+ Args:
342
+ simcc_x (np.ndarray[K, Wx]): model predicted simcc in x.
343
+ simcc_y (np.ndarray[K, Wy]): model predicted simcc in y.
344
+ simcc_split_ratio (int): The split ratio of simcc.
345
+
346
+ Returns:
347
+ tuple: A tuple containing center and scale.
348
+ - np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2)
349
+ - np.ndarray[float32]: scores in shape (K,) or (n, K)
350
+ """
351
+ keypoints, scores = get_simcc_maximum(simcc_x, simcc_y)
352
+ keypoints /= simcc_split_ratio
353
+
354
+ return keypoints, scores
355
+
356
+ def inference_pose(model, out_bbox, oriImg, model_input_size=(288, 384)):
357
+ resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size)
358
+ #outputs = inference(session, resized_img, dtype)
359
+ outputs = inference(model, resized_img)
360
+
361
+ keypoints, scores = postprocess(outputs, model_input_size, center, scale)
362
+
363
+ return keypoints, scores
custom_nodes/ComfyUI-tbox/src/dwpose/face.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import numpy as np
3
+ from torchvision.transforms import ToTensor, ToPILImage
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import cv2
7
+
8
+ from . import util
9
+ from torch.nn import Conv2d, Module, ReLU, MaxPool2d, init
10
+
11
+
12
+ class FaceNet(Module):
13
+ """Model the cascading heatmaps. """
14
+ def __init__(self):
15
+ super(FaceNet, self).__init__()
16
+ # cnn to make feature map
17
+ self.relu = ReLU()
18
+ self.max_pooling_2d = MaxPool2d(kernel_size=2, stride=2)
19
+ self.conv1_1 = Conv2d(in_channels=3, out_channels=64,
20
+ kernel_size=3, stride=1, padding=1)
21
+ self.conv1_2 = Conv2d(
22
+ in_channels=64, out_channels=64, kernel_size=3, stride=1,
23
+ padding=1)
24
+ self.conv2_1 = Conv2d(
25
+ in_channels=64, out_channels=128, kernel_size=3, stride=1,
26
+ padding=1)
27
+ self.conv2_2 = Conv2d(
28
+ in_channels=128, out_channels=128, kernel_size=3, stride=1,
29
+ padding=1)
30
+ self.conv3_1 = Conv2d(
31
+ in_channels=128, out_channels=256, kernel_size=3, stride=1,
32
+ padding=1)
33
+ self.conv3_2 = Conv2d(
34
+ in_channels=256, out_channels=256, kernel_size=3, stride=1,
35
+ padding=1)
36
+ self.conv3_3 = Conv2d(
37
+ in_channels=256, out_channels=256, kernel_size=3, stride=1,
38
+ padding=1)
39
+ self.conv3_4 = Conv2d(
40
+ in_channels=256, out_channels=256, kernel_size=3, stride=1,
41
+ padding=1)
42
+ self.conv4_1 = Conv2d(
43
+ in_channels=256, out_channels=512, kernel_size=3, stride=1,
44
+ padding=1)
45
+ self.conv4_2 = Conv2d(
46
+ in_channels=512, out_channels=512, kernel_size=3, stride=1,
47
+ padding=1)
48
+ self.conv4_3 = Conv2d(
49
+ in_channels=512, out_channels=512, kernel_size=3, stride=1,
50
+ padding=1)
51
+ self.conv4_4 = Conv2d(
52
+ in_channels=512, out_channels=512, kernel_size=3, stride=1,
53
+ padding=1)
54
+ self.conv5_1 = Conv2d(
55
+ in_channels=512, out_channels=512, kernel_size=3, stride=1,
56
+ padding=1)
57
+ self.conv5_2 = Conv2d(
58
+ in_channels=512, out_channels=512, kernel_size=3, stride=1,
59
+ padding=1)
60
+ self.conv5_3_CPM = Conv2d(
61
+ in_channels=512, out_channels=128, kernel_size=3, stride=1,
62
+ padding=1)
63
+
64
+ # stage1
65
+ self.conv6_1_CPM = Conv2d(
66
+ in_channels=128, out_channels=512, kernel_size=1, stride=1,
67
+ padding=0)
68
+ self.conv6_2_CPM = Conv2d(
69
+ in_channels=512, out_channels=71, kernel_size=1, stride=1,
70
+ padding=0)
71
+
72
+ # stage2
73
+ self.Mconv1_stage2 = Conv2d(
74
+ in_channels=199, out_channels=128, kernel_size=7, stride=1,
75
+ padding=3)
76
+ self.Mconv2_stage2 = Conv2d(
77
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
78
+ padding=3)
79
+ self.Mconv3_stage2 = Conv2d(
80
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
81
+ padding=3)
82
+ self.Mconv4_stage2 = Conv2d(
83
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
84
+ padding=3)
85
+ self.Mconv5_stage2 = Conv2d(
86
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
87
+ padding=3)
88
+ self.Mconv6_stage2 = Conv2d(
89
+ in_channels=128, out_channels=128, kernel_size=1, stride=1,
90
+ padding=0)
91
+ self.Mconv7_stage2 = Conv2d(
92
+ in_channels=128, out_channels=71, kernel_size=1, stride=1,
93
+ padding=0)
94
+
95
+ # stage3
96
+ self.Mconv1_stage3 = Conv2d(
97
+ in_channels=199, out_channels=128, kernel_size=7, stride=1,
98
+ padding=3)
99
+ self.Mconv2_stage3 = Conv2d(
100
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
101
+ padding=3)
102
+ self.Mconv3_stage3 = Conv2d(
103
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
104
+ padding=3)
105
+ self.Mconv4_stage3 = Conv2d(
106
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
107
+ padding=3)
108
+ self.Mconv5_stage3 = Conv2d(
109
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
110
+ padding=3)
111
+ self.Mconv6_stage3 = Conv2d(
112
+ in_channels=128, out_channels=128, kernel_size=1, stride=1,
113
+ padding=0)
114
+ self.Mconv7_stage3 = Conv2d(
115
+ in_channels=128, out_channels=71, kernel_size=1, stride=1,
116
+ padding=0)
117
+
118
+ # stage4
119
+ self.Mconv1_stage4 = Conv2d(
120
+ in_channels=199, out_channels=128, kernel_size=7, stride=1,
121
+ padding=3)
122
+ self.Mconv2_stage4 = Conv2d(
123
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
124
+ padding=3)
125
+ self.Mconv3_stage4 = Conv2d(
126
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
127
+ padding=3)
128
+ self.Mconv4_stage4 = Conv2d(
129
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
130
+ padding=3)
131
+ self.Mconv5_stage4 = Conv2d(
132
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
133
+ padding=3)
134
+ self.Mconv6_stage4 = Conv2d(
135
+ in_channels=128, out_channels=128, kernel_size=1, stride=1,
136
+ padding=0)
137
+ self.Mconv7_stage4 = Conv2d(
138
+ in_channels=128, out_channels=71, kernel_size=1, stride=1,
139
+ padding=0)
140
+
141
+ # stage5
142
+ self.Mconv1_stage5 = Conv2d(
143
+ in_channels=199, out_channels=128, kernel_size=7, stride=1,
144
+ padding=3)
145
+ self.Mconv2_stage5 = Conv2d(
146
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
147
+ padding=3)
148
+ self.Mconv3_stage5 = Conv2d(
149
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
150
+ padding=3)
151
+ self.Mconv4_stage5 = Conv2d(
152
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
153
+ padding=3)
154
+ self.Mconv5_stage5 = Conv2d(
155
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
156
+ padding=3)
157
+ self.Mconv6_stage5 = Conv2d(
158
+ in_channels=128, out_channels=128, kernel_size=1, stride=1,
159
+ padding=0)
160
+ self.Mconv7_stage5 = Conv2d(
161
+ in_channels=128, out_channels=71, kernel_size=1, stride=1,
162
+ padding=0)
163
+
164
+ # stage6
165
+ self.Mconv1_stage6 = Conv2d(
166
+ in_channels=199, out_channels=128, kernel_size=7, stride=1,
167
+ padding=3)
168
+ self.Mconv2_stage6 = Conv2d(
169
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
170
+ padding=3)
171
+ self.Mconv3_stage6 = Conv2d(
172
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
173
+ padding=3)
174
+ self.Mconv4_stage6 = Conv2d(
175
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
176
+ padding=3)
177
+ self.Mconv5_stage6 = Conv2d(
178
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
179
+ padding=3)
180
+ self.Mconv6_stage6 = Conv2d(
181
+ in_channels=128, out_channels=128, kernel_size=1, stride=1,
182
+ padding=0)
183
+ self.Mconv7_stage6 = Conv2d(
184
+ in_channels=128, out_channels=71, kernel_size=1, stride=1,
185
+ padding=0)
186
+
187
+ for m in self.modules():
188
+ if isinstance(m, Conv2d):
189
+ init.constant_(m.bias, 0)
190
+
191
+ def forward(self, x):
192
+ """Return a list of heatmaps."""
193
+ heatmaps = []
194
+
195
+ h = self.relu(self.conv1_1(x))
196
+ h = self.relu(self.conv1_2(h))
197
+ h = self.max_pooling_2d(h)
198
+ h = self.relu(self.conv2_1(h))
199
+ h = self.relu(self.conv2_2(h))
200
+ h = self.max_pooling_2d(h)
201
+ h = self.relu(self.conv3_1(h))
202
+ h = self.relu(self.conv3_2(h))
203
+ h = self.relu(self.conv3_3(h))
204
+ h = self.relu(self.conv3_4(h))
205
+ h = self.max_pooling_2d(h)
206
+ h = self.relu(self.conv4_1(h))
207
+ h = self.relu(self.conv4_2(h))
208
+ h = self.relu(self.conv4_3(h))
209
+ h = self.relu(self.conv4_4(h))
210
+ h = self.relu(self.conv5_1(h))
211
+ h = self.relu(self.conv5_2(h))
212
+ h = self.relu(self.conv5_3_CPM(h))
213
+ feature_map = h
214
+
215
+ # stage1
216
+ h = self.relu(self.conv6_1_CPM(h))
217
+ h = self.conv6_2_CPM(h)
218
+ heatmaps.append(h)
219
+
220
+ # stage2
221
+ h = torch.cat([h, feature_map], dim=1) # channel concat
222
+ h = self.relu(self.Mconv1_stage2(h))
223
+ h = self.relu(self.Mconv2_stage2(h))
224
+ h = self.relu(self.Mconv3_stage2(h))
225
+ h = self.relu(self.Mconv4_stage2(h))
226
+ h = self.relu(self.Mconv5_stage2(h))
227
+ h = self.relu(self.Mconv6_stage2(h))
228
+ h = self.Mconv7_stage2(h)
229
+ heatmaps.append(h)
230
+
231
+ # stage3
232
+ h = torch.cat([h, feature_map], dim=1) # channel concat
233
+ h = self.relu(self.Mconv1_stage3(h))
234
+ h = self.relu(self.Mconv2_stage3(h))
235
+ h = self.relu(self.Mconv3_stage3(h))
236
+ h = self.relu(self.Mconv4_stage3(h))
237
+ h = self.relu(self.Mconv5_stage3(h))
238
+ h = self.relu(self.Mconv6_stage3(h))
239
+ h = self.Mconv7_stage3(h)
240
+ heatmaps.append(h)
241
+
242
+ # stage4
243
+ h = torch.cat([h, feature_map], dim=1) # channel concat
244
+ h = self.relu(self.Mconv1_stage4(h))
245
+ h = self.relu(self.Mconv2_stage4(h))
246
+ h = self.relu(self.Mconv3_stage4(h))
247
+ h = self.relu(self.Mconv4_stage4(h))
248
+ h = self.relu(self.Mconv5_stage4(h))
249
+ h = self.relu(self.Mconv6_stage4(h))
250
+ h = self.Mconv7_stage4(h)
251
+ heatmaps.append(h)
252
+
253
+ # stage5
254
+ h = torch.cat([h, feature_map], dim=1) # channel concat
255
+ h = self.relu(self.Mconv1_stage5(h))
256
+ h = self.relu(self.Mconv2_stage5(h))
257
+ h = self.relu(self.Mconv3_stage5(h))
258
+ h = self.relu(self.Mconv4_stage5(h))
259
+ h = self.relu(self.Mconv5_stage5(h))
260
+ h = self.relu(self.Mconv6_stage5(h))
261
+ h = self.Mconv7_stage5(h)
262
+ heatmaps.append(h)
263
+
264
+ # stage6
265
+ h = torch.cat([h, feature_map], dim=1) # channel concat
266
+ h = self.relu(self.Mconv1_stage6(h))
267
+ h = self.relu(self.Mconv2_stage6(h))
268
+ h = self.relu(self.Mconv3_stage6(h))
269
+ h = self.relu(self.Mconv4_stage6(h))
270
+ h = self.relu(self.Mconv5_stage6(h))
271
+ h = self.relu(self.Mconv6_stage6(h))
272
+ h = self.Mconv7_stage6(h)
273
+ heatmaps.append(h)
274
+
275
+ return heatmaps
276
+
277
+
278
+ LOG = logging.getLogger(__name__)
279
+ TOTEN = ToTensor()
280
+ TOPIL = ToPILImage()
281
+
282
+
283
+ params = {
284
+ 'gaussian_sigma': 2.5,
285
+ 'inference_img_size': 736, # 368, 736, 1312
286
+ 'heatmap_peak_thresh': 0.1,
287
+ 'crop_scale': 1.5,
288
+ 'line_indices': [
289
+ [0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6],
290
+ [6, 7], [7, 8], [8, 9], [9, 10], [10, 11], [11, 12], [12, 13],
291
+ [13, 14], [14, 15], [15, 16],
292
+ [17, 18], [18, 19], [19, 20], [20, 21],
293
+ [22, 23], [23, 24], [24, 25], [25, 26],
294
+ [27, 28], [28, 29], [29, 30],
295
+ [31, 32], [32, 33], [33, 34], [34, 35],
296
+ [36, 37], [37, 38], [38, 39], [39, 40], [40, 41], [41, 36],
297
+ [42, 43], [43, 44], [44, 45], [45, 46], [46, 47], [47, 42],
298
+ [48, 49], [49, 50], [50, 51], [51, 52], [52, 53], [53, 54],
299
+ [54, 55], [55, 56], [56, 57], [57, 58], [58, 59], [59, 48],
300
+ [60, 61], [61, 62], [62, 63], [63, 64], [64, 65], [65, 66],
301
+ [66, 67], [67, 60]
302
+ ],
303
+ }
304
+
305
+
306
+ class Face(object):
307
+ """
308
+ The OpenPose face landmark detector model.
309
+
310
+ Args:
311
+ inference_size: set the size of the inference image size, suggested:
312
+ 368, 736, 1312, default 736
313
+ gaussian_sigma: blur the heatmaps, default 2.5
314
+ heatmap_peak_thresh: return landmark if over threshold, default 0.1
315
+
316
+ """
317
+ def __init__(self, face_model_path,
318
+ inference_size=None,
319
+ gaussian_sigma=None,
320
+ heatmap_peak_thresh=None):
321
+ self.inference_size = inference_size or params["inference_img_size"]
322
+ self.sigma = gaussian_sigma or params['gaussian_sigma']
323
+ self.threshold = heatmap_peak_thresh or params["heatmap_peak_thresh"]
324
+ self.model = FaceNet()
325
+ self.model.load_state_dict(torch.load(face_model_path))
326
+ # if torch.cuda.is_available():
327
+ # self.model = self.model.cuda()
328
+ # print('cuda')
329
+ self.model.eval()
330
+
331
+ def __call__(self, face_img):
332
+ H, W, C = face_img.shape
333
+
334
+ w_size = 384
335
+ x_data = torch.from_numpy(util.smart_resize(face_img, (w_size, w_size))).permute([2, 0, 1]) / 256.0 - 0.5
336
+
337
+ x_data = x_data.to(self.cn_device)
338
+
339
+ with torch.no_grad():
340
+ hs = self.model(x_data[None, ...])
341
+ heatmaps = F.interpolate(
342
+ hs[-1],
343
+ (H, W),
344
+ mode='bilinear', align_corners=True).cpu().numpy()[0]
345
+ return heatmaps
346
+
347
+ def compute_peaks_from_heatmaps(self, heatmaps):
348
+ all_peaks = []
349
+ for part in range(heatmaps.shape[0]):
350
+ map_ori = heatmaps[part].copy()
351
+ binary = np.ascontiguousarray(map_ori > 0.05, dtype=np.uint8)
352
+
353
+ if np.sum(binary) == 0:
354
+ continue
355
+
356
+ positions = np.where(binary > 0.5)
357
+ intensities = map_ori[positions]
358
+ mi = np.argmax(intensities)
359
+ y, x = positions[0][mi], positions[1][mi]
360
+ all_peaks.append([x, y])
361
+
362
+ return np.array(all_peaks)
custom_nodes/ComfyUI-tbox/src/dwpose/hand.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import json
3
+ import numpy as np
4
+ import math
5
+ import time
6
+ from scipy.ndimage.filters import gaussian_filter
7
+ import matplotlib.pyplot as plt
8
+ import matplotlib
9
+ import torch
10
+ from skimage.measure import label
11
+
12
+ from .model import handpose_model
13
+ from . import util
14
+
15
+ class Hand(object):
16
+ def __init__(self, model_path):
17
+ self.model = handpose_model()
18
+ # if torch.cuda.is_available():
19
+ # self.model = self.model.cuda()
20
+ # print('cuda')
21
+ model_dict = util.transfer(self.model, torch.load(model_path))
22
+ self.model.load_state_dict(model_dict)
23
+ self.model.eval()
24
+
25
+ def __call__(self, oriImgRaw):
26
+ scale_search = [0.5, 1.0, 1.5, 2.0]
27
+ # scale_search = [0.5]
28
+ boxsize = 368
29
+ stride = 8
30
+ padValue = 128
31
+ thre = 0.05
32
+ multiplier = [x * boxsize for x in scale_search]
33
+
34
+ wsize = 128
35
+ heatmap_avg = np.zeros((wsize, wsize, 22))
36
+
37
+ Hr, Wr, Cr = oriImgRaw.shape
38
+
39
+ oriImg = cv2.GaussianBlur(oriImgRaw, (0, 0), 0.8)
40
+
41
+ for m in range(len(multiplier)):
42
+ scale = multiplier[m]
43
+ imageToTest = util.smart_resize(oriImg, (scale, scale))
44
+
45
+ imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue)
46
+ im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
47
+ im = np.ascontiguousarray(im)
48
+
49
+ data = torch.from_numpy(im).float()
50
+ if torch.cuda.is_available():
51
+ data = data.cuda()
52
+
53
+ with torch.no_grad():
54
+ data = data.to(self.cn_device)
55
+ output = self.model(data).cpu().numpy()
56
+
57
+ # extract outputs, resize, and remove padding
58
+ heatmap = np.transpose(np.squeeze(output), (1, 2, 0)) # output 1 is heatmaps
59
+ heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride)
60
+ heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
61
+ heatmap = util.smart_resize(heatmap, (wsize, wsize))
62
+
63
+ heatmap_avg += heatmap / len(multiplier)
64
+
65
+ all_peaks = []
66
+ for part in range(21):
67
+ map_ori = heatmap_avg[:, :, part]
68
+ one_heatmap = gaussian_filter(map_ori, sigma=3)
69
+ binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8)
70
+
71
+ if np.sum(binary) == 0:
72
+ all_peaks.append([0, 0])
73
+ continue
74
+ label_img, label_numbers = label(binary, return_num=True, connectivity=binary.ndim)
75
+ max_index = np.argmax([np.sum(map_ori[label_img == i]) for i in range(1, label_numbers + 1)]) + 1
76
+ label_img[label_img != max_index] = 0
77
+ map_ori[label_img == 0] = 0
78
+
79
+ y, x = util.npmax(map_ori)
80
+ y = int(float(y) * float(Hr) / float(wsize))
81
+ x = int(float(x) * float(Wr) / float(wsize))
82
+ all_peaks.append([x, y])
83
+ return np.array(all_peaks)
84
+
85
+ if __name__ == "__main__":
86
+ hand_estimation = Hand('../model/hand_pose_model.pth')
87
+
88
+ # test_image = '../images/hand.jpg'
89
+ test_image = '../images/hand.jpg'
90
+ oriImg = cv2.imread(test_image) # B,G,R order
91
+ peaks = hand_estimation(oriImg)
92
+ canvas = util.draw_handpose(oriImg, peaks, True)
93
+ cv2.imshow('', canvas)
94
+ cv2.waitKey(0)
custom_nodes/ComfyUI-tbox/src/dwpose/model.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from collections import OrderedDict
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ def make_layers(block, no_relu_layers):
8
+ layers = []
9
+ for layer_name, v in block.items():
10
+ if 'pool' in layer_name:
11
+ layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1],
12
+ padding=v[2])
13
+ layers.append((layer_name, layer))
14
+ else:
15
+ conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1],
16
+ kernel_size=v[2], stride=v[3],
17
+ padding=v[4])
18
+ layers.append((layer_name, conv2d))
19
+ if layer_name not in no_relu_layers:
20
+ layers.append(('relu_'+layer_name, nn.ReLU(inplace=True)))
21
+
22
+ return nn.Sequential(OrderedDict(layers))
23
+
24
+ class bodypose_model(nn.Module):
25
+ def __init__(self):
26
+ super(bodypose_model, self).__init__()
27
+
28
+ # these layers have no relu layer
29
+ no_relu_layers = ['conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1',\
30
+ 'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2',\
31
+ 'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1',\
32
+ 'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1']
33
+ blocks = {}
34
+ block0 = OrderedDict([
35
+ ('conv1_1', [3, 64, 3, 1, 1]),
36
+ ('conv1_2', [64, 64, 3, 1, 1]),
37
+ ('pool1_stage1', [2, 2, 0]),
38
+ ('conv2_1', [64, 128, 3, 1, 1]),
39
+ ('conv2_2', [128, 128, 3, 1, 1]),
40
+ ('pool2_stage1', [2, 2, 0]),
41
+ ('conv3_1', [128, 256, 3, 1, 1]),
42
+ ('conv3_2', [256, 256, 3, 1, 1]),
43
+ ('conv3_3', [256, 256, 3, 1, 1]),
44
+ ('conv3_4', [256, 256, 3, 1, 1]),
45
+ ('pool3_stage1', [2, 2, 0]),
46
+ ('conv4_1', [256, 512, 3, 1, 1]),
47
+ ('conv4_2', [512, 512, 3, 1, 1]),
48
+ ('conv4_3_CPM', [512, 256, 3, 1, 1]),
49
+ ('conv4_4_CPM', [256, 128, 3, 1, 1])
50
+ ])
51
+
52
+
53
+ # Stage 1
54
+ block1_1 = OrderedDict([
55
+ ('conv5_1_CPM_L1', [128, 128, 3, 1, 1]),
56
+ ('conv5_2_CPM_L1', [128, 128, 3, 1, 1]),
57
+ ('conv5_3_CPM_L1', [128, 128, 3, 1, 1]),
58
+ ('conv5_4_CPM_L1', [128, 512, 1, 1, 0]),
59
+ ('conv5_5_CPM_L1', [512, 38, 1, 1, 0])
60
+ ])
61
+
62
+ block1_2 = OrderedDict([
63
+ ('conv5_1_CPM_L2', [128, 128, 3, 1, 1]),
64
+ ('conv5_2_CPM_L2', [128, 128, 3, 1, 1]),
65
+ ('conv5_3_CPM_L2', [128, 128, 3, 1, 1]),
66
+ ('conv5_4_CPM_L2', [128, 512, 1, 1, 0]),
67
+ ('conv5_5_CPM_L2', [512, 19, 1, 1, 0])
68
+ ])
69
+ blocks['block1_1'] = block1_1
70
+ blocks['block1_2'] = block1_2
71
+
72
+ self.model0 = make_layers(block0, no_relu_layers)
73
+
74
+ # Stages 2 - 6
75
+ for i in range(2, 7):
76
+ blocks['block%d_1' % i] = OrderedDict([
77
+ ('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]),
78
+ ('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]),
79
+ ('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]),
80
+ ('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]),
81
+ ('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]),
82
+ ('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]),
83
+ ('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0])
84
+ ])
85
+
86
+ blocks['block%d_2' % i] = OrderedDict([
87
+ ('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]),
88
+ ('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]),
89
+ ('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]),
90
+ ('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]),
91
+ ('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]),
92
+ ('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]),
93
+ ('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0])
94
+ ])
95
+
96
+ for k in blocks.keys():
97
+ blocks[k] = make_layers(blocks[k], no_relu_layers)
98
+
99
+ self.model1_1 = blocks['block1_1']
100
+ self.model2_1 = blocks['block2_1']
101
+ self.model3_1 = blocks['block3_1']
102
+ self.model4_1 = blocks['block4_1']
103
+ self.model5_1 = blocks['block5_1']
104
+ self.model6_1 = blocks['block6_1']
105
+
106
+ self.model1_2 = blocks['block1_2']
107
+ self.model2_2 = blocks['block2_2']
108
+ self.model3_2 = blocks['block3_2']
109
+ self.model4_2 = blocks['block4_2']
110
+ self.model5_2 = blocks['block5_2']
111
+ self.model6_2 = blocks['block6_2']
112
+
113
+
114
+ def forward(self, x):
115
+
116
+ out1 = self.model0(x)
117
+
118
+ out1_1 = self.model1_1(out1)
119
+ out1_2 = self.model1_2(out1)
120
+ out2 = torch.cat([out1_1, out1_2, out1], 1)
121
+
122
+ out2_1 = self.model2_1(out2)
123
+ out2_2 = self.model2_2(out2)
124
+ out3 = torch.cat([out2_1, out2_2, out1], 1)
125
+
126
+ out3_1 = self.model3_1(out3)
127
+ out3_2 = self.model3_2(out3)
128
+ out4 = torch.cat([out3_1, out3_2, out1], 1)
129
+
130
+ out4_1 = self.model4_1(out4)
131
+ out4_2 = self.model4_2(out4)
132
+ out5 = torch.cat([out4_1, out4_2, out1], 1)
133
+
134
+ out5_1 = self.model5_1(out5)
135
+ out5_2 = self.model5_2(out5)
136
+ out6 = torch.cat([out5_1, out5_2, out1], 1)
137
+
138
+ out6_1 = self.model6_1(out6)
139
+ out6_2 = self.model6_2(out6)
140
+
141
+ return out6_1, out6_2
142
+
143
+ class handpose_model(nn.Module):
144
+ def __init__(self):
145
+ super(handpose_model, self).__init__()
146
+
147
+ # these layers have no relu layer
148
+ no_relu_layers = ['conv6_2_CPM', 'Mconv7_stage2', 'Mconv7_stage3',\
149
+ 'Mconv7_stage4', 'Mconv7_stage5', 'Mconv7_stage6']
150
+ # stage 1
151
+ block1_0 = OrderedDict([
152
+ ('conv1_1', [3, 64, 3, 1, 1]),
153
+ ('conv1_2', [64, 64, 3, 1, 1]),
154
+ ('pool1_stage1', [2, 2, 0]),
155
+ ('conv2_1', [64, 128, 3, 1, 1]),
156
+ ('conv2_2', [128, 128, 3, 1, 1]),
157
+ ('pool2_stage1', [2, 2, 0]),
158
+ ('conv3_1', [128, 256, 3, 1, 1]),
159
+ ('conv3_2', [256, 256, 3, 1, 1]),
160
+ ('conv3_3', [256, 256, 3, 1, 1]),
161
+ ('conv3_4', [256, 256, 3, 1, 1]),
162
+ ('pool3_stage1', [2, 2, 0]),
163
+ ('conv4_1', [256, 512, 3, 1, 1]),
164
+ ('conv4_2', [512, 512, 3, 1, 1]),
165
+ ('conv4_3', [512, 512, 3, 1, 1]),
166
+ ('conv4_4', [512, 512, 3, 1, 1]),
167
+ ('conv5_1', [512, 512, 3, 1, 1]),
168
+ ('conv5_2', [512, 512, 3, 1, 1]),
169
+ ('conv5_3_CPM', [512, 128, 3, 1, 1])
170
+ ])
171
+
172
+ block1_1 = OrderedDict([
173
+ ('conv6_1_CPM', [128, 512, 1, 1, 0]),
174
+ ('conv6_2_CPM', [512, 22, 1, 1, 0])
175
+ ])
176
+
177
+ blocks = {}
178
+ blocks['block1_0'] = block1_0
179
+ blocks['block1_1'] = block1_1
180
+
181
+ # stage 2-6
182
+ for i in range(2, 7):
183
+ blocks['block%d' % i] = OrderedDict([
184
+ ('Mconv1_stage%d' % i, [150, 128, 7, 1, 3]),
185
+ ('Mconv2_stage%d' % i, [128, 128, 7, 1, 3]),
186
+ ('Mconv3_stage%d' % i, [128, 128, 7, 1, 3]),
187
+ ('Mconv4_stage%d' % i, [128, 128, 7, 1, 3]),
188
+ ('Mconv5_stage%d' % i, [128, 128, 7, 1, 3]),
189
+ ('Mconv6_stage%d' % i, [128, 128, 1, 1, 0]),
190
+ ('Mconv7_stage%d' % i, [128, 22, 1, 1, 0])
191
+ ])
192
+
193
+ for k in blocks.keys():
194
+ blocks[k] = make_layers(blocks[k], no_relu_layers)
195
+
196
+ self.model1_0 = blocks['block1_0']
197
+ self.model1_1 = blocks['block1_1']
198
+ self.model2 = blocks['block2']
199
+ self.model3 = blocks['block3']
200
+ self.model4 = blocks['block4']
201
+ self.model5 = blocks['block5']
202
+ self.model6 = blocks['block6']
203
+
204
+ def forward(self, x):
205
+ out1_0 = self.model1_0(x)
206
+ out1_1 = self.model1_1(out1_0)
207
+ concat_stage2 = torch.cat([out1_1, out1_0], 1)
208
+ out_stage2 = self.model2(concat_stage2)
209
+ concat_stage3 = torch.cat([out_stage2, out1_0], 1)
210
+ out_stage3 = self.model3(concat_stage3)
211
+ concat_stage4 = torch.cat([out_stage3, out1_0], 1)
212
+ out_stage4 = self.model4(concat_stage4)
213
+ concat_stage5 = torch.cat([out_stage4, out1_0], 1)
214
+ out_stage5 = self.model5(concat_stage5)
215
+ concat_stage6 = torch.cat([out_stage5, out1_0], 1)
216
+ out_stage6 = self.model6(concat_stage6)
217
+ return out_stage6
218
+
custom_nodes/ComfyUI-tbox/src/dwpose/types.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import NamedTuple, List, Optional
2
+
3
+ class Keypoint(NamedTuple):
4
+ x: float
5
+ y: float
6
+ score: float = 1.0
7
+ id: int = -1
8
+
9
+
10
+ class BodyResult(NamedTuple):
11
+ # Note: Using `Optional` instead of `|` operator as the ladder is a Python
12
+ # 3.10 feature.
13
+ # Annotator code should be Python 3.8 Compatible, as controlnet repo uses
14
+ # Python 3.8 environment.
15
+ # https://github.com/lllyasviel/ControlNet/blob/d3284fcd0972c510635a4f5abe2eeb71dc0de524/environment.yaml#L6
16
+ keypoints: List[Optional[Keypoint]]
17
+ total_score: float = 0.0
18
+ total_parts: int = 0
19
+
20
+
21
+ HandResult = List[Keypoint]
22
+ FaceResult = List[Keypoint]
23
+ AnimalPoseResult = List[Keypoint]
24
+
25
+
26
+ class PoseResult(NamedTuple):
27
+ body: BodyResult
28
+ left_hand: Optional[HandResult]
29
+ right_hand: Optional[HandResult]
30
+ face: Optional[FaceResult]