Spaces:
Runtime error
Runtime error
File size: 4,533 Bytes
9bcc72d db7622e 83b83d4 9bcc72d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import os
import gc
from string import Template
from threading import Thread
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, BatchEncoding, TextIteratorStreamer
tokenizer = AutoTokenizer.from_pretrained(
"PY007/LiteChat-Preview",
)
model = AutoModelForCausalLM.from_pretrained(
"PY007/LiteChat-Preview",
trust_remote_code=True,
device_map="auto",
torch_dtype=torch.float16
)
model.eval()
max_context_length = model.config.max_position_embeddings
max_new_tokens = 1024
prompt_template = Template("""\
### Instruction: $human
### Response: $bot\
""")
system_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
system_prompt_tokens = tokenizer([f"{system_prompt}\n\n"], return_tensors="pt")
max_sys_tokens = system_prompt_tokens['input_ids'].size(-1)
def bot(history):
history = history or []
# Inject prompt formatting into the history
prompt_history = []
for human, bot in history:
if bot is not None:
bot = bot.replace("<br>", "\n")
bot = bot.rstrip()
prompt_history.append(
prompt_template.substitute(
human=human, bot=bot if bot is not None else "")
)
msg_tokens = tokenizer(
"\n\n".join(prompt_history).strip(),
return_tensors="pt",
add_special_tokens=False # Use <BOS> from the system prompt
)
# Take only the most recent context up to the max context length and prepend the
# system prompt with the messages
max_tokens = -max_context_length + max_new_tokens + max_sys_tokens
inputs = BatchEncoding({
k: torch.concat([system_prompt_tokens[k], msg_tokens[k][:, max_tokens:]], dim=-1)
for k in msg_tokens
}).to('cuda')
# inputs = BatchEncoding({
# k: torch.concat([system_prompt_tokens[k], msg_tokens[k][:, max_tokens:]], dim=-1)
# for k in msg_tokens
# })
# Remove `token_type_ids` b/c it's not yet supported for LLaMA `transformers` models
if inputs.get("token_type_ids", None) is not None:
inputs.pop("token_type_ids")
streamer = TextIteratorStreamer(
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=0.95,
top_k=50,
temperature=0.7,
)
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
partial_text = ""
for new_text in streamer:
# Process out the prompt separator
new_text = new_text.replace("<br>", "\n")
if "###" in new_text:
new_text = new_text.split("###")[0]
partial_text += new_text.strip()
history[-1][1] = partial_text
break
else:
# Filter empty trailing new lines
if new_text == "\n":
new_text = new_text.strip()
partial_text += new_text
history[-1][1] = partial_text
yield history
return partial_text
def user(user_message, history):
return "", history + [[user_message, None]]
with gr.Blocks() as demo:
gr.Markdown("# LiteChat by StatNLP")
gr.Markdown("The model is currently running on free-tier CPU, which has limited speed.")
gr.Markdown("Paper and code will be released soon.")
chatbot = gr.Chatbot([], elem_id="chatbot").style(height=500)
state = gr.State([])
with gr.Row():
with gr.Column():
msg = gr.Textbox(
label="Send a message",
placeholder="Send a message",
show_label=False
).style(container=False)
with gr.Column():
with gr.Row():
submit = gr.Button("Send")
stop = gr.Button("Stop")
clear = gr.Button("Clear History")
submit_event = msg.submit(user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then(
fn=bot, inputs=[chatbot], outputs=[chatbot], queue=True)
submit_click_event = submit.click(user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then(
fn=bot, inputs=[chatbot], outputs=[chatbot], queue=True)
stop.click(fn=None, inputs=None, outputs=None, cancels=[submit_event, submit_click_event], queue=False)
clear.click(lambda: None, None, [chatbot], queue=True)
demo.queue(max_size=32)
demo.launch() |