m96tkmok commited on
Commit
e464683
·
verified ·
1 Parent(s): cf143e6

Update app.py

Browse files

Use Streamlit interface

Files changed (1) hide show
  1. app.py +118 -92
app.py CHANGED
@@ -1,99 +1,125 @@
1
- from threading import Thread
2
- from huggingface_hub import login
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
-
5
- import gradio as gr
6
  import os
7
 
8
- # Log In
9
- HF_TOKEN = os.environ.get("HF_TOKEN", None)
10
- login(token=HF_TOKEN)
11
-
12
- TITLE = "<h1><center>Chat with lianghsun/Llama-3.2-Taiwan-3B</center></h1>"
13
-
14
- DESCRIPTION = "<h3><center>Visit <a href='https://huggingface.co/lianghsun/Llama-3.2-Taiwan-3B' target='_blank'> the model page</a> for details.</center></h3>"
15
-
16
- DEFAULT_SYSTEM = "你是一個說中文的聊天機械人, 使用正體中文回答問題."
17
-
18
- CSS = """
19
- .duplicate-button {
20
- margin: auto !important;
21
- color: white !important;
22
- background: green !important;
23
- border-radius: 100vh !important;
24
- }
25
- """
26
-
27
-
28
- tokenizer = AutoTokenizer.from_pretrained("shenzhi-wang/Gemma-2-9B-Chinese-Chat")
29
- model = AutoModelForCausalLM.from_pretrained("shenzhi-wang/Gemma-2-9B-Chinese-Chat", torch_dtype="auto", device_map="auto")
30
-
31
- def stream_chat(message: str, history: list, system: str, temperature: float, max_new_tokens: int):
32
- conversation = [{"role": "system", "content": system or DEFAULT_SYSTEM}]
33
- for prompt, answer in history:
34
- conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
35
-
36
- conversation.append({"role": "user", "content": message})
37
 
38
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(
39
- model.device
40
- )
41
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
42
-
43
- generate_kwargs = dict(
44
- input_ids=input_ids,
45
- streamer=streamer,
46
- max_new_tokens=max_new_tokens,
47
- temperature=temperature,
48
- do_sample=True,
49
- )
50
- if temperature == 0:
51
- generate_kwargs["do_sample"] = False
52
-
53
- t = Thread(target=model.generate, kwargs=generate_kwargs)
54
- t.start()
55
-
56
- output = ""
57
- for new_token in streamer:
58
- output += new_token
59
- yield output
60
-
61
-
62
- chatbot = gr.Chatbot(height=450)
63
 
64
- with gr.Blocks(css=CSS) as demo:
65
- gr.HTML(TITLE)
66
- gr.HTML(DESCRIPTION)
67
- gr.ChatInterface(
68
- fn=stream_chat,
69
- chatbot=chatbot,
70
- fill_height=True,
71
- additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
72
- additional_inputs=[
73
- gr.Text(
74
- value="",
75
- label="System",
76
- render=False,
77
- ),
78
- gr.Slider(
79
- minimum=0,
80
- maximum=1,
81
- step=0.1,
82
- value=0.8,
83
- label="Temperature",
84
- render=False,
85
- ),
86
- gr.Slider(
87
- minimum=128,
88
- maximum=4096,
89
- step=1,
90
- value=1024,
91
- label="Max new tokens",
92
- render=False,
93
- ),
94
- ],
 
 
 
 
 
95
  )
96
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  if __name__ == "__main__":
99
- demo.launch()
 
1
+ import streamlit as st
 
 
 
 
2
  import os
3
 
