yolo-inferences / app.py
kitsumed's picture
Update app.py
b24456a verified
raw
history blame
2.4 kB
# Code inspired from ultralyrics example with gradio
import gradio as gradio
import PIL.Image as Image
import os
import shutil
from ultralytics import YOLO
from huggingface_hub import hf_hub_download
# Directory where downloaded model will be stored
MODEL_DIR = "cached_models"
os.makedirs(MODEL_DIR, exist_ok=True)
# List of models available in the gradio ui
AVAILABLE_MODELS = {
"YOLOv8m Speech Bubble (kitsumed)": {
"repo_id": "kitsumed/yolov8m_seg-speech-bubble",
# Filename, include sub-directory if model not at root (models/v1/model.pt)
"filename": "model.pt"
},
# Add more models here
}
# Cache for currently loaded model
current_model = None
current_model_name = None
def load_model(model_name):
global current_model, current_model_name
if model_name == current_model_name:
return current_model
# Load the repo info related to the selected model from the available models dictionary
info = AVAILABLE_MODELS.get(model_name)
model_path = hf_hub_download(
repo_id=info["repo_id"],
filename=info["filename"],
# Where to cache the downloaded file, files already cached will directly be reused
local_dir=MODEL_DIR
)
current_model = YOLO(model_path)
current_model_name = model_name
return current_model
def predict_image(img, conf_threshold, iou_threshold, model_name):
model = load_model(model_name)
results = model.predict(
source=img,
conf=conf_threshold,
iou=iou_threshold,
show_labels=True,
show_conf=True,
imgsz=640,
)
for r in results:
im_array = r.plot()
im = Image.fromarray(im_array[..., ::-1])
return im
iface = gradio.Interface(
fn=predict_image,
inputs=[
gradio.Image(type="pil", label="Upload Image"),
gradio.Slider(minimum=0, maximum=1, value=0.20, label="Confidence threshold"),
gradio.Slider(minimum=0, maximum=1, value=0.40, label="IoU threshold"),
gradio.Dropdown(choices=list(AVAILABLE_MODELS.keys()), label="Select Model", value=list(AVAILABLE_MODELS.keys())[0])
],
outputs=gradio.Image(type="pil", label="Result"),
title="Try out kitsumed YOLO models",
description="Select a model from kitsumed on Hugging Face and upload an image to perform predictions.",
)
if __name__ == "__main__":
iface.launch()