import torch import torch.nn as nn import cv2 import gradio as gr import numpy as np from PIL import Image import transformers from transformers import RobertaModel, RobertaTokenizer import timm import pandas as pd import matplotlib.pyplot as plt from timm.data import resolve_data_config from timm.data.transforms_factory import create_transform from model import Model from get_output import visualize_output # Use GPU if available device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Initialize used pretrained models vit = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=0, global_pool='').to(device) tokenizer = RobertaTokenizer.from_pretrained('roberta-base', truncation=True, do_lower_case=True) roberta = RobertaModel.from_pretrained("roberta-base") model = Model(vit, roberta, tokenizer, device).to(device) model.eval() # Initialize trained model state = torch.load(target_dir) model.load_state_dict(state['val_model_dict']) # Transform for input image config = resolve_data_config({}, model=vit) config['no_aug'] = True config['interpolation'] = 'bilinear' transform = create_transform(**config) def query_image(input_img, query, binarize, eval_threshold): PIL_image = Image.fromarray(input_img, "RGB") img = transform(PIL_image) img = torch.unsqueeze(img,0).to(device) with torch.no_grad(): output = model(img, query) img = visualize_output(img, output, binarize, eval_threshold) return img description = """ Gradio demo for an object detection architecture, introduced in my bachelor thesis. \n\nLorem ipsum .... *"image of a shoe"*. Refer to the CLIP paper to see the full list of text templates used to augment the training data. """ demo = gr.Interface( query_image, inputs=[gr.Image(), "text", "checkbox", gr.Slider(0, 1, value=0.25)], outputs="image", title="Object Detection Using Textual Queries", description=description, examples=[ ["examples/img1.jpeg", "Find a person.", True, 0.25], ], allow_flagging = "never", cache_examples=False, ) demo.launch(debug=True)