File size: 4,503 Bytes
9bcc72d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126e1f7
 
 
 
 
9bcc72d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db7622e
83b83d4
 
9bcc72d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
import gc
from string import Template
from threading import Thread

import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, BatchEncoding, TextIteratorStreamer


tokenizer = AutoTokenizer.from_pretrained(
    "PY007/LiteChat-Preview",
)
model = AutoModelForCausalLM.from_pretrained(
    "PY007/LiteChat-Preview",
    trust_remote_code=True,
    device_map="auto",
)
model.eval()


max_context_length = model.config.max_position_embeddings
max_new_tokens = 1024


prompt_template = Template("""\
### Instruction: $human
### Response: $bot\
""")


system_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
system_prompt_tokens = tokenizer([f"{system_prompt}\n\n"], return_tensors="pt")
max_sys_tokens = system_prompt_tokens['input_ids'].size(-1)


def bot(history):
    history = history or []

    # Inject prompt formatting into the history
    prompt_history = []
    for human, bot in history:
        if bot is not None:
            bot = bot.replace("<br>", "\n")
            bot = bot.rstrip()
        prompt_history.append(
            prompt_template.substitute(
                human=human, bot=bot if bot is not None else "")
        )

    msg_tokens = tokenizer(
        "\n\n".join(prompt_history).strip(),
        return_tensors="pt",
        add_special_tokens=False  # Use <BOS> from the system prompt
    )

    # Take only the most recent context up to the max context length and prepend the
    # system prompt with the messages
    max_tokens = -max_context_length + max_new_tokens + max_sys_tokens
    # inputs = BatchEncoding({
    #     k: torch.concat([system_prompt_tokens[k], msg_tokens[k][:, max_tokens:]], dim=-1)
    #     for k in msg_tokens
    # }).to('cuda')
    inputs = BatchEncoding({
        k: torch.concat([system_prompt_tokens[k], msg_tokens[k][:, max_tokens:]], dim=-1)
        for k in msg_tokens
    })
    # Remove `token_type_ids` b/c it's not yet supported for LLaMA `transformers` models
    if inputs.get("token_type_ids", None) is not None:
        inputs.pop("token_type_ids")

    streamer = TextIteratorStreamer(
        tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
    )
    generate_kwargs = dict(
        inputs,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=0.95,
        top_k=50,
        temperature=0.7,
    )
    thread = Thread(target=model.generate, kwargs=generate_kwargs)
    thread.start()

    partial_text = ""
    for new_text in streamer:
        # Process out the prompt separator
        new_text = new_text.replace("<br>", "\n")
        if "###" in new_text:
            new_text = new_text.split("###")[0]
            partial_text += new_text.strip()
            history[-1][1] = partial_text
            break
        else:
            # Filter empty trailing new lines
            if new_text == "\n":
                new_text = new_text.strip()
            partial_text += new_text
            history[-1][1] = partial_text
        yield history
    return partial_text


def user(user_message, history):
    return "", history + [[user_message, None]]


with gr.Blocks() as demo:
    gr.Markdown("# LiteChat by StatNLP")
    gr.Markdown("The model is currently running on free-tier CPU, which has limited speed.")
    gr.Markdown("Paper and code will be released soon.")

    chatbot = gr.Chatbot([], elem_id="chatbot").style(height=500)
    state = gr.State([])
    with gr.Row():
        with gr.Column():
            msg = gr.Textbox(
                label="Send a message",
                placeholder="Send a message",
                show_label=False
            ).style(container=False)
        with gr.Column():
            with gr.Row():
                submit = gr.Button("Send")
                stop = gr.Button("Stop")
                clear = gr.Button("Clear History")

    submit_event = msg.submit(user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then(
        fn=bot, inputs=[chatbot], outputs=[chatbot], queue=True)
    submit_click_event = submit.click(user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then(
        fn=bot, inputs=[chatbot], outputs=[chatbot], queue=True)

    stop.click(fn=None, inputs=None, outputs=None, cancels=[submit_event, submit_click_event], queue=False)
    clear.click(lambda: None, None, [chatbot], queue=True)

demo.queue(max_size=32)
demo.launch()