from dataclasses import dataclass, field from typing import Generic, Optional, TypeVar import cv2 import imgutils import numpy as np from imgutils.generic.yolo import ( _image_preprocess, _rtdetr_postprocess, _yolo_postprocess, rgb_encode, ) from PIL import Image, ImageDraw T = TypeVar("T", int, float) REPO_IDS = { "head": "deepghs/anime_head_detection", "face": "deepghs/anime_face_detection", "eye": "deepghs/anime_eye_detection", } @dataclass class DetectorOutput(Generic[T]): bboxes: list[list[T]] = field(default_factory=list) masks: list[Image.Image] = field(default_factory=list) confidences: list[float] = field(default_factory=list) previews: Optional[Image.Image] = None class AnimeDetector: """ A class used to perform object detection on anime images. Please refer to the `imgutils` documentation for more information on the available models. """ def __init__(self, repo_id: str, model_name: str, hf_token: Optional[str] = None): model_manager = imgutils.generic.yolo._open_models_for_repo_id( repo_id, hf_token=hf_token ) model, max_infer_size, labels = model_manager._open_model(model_name) self.model = model self.max_infer_size = max_infer_size self.labels = labels self.model_type = model_manager._get_model_type(model_name) def __call__( self, image: Image.Image, conf_threshold: float = 0.3, iou_threshold: float = 0.7, allow_dynamic: bool = False, ) -> DetectorOutput[float]: """ Perform object detection on the given image. Args: image (Image.Image): The input image on which to perform detection. conf_threshold (float, optional): Confidence threshold for detection. Defaults to 0.3. iou_threshold (float, optional): Intersection over Union (IoU) threshold for detection. Defaults to 0.7. allow_dynamic (bool, optional): Whether to allow dynamic resizing of the image. Defaults to False. Returns: DetectorOutput[float]: The detection results, including bounding boxes, masks, confidences, and a preview image. Raises: ValueError: If the model type is unknown. """ # Preprocessing new_image, old_size, new_size = _image_preprocess( image, self.max_infer_size, allow_dynamic=allow_dynamic ) data = rgb_encode(new_image)[None, ...] # Start detection (output,) = self.model.run(["output0"], {"images": data}) # Postprocessing if self.model_type == "yolo": output = _yolo_postprocess( output=output[0], conf_threshold=conf_threshold, iou_threshold=iou_threshold, old_size=old_size, new_size=new_size, labels=self.labels, ) elif self.model_type == "rtdetr": output = _rtdetr_postprocess( output=output[0], conf_threshold=conf_threshold, iou_threshold=iou_threshold, old_size=old_size, new_size=new_size, labels=self.labels, ) else: raise ValueError( f"Unknown object detection model type - {self.model_type!r}." ) # pragma: no cover if len(output) == 0: return DetectorOutput() bboxes = [x[0] for x in output] # [x0, y0, x1, y1] masks = create_mask_from_bbox(bboxes, image.size) confidences = [x[2] for x in output] # Create a preview image previews = [] for mask in masks: np_image = np.array(image) np_mask = np.array(mask) preview = cv2.bitwise_and( np_image, cv2.cvtColor(np_mask, cv2.COLOR_GRAY2BGR) ) preview = Image.fromarray(preview) previews.append(preview) return DetectorOutput( bboxes=bboxes, masks=masks, confidences=confidences, previews=previews ) def create_mask_from_bbox( bboxes: list[list[float]], shape: tuple[int, int] ) -> list[Image.Image]: """ Creates a list of binary masks from bounding boxes. Args: bboxes (list[list[float]]): A list of bounding boxes, where each bounding box is represented by a list of four float values [x_min, y_min, x_max, y_max]. shape (tuple[int, int]): The shape of the mask (height, width). Returns: list[Image.Image]: A list of PIL Image objects representing the binary masks. """ masks = [] for bbox in bboxes: mask = Image.new("L", shape, 0) mask_draw = ImageDraw.Draw(mask) mask_draw.rectangle(bbox, fill=255) masks.append(mask) return masks def create_bbox_from_mask( masks: list[Image.Image], shape: tuple[int, int] ) -> list[list[int]]: """ Create bounding boxes from a list of mask images. Args: masks (list[Image.Image]): A list of PIL Image objects representing the masks. shape (tuple[int, int]): A tuple representing the desired shape (width, height) to resize the masks. Returns: list[list[int]]: A list of bounding boxes, where each bounding box is represented as a list of four integers [left, upper, right, lower]. """ bboxes = [] for mask in masks: mask = mask.resize(shape) bbox = mask.getbbox() if bbox is not None: bboxes.append(list(bbox)) return bboxes