GLM4-Z1-32B / app.py
nikravan's picture
Update app.py
5ebc32e verified
raw
history blame
5.46 kB
import torch
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
import os
from threading import Thread
MODEL_LIST = ["THUDM/GLM-4-Z1-32B-0414"]
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL_ID = MODEL_LIST[0]
MODEL_NAME = "GLM-4-Z1-32B-0414"
TITLE = "<h1>3ML-bot (Text Only)</h1>"
DESCRIPTION = f"""
<center>
<p>😊 A Multi-Lingual Analytical Chatbot.
<br>
🚀 MODEL NOW: <a href="https://hf.co/nikravan/glm-4vq">{MODEL_NAME}</a>
</center>"""
CSS = """
h1 {
text-align: center;
display: block;
}
"""
# Configure BitsAndBytes for 4-bit quantization
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
@spaces.GPU()
def stream_chat(message, history: list, temperature: float, max_length: int, top_p: float, top_k: int, penalty: float):
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True,
quantization_config=quantization_config,
device_map="auto"
)
print(f'message is - {message}')
print(f'history is - {history}')
conversation = []
if len(history) > 0:
for prompt, answer in history:
conversation.extend([
{"role": "user", "content": prompt},
{"role": "assistant", "content": answer}
])
conversation.append({"role": "user", "content": message})
print(f"Conversation is -\n{conversation}")
input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True,
return_tensors="pt", return_dict=True).to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
max_length=max_length,
streamer=streamer,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
repetition_penalty=penalty,
eos_token_id=[151329, 151336, 151338],
)
gen_kwargs = {**input_ids, **generate_kwargs}
with torch.no_grad():
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer
chatbot = gr.Chatbot()
chat_input = gr.Textbox(
interactive=True,
placeholder="Enter your message here...",
show_label=False,
)
EXAMPLES = [
["Analyze the geopolitical implications of recent technological advancements in AI from a Chinese perspective."],
["¿Cuáles son los desafíos éticos más importantes en el desarrollo de la inteligencia artificial general?"],
["从经济学和社会学角度分析,人工智能将如何改变未来的就业市场?"],
["ما هي التحديات الرئيسية التي تواجه تطوير الذكاء الاصطناعي في العالم العربي؟"],
["नैतिक कृत्रिम बुद्धिमत्ता विकास में सबसे बड़ी चुनौतियाँ क्या हैं? विस्तार से समझाइए।"],
["Кои са основните предизвикателства пред разработването на изкуствен интелект в България и Източна Европа?"],
["Explain the potential risks and benefits of quantum computing in national security contexts."],
["分析气候变化对全球经济不平等的影响,并提出可能的解决方案。"],
]
with gr.Blocks(css=CSS, theme="soft", fill_height=True) as demo:
gr.HTML(TITLE)
gr.HTML(DESCRIPTION)
gr.ChatInterface(
fn=stream_chat,
textbox=chat_input,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.8,
label="Temperature",
render=False,
),
gr.Slider(
minimum=1024,
maximum=8192,
step=1,
value=4096,
label="Max Length",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=1.0,
label="top_p",
render=False,
),
gr.Slider(
minimum=1,
maximum=20,
step=1,
value=10,
label="top_k",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.0,
label="Repetition penalty",
render=False,
),
],
examples=EXAMPLES,
)
if __name__ == "__main__":
demo.queue(api_open=False).launch(show_api=False, share=False)