Spaces:
Sleeping
Sleeping
from transformers import BlipForConditionalGeneration, BlipProcessor | |
import time | |
import gradio as gr | |
def get_image_captioning_tab(): | |
salesforce_model_name = "Salesforce/blip-image-captioning-base" | |
salesforce_model = BlipForConditionalGeneration.from_pretrained(salesforce_model_name) | |
salesforce_processor = BlipProcessor.from_pretrained(salesforce_model_name) | |
noamrot_model_name = "noamrot/FuseCap_Image_Captioning" | |
noamrot_model = BlipForConditionalGeneration.from_pretrained(noamrot_model_name) | |
noamrot_processor = BlipProcessor.from_pretrained(noamrot_model_name) | |
model_map = { | |
salesforce_model_name: (salesforce_model, salesforce_processor), | |
noamrot_model_name: (noamrot_model, noamrot_processor) | |
} | |
def gradio_process(model_name, image, text): | |
(model, processor) = model_map[model_name] | |
start = time.time() | |
inputs = processor(image, text, return_tensors="pt") | |
out = model.generate(**inputs) | |
result = processor.decode(out[0], skip_special_tokens=True) | |
end = time.time() | |
time_spent = end - start | |
return [result, time_spent] | |
with gr.TabItem("Image Captioning") as image_captioning_tab: | |
gr.Markdown("# Image Captioning") | |
with gr.Row(): | |
with gr.Column(): | |
# Input components | |
input_image = gr.Image(label="Upload Image", type="pil") | |
input_text = gr.Textbox(label="Caption") | |
model_selector = gr.Dropdown([salesforce_model_name, noamrot_model_name], | |
label = "Select Model") | |
# Process button | |
process_btn = gr.Button("Generate caption") | |
with gr.Column(): | |
# Output components | |
elapsed_result = gr.Textbox(label="Seconds elapsed", lines=1) | |
output_text = gr.Textbox(label="Generated caption") | |
# Connect the input components to the processing function | |
process_btn.click( | |
fn=gradio_process, | |
inputs=[ | |
model_selector, | |
input_image, | |
input_text | |
], | |
outputs=[output_text, elapsed_result] | |
) | |
return image_captioning_tab | |