4
+ from typing import Iterator
5
+ from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ # Configure page settings
10
+ st.set_page_config(
11
+ page_title="LLM Taiwan Chat",
12
+ page_icon="💬",
13
+ layout="centered"
14
+ )
15
+
16
+ # Initialize session state for chat history and system prompt
17
+ if "messages" not in st.session_state:
18
+ st.session_state.messages = []
19
+ if "system_prompt" not in st.session_state:
20
+ st.session_state.system_prompt = "你是一個產自台灣的聊天機械人, 你以台灣本地人的身份, 使用正體中文回答問題."
21
+ if "temperature" not in st.session_state:
22
+ st.session_state.temperature = 0.2
23
+ if "top_p" not in st.session_state:
24
+ st.session_state.top_p = 0.95
25
+
26
+ ## model="lianghsun/Llama-3.2-Taiwan-3B" to meta-llama/Llama-3.2-3B-Instruct
27
+
28
+ def stream_chat(prompt: str) -> Iterator[str]:
29
+ """Stream chat responses from the LLM API"""
30
+
31
+ client = InferenceClient(model="meta-llama/Llama-3.2-3B-Instruct", timeout=30, token=HF_TOKEN)
32
+
33
+
34
+ messages = []
35
+ if st.session_state.system_prompt:
36
+ messages.append({"role": "system", "content": st.session_state.system_prompt})
37
+ messages.extend(st.session_state.messages)
38
+
39
+ stream = client.chat.completions.create(
40
+ messages=messages,
41
+ model="meta-llama/Llama-3.2-3B-Instruct",
42
+ stream=True,
43
+ temperature=st.session_state.temperature,
44
+ top_p=st.session_state.top_p
45
  )
46
+
47
+ for chunk in stream:
48
+ if chunk.choices[0].delta.content is not None:
49
+ yield chunk.choices[0].delta.content
50
+
51
+ def clear_chat_history():
52
+ """Clear all chat messages and reset system prompt"""
53
+ st.session_state.messages = []
54
+ st.session_state.system_prompt = ""
55
+
56
+ def main():
57
+ st.title("💬 LLM Taiwan Chat")
58
+
59
+ # Add a clear chat button with custom styling
60
+ col1, col2 = st.columns([6, 1])
61
+ with col2:
62
+ if st.button("🗑️", type="secondary", use_container_width=True):
63
+ clear_chat_history()
64
+ st.rerun()
65
+
66
+ # Advanced options in expander
67
+ with st.expander("進階選項 ⚙️", expanded=False):
68
+ # System prompt input
69
+ system_prompt = st.text_area(
70
+ "System Prompt 設定:",
71
+ value=st.session_state.system_prompt,
72
+ help="設定 system prompt 來定義 AI 助理的行為和角色。開始對話後將無法修改。",
73
+ height=100,
74
+ disabled=len(st.session_state.messages) > 0 # 當有對話時設為唯讀
75
+ )
76
+ if not st.session_state.messages and system_prompt != st.session_state.system_prompt:
77
+ st.session_state.system_prompt = system_prompt
78
+
79
+ st.session_state.temperature = st.slider(
80
+ "Temperature",
81
+ min_value=0.0,
82
+ max_value=2.0,
83
+ value=st.session_state.temperature,
84
+ step=0.1,
85
+ help="較高的值會使輸出更加隨機,較低的值會使其更加集中和確定。"
86
+ )
87
+ st.session_state.top_p = st.slider(
88
+ "Top P",
89
+ min_value=0.1,
90
+ max_value=1.0,
91
+ value=st.session_state.top_p,
92
+ step=0.05,
93
+ help="控制模型輸出的多樣性,較低的值會使輸出更加保守。"
94
+ )
95
+
96
+ # Display chat messages
97
+ for message in st.session_state.messages:
98
+ with st.chat_message(message["role"]):
99
+ st.write(message["content"])
100
+
101
+ # Chat input
102
+ if prompt := st.chat_input("輸入您的訊息..."):
103
+ # Add user message to chat history
104
+ st.session_state.messages.append({"role": "user", "content": prompt})
105
+
106
+ # Display user message
107
+ with st.chat_message("user"):
108
+ st.write(prompt)
109
+
110
+ # Display assistant response with streaming
111
+ with st.chat_message("assistant"):
112
+ response_placeholder = st.empty()
113
+ full_response = ""
114
+
115
+ # Stream the response
116
+ for response_chunk in stream_chat(prompt):
117
+ full_response += response_chunk
118
+ response_placeholder.markdown(full_response + "▌")
119
+ response_placeholder.markdown(full_response)
120
+
121
+ # Add assistant response to chat history
122
+ st.session_state.messages.append({"role": "assistant", "content": full_response})
123
 
124
  if __name__ == "__main__":
125
+ main()