|
import torch |
|
import os |
|
import streamlit as st |
|
import matplotlib.pyplot as plt |
|
import pandas as pd |
|
import numpy as np |
|
import plotly.express as px |
|
import pickle |
|
import random |
|
|
|
from PIL import Image |
|
from transformers import YolosFeatureExtractor, YolosForObjectDetection |
|
from torchvision.transforms import ToTensor, ToPILImage |
|
from annotated_text import annotated_text |
|
|
|
|
|
st.set_page_config(layout="wide") |
|
|
|
@st.cache_data(ttl=3600, show_spinner=False) |
|
def load_model(feature_extractor_url, model_url): |
|
feature_extractor_ = YolosFeatureExtractor.from_pretrained(feature_extractor_url) |
|
model_ = YolosForObjectDetection.from_pretrained(model_url) |
|
return feature_extractor_, model_ |
|
|
|
|
|
def rgb_to_hex(rgb): |
|
"""Converts an RGB tuple to an HTML-style Hex string.""" |
|
hex_color = "#{:02x}{:02x}{:02x}".format(int(rgb[0] * 255), int(rgb[1] * 255), int(rgb[2] * 255)) |
|
return hex_color |
|
|
|
|
|
def fix_channels(t): |
|
if len(t.shape) == 2: |
|
return ToPILImage()(torch.stack([t for i in (0, 0, 0)])) |
|
if t.shape[0] == 4: |
|
return ToPILImage()(t[:3]) |
|
if t.shape[0] == 1: |
|
return ToPILImage()(torch.stack([t[0] for i in (0, 0, 0)])) |
|
return ToPILImage()(t) |
|
|
|
|
|
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125], |
|
[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]] |
|
|
|
def idx_to_text(i): |
|
if i in list(dict_cats_final.keys()): |
|
return dict_cats_final[i.item()] |
|
else: |
|
return False |
|
|
|
|
|
def box_cxcywh_to_xyxy(x): |
|
x_c, y_c, w, h = x.unbind(1) |
|
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), |
|
(x_c + 0.5 * w), (y_c + 0.5 * h)] |
|
return torch.stack(b, dim=1) |
|
|
|
def rescale_bboxes(out_bbox, size): |
|
img_w, img_h = size |
|
b = box_cxcywh_to_xyxy(out_bbox) |
|
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32) |
|
return b |
|
|
|
def plot_results(pil_img, prob, boxes): |
|
fig = plt.figure(figsize=(16,10)) |
|
plt.imshow(pil_img) |
|
ax = plt.gca() |
|
|
|
colors = COLORS * 100 |
|
colors_used = [] |
|
|
|
for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors): |
|
cl = p.argmax() |
|
p_max = p.max().detach().numpy() |
|
if idx_to_text(cl) is False: |
|
pass |
|
|
|
else: |
|
colors_used.append(rgb_to_hex(c)) |
|
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, |
|
fill=False, color=c, linewidth=3)) |
|
ax.text(xmin, ymin, f"{idx_to_text(cl)}", fontsize=10, |
|
bbox=dict(facecolor=c, alpha=0.8)) |
|
plt.axis('off') |
|
|
|
plt.savefig("results_od.png", |
|
bbox_inches ="tight") |
|
plt.show() |
|
st.image("results_od.png") |
|
|
|
return colors_used |
|
|
|
|
|
def return_probas(outputs, threshold): |
|
probas = outputs.logits.softmax(-1)[0, :, :-1] |
|
probas = probas[:][:,list(dict_cats_final.keys())] |
|
keep = probas.max(-1).values > threshold |
|
|
|
return probas, keep |
|
|
|
|
|
def visualize_probas(probas, threshold, colors): |
|
label_df = pd.DataFrame({"label":probas.max(-1).indices.detach().numpy(), |
|
"proba":probas.max(-1).values.detach().numpy()}) |
|
|
|
cats_dict = dict(zip(np.arange(0,len(cats)),cats)) |
|
label_df["label"] = label_df["label"].map(cats_dict) |
|
top_label_df = label_df.loc[label_df["proba"]>threshold].round(2) |
|
top_label_df["colors"] = colors |
|
top_label_df.sort_values(by=["proba"], ascending=False, inplace=True) |
|
|
|
|
|
|
|
mode_func = lambda x: x.mode().iloc[0] |
|
top_label_df_agg = top_label_df.groupby("label").agg({"proba":"mean", "colors":mode_func}) |
|
top_label_df_agg = top_label_df_agg.reset_index().sort_values(by=["proba"], ascending=False) |
|
top_label_df_agg.columns = ["Item","Score","Colors"] |
|
|
|
color_map = dict(zip(top_label_df_agg["Item"].to_list(), |
|
top_label_df_agg["Colors"].to_list())) |
|
|
|
fig = px.bar(top_label_df_agg, y='Item', x='Score', |
|
color="Item", title="Probability scores") |
|
st.plotly_chart(fig, use_container_width=True) |
|
|
|
|
|
cats = ['shirt, blouse', 'top, t-shirt, sweatshirt', 'sweater', 'cardigan', 'jacket', 'vest', 'pants', 'shorts', 'skirt', 'coat', 'dress', 'jumpsuit', |
|
'cape', 'glasses', 'hat', 'headband, head covering, hair accessory', 'tie', 'glove', 'watch', 'belt', 'leg warmer', 'tights, stockings', 'sock', 'shoe', 'bag, wallet', 'scarf', 'umbrella', 'hood', 'collar', |
|
'lapel', 'epaulette', 'sleeve', 'pocket', 'neckline', 'buckle', 'zipper', 'applique', 'bead', 'bow', 'flower', 'fringe', 'ribbon', 'rivet', 'ruffle', 'sequin', 'tassel'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.markdown("# Object Detection") |
|
|
|
st.markdown("### What is Object Detection ?") |
|
|
|
|
|
st.info("""Object Detection is a computer vision task in which the goal is to **detect** and **locate objects** of interest in an image or video. |
|
The task involves identifying the position and boundaries of objects (or **bounding boxes**) in an image, and classifying the objects into different categories.""") |
|
|
|
|
|
st.markdown("Here is an example of Object Detection for Traffic Analysis.") |
|
|
|
|
|
st.video(data='https://www.youtube.com/watch?v=PVCGDoTZHaI') |
|
|
|
st.markdown(" ") |
|
|
|
st.markdown("""Common applications of Object Detection include: |
|
- **Autonomous Vehicles** :car: : Object detection is crucial for self-driving cars to track pedestrians, cyclists, other vehicles, and obstacles on the road. |
|
- **Retail** 🏬 : Implementing smart shelves and checkout systems that use object detection to track inventory and monitor stock levels. |
|
- **Healthcare** 👨⚕️: Detecting and tracking anomalies in medical images, such as tumors or abnormalities, for diagnostic purposes or prevention. |
|
- **Manufacturing** 🏭: Quality control on production lines by detecting defects or irregularities in manufactured products. Ensuring workplace safety by monitoring the movement of workers and equipment. |
|
- **Fashion and E-commerce** 🛍️ : Improving virtual try-on experiences by accurately detecting and placing virtual clothing items on users. |
|
""") |
|
|
|
|
|
st.markdown(" ") |
|
st.divider() |
|
|
|
st.markdown("## Fashion Object Detection 👗") |
|
|
|
|
|
st.info("""In this use case, we are going to identify and locate different articles of clothings, as well as finer details such as a collar or pocket using an object detection AI model. |
|
The images used were taken from **Dior's 2020 Fall Women Fashion Show**.""") |
|
|
|
st.markdown(" ") |
|
|
|
images_dior = [os.path.join("data/dior_show/images",url) for url in os.listdir("data/dior_show/images") if url != "results"] |
|
columns_img = st.columns(4) |
|
for img, col in zip(images_dior,columns_img): |
|
with col: |
|
st.image(img) |
|
|
|
st.markdown(" ") |
|
|
|
|
|
st.markdown("### About the model 📚") |
|
st.markdown("""The object detection model was trained specifically to **detect clothing items** on images. <br> |
|
It is able to detect <b>46</b> different types of clothing items.""", unsafe_allow_html=True) |
|
|
|
colors = ["#8ef", "#faa", "#afa", "#fea", "#8ef","#afa"]*7 + ["#8ef", "#faa", "#afa", "#fea"] |
|
|
|
cats_annotated = [(g,"","#afa") for g in cats] |
|
annotated_text([cats_annotated]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.markdown("Credits: https://huggingface.co/valentinafeve/yolos-fashionpedia") |
|
st.markdown("") |
|
st.markdown("") |
|
|
|
|
|
|
|
|
|
|
|
st.markdown("### Select an image 🖼️") |
|
|
|
|
|
image_ = None |
|
fashion_images_path = r"data/dior_show/images" |
|
list_images = os.listdir(fashion_images_path) |
|
image_name = st.selectbox("Select the image you wish to run the model on", list_images) |
|
image_ = os.path.join(fashion_images_path, image_name) |
|
st.image(image_, width=300) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.markdown(" ") |
|
st.markdown(" ") |
|
|
|
|
|
|
|
|
|
|
|
|
|
dict_cats = dict(zip(np.arange(len(cats)), cats)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
selected_options = cats |
|
dict_cats_final = {key:value for (key,value) in dict_cats.items() if value in selected_options} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.markdown("### Define a threshold for predictions 🔎") |
|
st.markdown("""This section allows you to control how confident you want your model to be with its predictions. <br> |
|
Objects that are given a lower score than the chosen threshold will be ignored in the final results.""", unsafe_allow_html=True) |
|
|
|
st.markdown(" Below is an example of probability scores given by object detection models for each element detected.") |
|
|
|
|
|
st.image("images/probability_od.png", caption="Example with bounding boxes and probability scores given by object detection models") |
|
|
|
st.markdown(" ") |
|
|
|
st.markdown("**Select a threshold** ") |
|
|
|
|
|
|
|
|
|
threshold = st.slider('**Select a threshold**', min_value=0.5, step=0.05, max_value=1.0, value=0.75, label_visibility="collapsed") |
|
|
|
|
|
|
|
|
|
|
|
st.write("You've selected a threshold at", threshold) |
|
st.markdown(" ") |
|
|
|
|
|
|
|
pickle_file_path = r"data/dior_show/results" |
|
|
|
|
|
|
|
|
|
run_model = st.button("**Run the model**", type="primary") |
|
|
|
if run_model: |
|
if image_ != None and selected_options != None and threshold!= None: |
|
with st.spinner('Wait for it...'): |
|
|
|
|
|
image = Image.open(image_) |
|
image = fix_channels(ToTensor()(image)) |
|
|
|
|
|
FEATURE_EXTRACTOR_PATH = "hustvl/yolos-small" |
|
MODEL_PATH = "valentinafeve/yolos-fashionpedia" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image_name = image_name[:5] |
|
path_load_pickle = os.path.join(pickle_file_path, f"{image_name}_results.pkl") |
|
with open(path_load_pickle, 'rb') as pickle_file: |
|
outputs = pickle.load(pickle_file) |
|
|
|
probas, keep = return_probas(outputs, threshold) |
|
|
|
st.markdown("#### See the results ☑️") |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
|
|
bboxes_scaled = rescale_bboxes(outputs.pred_boxes[0, keep].cpu(), image.size) |
|
colors_used = plot_results(image, probas[keep], bboxes_scaled) |
|
|
|
with col2: |
|
|
|
if not any(keep.tolist()): |
|
st.error("""No objects were detected on the image. |
|
Decrease your threshold or choose differents items to detect.""") |
|
else: |
|
visualize_probas(probas, threshold, colors_used) |
|
|
|
|
|
else: |
|
st.error("You must select an **image**, **elements to detect** and a **threshold** to run the model !") |
|
|
|
|
|
|