Spaces:
Running
Running
# Code inspired from ultralyrics example with gradio | |
import gradio as gradio | |
import PIL.Image as Image | |
import os | |
import shutil | |
from ultralytics import YOLO | |
from huggingface_hub import hf_hub_download | |
# Directory where downloaded model will be stored | |
MODEL_DIR = "cached_models" | |
os.makedirs(MODEL_DIR, exist_ok=True) | |
# List of models available in the gradio ui | |
AVAILABLE_MODELS = { | |
"YOLOv8m Speech Bubble (kitsumed)": { | |
"repo_id": "kitsumed/yolov8m_seg-speech-bubble", | |
# Filename, include sub-directory if model not at root (models/v1/model.pt) | |
"filename": "model.pt" | |
}, | |
# Add more models here | |
} | |
# Cache for currently loaded model | |
current_model = None | |
current_model_name = None | |
def load_model(model_name): | |
global current_model, current_model_name | |
if model_name == current_model_name: | |
return current_model | |
# Load the repo info related to the selected model from the available models dictionary | |
info = AVAILABLE_MODELS.get(model_name) | |
model_path = hf_hub_download( | |
repo_id=info["repo_id"], | |
filename=info["filename"], | |
# Where to cache the downloaded file, files already cached will directly be reused | |
local_dir=MODEL_DIR | |
) | |
current_model = YOLO(model_path) | |
current_model_name = model_name | |
return current_model | |
def predict_image(img, conf_threshold, iou_threshold, model_name): | |
model = load_model(model_name) | |
results = model.predict( | |
source=img, | |
conf=conf_threshold, | |
iou=iou_threshold, | |
show_labels=True, | |
show_conf=True, | |
imgsz=640, | |
) | |
for r in results: | |
im_array = r.plot() | |
im = Image.fromarray(im_array[..., ::-1]) | |
return im | |
iface = gradio.Interface( | |
fn=predict_image, | |
inputs=[ | |
gradio.Image(type="pil", label="Upload Image"), | |
gradio.Slider(minimum=0, maximum=1, value=0.20, label="Confidence threshold"), | |
gradio.Slider(minimum=0, maximum=1, value=0.40, label="IoU threshold"), | |
gradio.Dropdown(choices=list(AVAILABLE_MODELS.keys()), label="Select Model", value=list(AVAILABLE_MODELS.keys())[0]) | |
], | |
outputs=gradio.Image(type="pil", label="Result"), | |
title="Try out kitsumed YOLO models", | |
description="Select a model from kitsumed on Hugging Face and upload an image to perform predictions.", | |
) | |
if __name__ == "__main__": | |
iface.launch() | |