|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
import torch |
|
import gradio as gr |
|
from threading import Thread |
|
|
|
base_model_id = "meta-llama/Meta-Llama-3-8B-Instruct" |
|
new_model_id = "MohamedRashad/Arabic-Orpo-Llama-3-8B-Instruct" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(base_model_id) |
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
base_model_id, |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto", |
|
).eval() |
|
new_model = AutoModelForCausalLM.from_pretrained( |
|
new_model_id, |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto", |
|
).eval() |
|
terminators = [ |
|
tokenizer.eos_token_id, |
|
tokenizer.convert_tokens_to_ids("<|eot_id|>"), |
|
] |
|
|
|
|
|
def generate_both(system_prompt, input_text, base_chatbot, new_chatbot): |
|
base_text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True) |
|
new_text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True) |
|
|
|
system_prompt_list = [{"role": "system", "content": system_prompt}] |
|
input_text_list = [{"role": "user", "content": input_text}] |
|
|
|
base_chat_history = [] |
|
for user, assistant in base_chatbot: |
|
base_chat_history.append({"role": "user", "content": user}) |
|
base_chat_history.append({"role": "assistant", "content": assistant}) |
|
|
|
new_chat_history = [] |
|
for user, assistant in new_chatbot: |
|
new_chat_history.append({"role": "user", "content": user}) |
|
new_chat_history.append({"role": "assistant", "content": assistant}) |
|
|
|
base_messages = system_prompt_list + base_chat_history + input_text_list |
|
new_messages = system_prompt_list + new_chat_history + input_text_list |
|
|
|
base_input_ids = tokenizer.apply_chat_template( |
|
base_messages, |
|
add_generation_prompt=True, |
|
return_tensors="pt" |
|
).to(base_model.device).long() |
|
|
|
new_input_ids = tokenizer.apply_chat_template( |
|
new_messages, |
|
add_generation_prompt=True, |
|
return_tensors="pt" |
|
).to(new_model.device).long() |
|
|
|
base_generation_kwargs = dict( |
|
input_ids=base_input_ids, |
|
streamer=base_text_streamer, |
|
max_new_tokens=2048, |
|
eos_token_id=terminators, |
|
pad_token_id=tokenizer.eos_token_id, |
|
do_sample=True, |
|
temperature=0.2, |
|
top_p=0.9, |
|
) |
|
new_generation_kwargs = dict( |
|
input_ids=new_input_ids, |
|
streamer=new_text_streamer, |
|
max_new_tokens=2048, |
|
eos_token_id=terminators, |
|
pad_token_id=tokenizer.eos_token_id, |
|
do_sample=True, |
|
temperature=0.2, |
|
top_p=0.9, |
|
) |
|
|
|
base_thread = Thread(target=base_model.generate, kwargs=base_generation_kwargs) |
|
base_thread.start() |
|
|
|
base_chatbot.append([input_text, ""]) |
|
new_chatbot.append([input_text, ""]) |
|
|
|
for base_text in base_text_streamer: |
|
if "<|eot_id|>" in base_text: |
|
eot_location = base_text.find("<|eot_id|>") |
|
base_text = base_text[:eot_location] |
|
base_chatbot[-1][-1] += base_text |
|
yield base_chatbot, new_chatbot |
|
|
|
new_thread = Thread(target=new_model.generate, kwargs=new_generation_kwargs) |
|
new_thread.start() |
|
|
|
for new_text in new_text_streamer: |
|
if "<|eot_id|>" in new_text: |
|
eot_location = new_text.find("<|eot_id|>") |
|
new_text = new_text[:eot_location] |
|
new_chatbot[-1][-1] += new_text |
|
yield base_chatbot, new_chatbot |
|
|
|
return base_chatbot, new_chatbot |
|
|
|
def clear(): |
|
return [], [] |
|
|
|
with gr.Blocks(title="Arabic-ORPO-Llama3") as demo: |
|
with gr.Column(): |
|
gr.HTML("<center><h1>Arabic Chatbot Comparison</h1></center>") |
|
system_prompt = gr.Textbox(lines=1, label="System Prompt", value="You are a pirate chatbot who always responds in pirate speak!") |
|
with gr.Row(variant="panel"): |
|
base_chatbot = gr.Chatbot(label=base_model_id, rtl=False, likeable=True, show_copy_button=True) |
|
new_chatbot = gr.Chatbot(label=new_model_id, rtl=True, likeable=True, show_copy_button=True) |
|
with gr.Row(variant="panel"): |
|
with gr.Column(scale=1): |
|
submit_btn = gr.Button(value="Generate", variant="primary") |
|
clear_btn = gr.Button(value="Clear", variant="secondary") |
|
input_text = gr.Textbox(lines=1, label="", value="مرحبا", rtl=True, text_align="right", scale=3, show_copy_button=True) |
|
|
|
input_text.submit(generate_both, inputs=[system_prompt, input_text, base_chatbot, new_chatbot], outputs=[base_chatbot, new_chatbot]) |
|
submit_btn.click(generate_both, inputs=[system_prompt, input_text, base_chatbot, new_chatbot], outputs=[base_chatbot, new_chatbot]) |
|
clear_btn.click(clear, outputs=[base_chatbot, new_chatbot]) |
|
|
|
demo.launch() |
|
|