File size: 3,382 Bytes
2c5aba6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a634e56
2c5aba6
 
 
 
 
c0be566
2c5aba6
 
 
 
 
 
c0be566
5392e1d
c0be566
 
a634e56
2c5aba6
 
 
 
 
a634e56
2c5aba6
 
 
 
 
 
 
 
 
 
 
 
a634e56
2c5aba6
6b47b8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c682534
f6c73a1
2c5aba6
 
 
 
232b6f0
2c5aba6
 
 
bbd2de4
 
 
62fda84
2c5aba6
 
02ca40c
f6c73a1
 
 
 
 
 
c682534
f6c73a1
 
2c5aba6
02ca40c
2c5aba6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
import torch.nn as nn
import cv2
import gradio as gr
import numpy as np
from PIL import Image
import transformers
from transformers import RobertaModel, RobertaTokenizer
import timm
import pandas as pd 
import matplotlib.pyplot as plt
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

from model import Model
from output import visualize_output


# Use GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Initialize used pretrained models
vit = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=0, global_pool='').to(device)
tokenizer = RobertaTokenizer.from_pretrained('roberta-base', truncation=True, do_lower_case=True)
roberta = RobertaModel.from_pretrained("roberta-base")
model = Model(vit, roberta, tokenizer, device).to(device)
model.eval()

# Initialize trained model
state = torch.load('saved_model', map_location=torch.device('cpu'))
model.load_state_dict(state['val_model_dict'])

# Create transform for input image
config = resolve_data_config({}, model=vit)
config['no_aug'] = True
config['interpolation'] = 'bilinear'
transform = create_transform(**config)

# Inference function
def query_image(input_img, query, binarize, eval_threshold):

    PIL_image = Image.fromarray(input_img, "RGB")
    img = transform(PIL_image)
    img = torch.unsqueeze(img,0).to(device)

    with torch.no_grad():
        output = model(img, query)

    img = visualize_output(img, output, binarize, eval_threshold)
    return img

# Gradio interface
description = """
Gradio demo for an object detection architecture, introduced in my <a href="https://www.google.com/">bachelor thesis</a> (link will be added). 
\n\n
You can use this architecture to detect objects using textual queries. To use it, simply upload an image and enter any query you want.
It can be a single word or in the form of a sentence. The model is trained to recognize only 80 categories from the COCO Detection 2017 dataset. 
Refer to <a href="https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/">this</a> website 
or the original <a href="https://arxiv.org/pdf/1405.0312.pdf">COCO</a> paper to see the full list of categories.
\n\n
Best results are obtained using one of these sentences, which were used during training:
<div class="row">
    <div class="column">
        <ul>
            <li>Find a {class}.</li>
            <li>Tea</li>
            <li>Milk</li>
        </ul>
    </div>
    <div class="column">
        <ul>
            <li>Coffee</li>
            <li>Tea</li>
            <li>Milk</li>
        </ul>
    </div>
</div>
"""
demo = gr.Interface(
    query_image, 
    inputs=[gr.Image(), "text", "checkbox", gr.Slider(0, 1, value=0.25)], 
    outputs="image",
    title="Object Detection Using Textual Queries",
    description=description,
    examples=[
        ["examples/img1.jpeg", "Find a person.", True, 0.45],
        ["examples/img2.jpeg", "Detect a person with skis.", True, 0.25],
        ["examples/img3.jpeg", "There should be a cat in this picture, where?", True, 0.25],
        ["examples/img4.jpeg", "Can you find an elephant?", True, 0.25],
    ],
    cache_examples=False,
    allow_flagging = "never",
    css = """
    .row {
      display: flex;
    }

    .column {
      flex: 30%;
    }
    """
)
demo.launch()