|
import torch |
|
import torch.nn as nn |
|
import cv2 |
|
import gradio as gr |
|
import numpy as np |
|
from PIL import Image |
|
import transformers |
|
from transformers import RobertaModel, RobertaTokenizer |
|
import timm |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
from timm.data import resolve_data_config |
|
from timm.data.transforms_factory import create_transform |
|
|
|
from model import Model |
|
from output import visualize_output |
|
|
|
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
vit = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=0, global_pool='').to(device) |
|
tokenizer = RobertaTokenizer.from_pretrained('roberta-base', truncation=True, do_lower_case=True) |
|
roberta = RobertaModel.from_pretrained("roberta-base") |
|
model = Model(vit, roberta, tokenizer, device).to(device) |
|
model.eval() |
|
|
|
|
|
state = torch.load('saved_model', map_location=torch.device('cpu')) |
|
model.load_state_dict(state['val_model_dict']) |
|
|
|
|
|
config = resolve_data_config({}, model=vit) |
|
config['no_aug'] = True |
|
config['interpolation'] = 'bilinear' |
|
transform = create_transform(**config) |
|
|
|
|
|
def query_image(input_img, query, binarize, eval_threshold): |
|
|
|
PIL_image = Image.fromarray(input_img, "RGB") |
|
img = transform(PIL_image) |
|
img = torch.unsqueeze(img,0).to(device) |
|
|
|
with torch.no_grad(): |
|
output = model(img, query) |
|
|
|
img = visualize_output(img, output, binarize, eval_threshold) |
|
return img |
|
|
|
|
|
description = """ |
|
Gradio demo for an object detection architecture, introduced in my <a href="https://www.google.com/">bachelor thesis</a> (link will be added). |
|
\n\n |
|
You can use this architecture to detect objects using textual queries. To use it, simply upload an image and enter any query you want. |
|
It can be a single word or in the form of a sentence. The model is trained to recognize only 80 categories from the COCO Detection 2017 dataset. |
|
Refer to <a href="https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/">this</a> website |
|
or the original <a href="https://arxiv.org/pdf/1405.0312.pdf">COCO</a> paper to see the full list of categories. |
|
\n\n |
|
Best results are obtained using one of these sentences, which were used during training: |
|
<div class="row"> |
|
<div class="column"> |
|
<ul> |
|
<li>Find a {class}.</li> |
|
<li>Tea</li> |
|
<li>Milk</li> |
|
</ul> |
|
</div> |
|
<div class="column"> |
|
<ul> |
|
<li>Coffee</li> |
|
<li>Tea</li> |
|
<li>Milk</li> |
|
</ul> |
|
</div> |
|
</div> |
|
""" |
|
demo = gr.Interface( |
|
query_image, |
|
inputs=[gr.Image(), "text", "checkbox", gr.Slider(0, 1, value=0.25)], |
|
outputs="image", |
|
title="Object Detection Using Textual Queries", |
|
description=description, |
|
examples=[ |
|
["examples/img1.jpeg", "Find a person.", True, 0.45], |
|
["examples/img2.jpeg", "Detect a person with skis.", True, 0.25], |
|
["examples/img3.jpeg", "There should be a cat in this picture, where?", True, 0.25], |
|
["examples/img4.jpeg", "Can you find an elephant?", True, 0.25], |
|
], |
|
cache_examples=False, |
|
allow_flagging = "never", |
|
css = """ |
|
.row { |
|
display: flex; |
|
} |
|
|
|
.column { |
|
flex: 30%; |
|
} |
|
""" |
|
) |
|
demo.launch() |
|
|
|
|