Update app.py
Browse files
app.py
CHANGED
@@ -1,41 +1,59 @@
|
|
1 |
-
from datasets import load_dataset
|
2 |
import gradio as gr
|
3 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
4 |
|
5 |
-
#
|
6 |
-
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
-
#
|
11 |
tokenizer = AutoTokenizer.from_pretrained("uer/gpt2-chinese-cluecorpussmall")
|
12 |
model = AutoModelForCausalLM.from_pretrained("uer/gpt2-chinese-cluecorpussmall")
|
13 |
|
14 |
def generate_response(user_input, history):
|
15 |
-
#
|
16 |
if not user_input:
|
17 |
-
return "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
try:
|
19 |
-
# 对用户输入编码,加上结束符
|
20 |
inputs = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt")
|
21 |
reply_ids = model.generate(
|
22 |
-
inputs,
|
23 |
-
max_length=
|
24 |
pad_token_id=tokenizer.eos_token_id,
|
25 |
-
no_repeat_ngram_size=
|
26 |
)
|
27 |
-
|
28 |
-
return
|
29 |
except Exception as e:
|
30 |
-
print("
|
31 |
-
return "
|
32 |
-
|
33 |
|
|
|
34 |
demo = gr.ChatInterface(
|
35 |
fn=generate_response,
|
36 |
-
|
37 |
-
|
38 |
)
|
39 |
|
40 |
demo.launch()
|
41 |
-
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
|
3 |
|
4 |
+
# 初始化中文情绪检测模型(请确保此模型适用于中文情绪分类)
|
5 |
+
emotion_classifier = pipeline(
|
6 |
+
"text-classification",
|
7 |
+
model="uer/roberta-base-finetuned-emotion-chinese",
|
8 |
+
return_all_scores=False
|
9 |
+
)
|
10 |
+
|
11 |
+
# 定义需要安抚的小朋友情绪列表及对应的回复语句
|
12 |
+
SAFE_EMOTIONS = {
|
13 |
+
"生气": "当你生气的时候,可以试着深呼吸哦,放松一下自己。",
|
14 |
+
"郁闷": "听起来你有点郁闷,不如试试做些你喜欢的事情,让心情变好一点。",
|
15 |
+
"难受": "觉得难受的时候,记得告诉家人或老师,他们一定会帮助你的。",
|
16 |
+
"难过": "难过的时候可以试试和朋友聊聊,或者做点轻松的活动,让自己慢慢好起来。"
|
17 |
+
}
|
18 |
|
19 |
+
# 初始化中文对话生成模型(这里使用的是一个中文GPT2模型)
|
20 |
tokenizer = AutoTokenizer.from_pretrained("uer/gpt2-chinese-cluecorpussmall")
|
21 |
model = AutoModelForCausalLM.from_pretrained("uer/gpt2-chinese-cluecorpussmall")
|
22 |
|
23 |
def generate_response(user_input, history):
|
24 |
+
# 如果输入为空,则提示输入内容
|
25 |
if not user_input:
|
26 |
+
return "请跟我说说你的心情哦!"
|
27 |
+
|
28 |
+
# 检测小朋友的情绪
|
29 |
+
emotion_result = emotion_classifier(user_input)[0]
|
30 |
+
detected_emotion = emotion_result["label"]
|
31 |
+
print("检测到情绪:", detected_emotion, ",置信度:", emotion_result["score"])
|
32 |
+
|
33 |
+
# 如果检测到的情绪属于负面情绪,则直接返回安抚回复
|
34 |
+
if detected_emotion in SAFE_EMOTIONS:
|
35 |
+
return SAFE_EMOTIONS[detected_emotion]
|
36 |
+
|
37 |
+
# 否则,使用对话生成模型生成回复
|
38 |
try:
|
|
|
39 |
inputs = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt")
|
40 |
reply_ids = model.generate(
|
41 |
+
inputs,
|
42 |
+
max_length=100,
|
43 |
pad_token_id=tokenizer.eos_token_id,
|
44 |
+
no_repeat_ngram_size=2
|
45 |
)
|
46 |
+
response_text = tokenizer.decode(reply_ids[:, inputs.shape[-1]:][0], skip_special_tokens=True)
|
47 |
+
return response_text
|
48 |
except Exception as e:
|
49 |
+
print("生成回复时发生异常:", e)
|
50 |
+
return "抱歉,我现在有点小问题,能再跟我说说吗?"
|
|
|
51 |
|
52 |
+
# 创建 Gradio 聊天界面
|
53 |
demo = gr.ChatInterface(
|
54 |
fn=generate_response,
|
55 |
+
title="儿童情绪安抚助手",
|
56 |
+
description="跟我聊聊天,如果你觉得生气、郁闷或难受,我会帮你放松心情哦!"
|
57 |
)
|
58 |
|
59 |
demo.launch()
|
|