File size: 2,424 Bytes
ea37c27
 
ca317b2
 
ea37c27
ca317b2
ea37c27
 
ed5a7bf
ee668ff
03d6908
ea37c27
 
ca317b2
 
 
03d6908
ca317b2
ea37c27
ca317b2
 
 
 
 
ee668ff
 
ea37c27
386e329
ea37c27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca317b2
 
 
ea37c27
ca317b2
 
 
 
 
ea37c27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee668ff
 
 
ea37c27
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import time
from threading import Thread
from llava_llama3.serve.cli import chat_llava
from llava_llama3.model.builder import load_pretrained_model
import gradio as gr
import torch
from PIL import Image

import spaces

# Model configuration
model_id = "TheFinAI/FinLLaVA"
device = "cuda:0"
load_8bit = False
load_4bit = False

# Load the pretrained model
tokenizer, llava_model, image_processor, context_len = load_pretrained_model(
    model_id, 
    None, 
    'llava_llama3', 
    load_8bit, 
    load_4bit, 
    device=device
)


@spaces.GPU
def bot_streaming(message, history):
    print(message)
    image = None
    
    # Check if there's an image in the current message
    if message["files"]:
        # message["files"][-1] could be a dictionary or a string
        if isinstance(message["files"][-1], dict):
            image = message["files"][-1]["path"]
        else:
            image = message["files"][-1]
    else:
        # If no image in the current message, look in the history for the last image
        for hist in history:
            if isinstance(hist[0], tuple):
                image = hist[0][0]
    
    # Error handling if no image is found
    if image is None:
        raise gr.Error("You need to upload an image for LLaVA to work.")
    
    # Load the image
    image = Image.open(image)
    
    # Generate the prompt for the model
    prompt = message['text']
    
    # Call the chat_llava function to generate the output
    output = chat_llava(
        args=None,
        image_file=image,
        text=prompt,
        tokenizer=tokenizer,
        model=llava_model,
        image_processor=image_processor,
        context_len=context_len
    )
    
    # Stream the output
    buffer = ""
    for new_text in output:
        buffer += new_text
        yield buffer


chatbot=gr.Chatbot(scale=1)
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
with gr.Blocks(fill_height=True, ) as demo:
    gr.ChatInterface(
    fn=bot_streaming,
    title="LLaVA Llama-3-8B",
    examples=[{"text": "What is on the flower?", "files": ["./bee.jpg"]},
              {"text": "How to make this pastry?", "files": ["./baklava.png"]}],

    stop_btn="Stop Generation",
    multimodal=True,
    textbox=chat_input,
    chatbot=chatbot,
    )

demo.queue(api_open=False)
demo.launch(show_api=False, share=False)