|
import torch |
|
import torch.nn as nn |
|
import cv2 |
|
import gradio as gr |
|
import numpy as np |
|
from PIL import Image |
|
import transformers |
|
from transformers import RobertaModel, RobertaTokenizer |
|
import timm |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
from timm.data import resolve_data_config |
|
from timm.data.transforms_factory import create_transform |
|
|
|
from model import Model |
|
from get_output import visualize_output |
|
|
|
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
vit = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=0, global_pool='').to(device) |
|
tokenizer = RobertaTokenizer.from_pretrained('roberta-base', truncation=True, do_lower_case=True) |
|
roberta = RobertaModel.from_pretrained("roberta-base") |
|
model = Model(vit, roberta, tokenizer, device).to(device) |
|
model.eval() |
|
|
|
|
|
state = torch.load(target_dir) |
|
model.load_state_dict(state['val_model_dict']) |
|
|
|
|
|
config = resolve_data_config({}, model=vit) |
|
config['no_aug'] = True |
|
config['interpolation'] = 'bilinear' |
|
transform = create_transform(**config) |
|
|
|
|
|
def query_image(input_img, query, binarize, eval_threshold): |
|
|
|
PIL_image = Image.fromarray(input_img, "RGB") |
|
img = transform(PIL_image) |
|
img = torch.unsqueeze(img,0).to(device) |
|
|
|
with torch.no_grad(): |
|
output = model(img, query) |
|
|
|
img = visualize_output(img, output, binarize, eval_threshold) |
|
return img |
|
|
|
|
|
description = """ |
|
Gradio demo for an object detection architecture, |
|
introduced in <a href="https://arxiv.org/abs/2205.06230">my bachelor thesis</a>. |
|
\n\nLorem ipsum .... |
|
*"image of a shoe"*. Refer to the <a href="https://arxiv.org/abs/2103.00020">CLIP</a> paper to see the full list of text templates used to augment the training data. |
|
""" |
|
demo = gr.Interface( |
|
query_image, |
|
inputs=[gr.Image(), "text", "checkbox", gr.Slider(0, 1, value=0.25)], |
|
outputs="image", |
|
title="Object Detection Using Textual Queries", |
|
description=description, |
|
examples=[ |
|
["examples/img1.jpeg", "Find a person.", True, 0.25], |
|
], |
|
allow_flagging = "never", |
|
cache_examples=False, |
|
) |
|
demo.launch(debug=True) |
|
|
|
|