Spaces:
Sleeping
Sleeping
app.py
Browse files# app.py
import streamlit as st
from evalue import main as run_eval
st.title("RGB Evaluation Tool")
dataset = st.text_input("Dataset name (e.g., en, en_int):", "en")
noise = st.slider("Noise Rate", 0.0, 1.0, 0.6)
model = st.text_input("Model name:", "moonshotai/kimi-k2-instruct")
if st.button("Run Evaluation"):
args = ["--dataset", dataset, "--noise_rate", str(noise), "--modelname", model]
st.write("Running evaluation...")
run_eval(args)
st.success("Evaluation completed!")
evalue.py
ADDED
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import numpy as np
|
3 |
+
import random, math
|
4 |
+
import argparse,torch
|
5 |
+
import os
|
6 |
+
import json, tqdm, requests
|
7 |
+
import yaml
|
8 |
+
from models.models import *
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
def processdata(instance, noise_rate, passage_num, filename, correct_rate = 0):
|
13 |
+
query = instance['query']
|
14 |
+
ans = instance['answer']
|
15 |
+
|
16 |
+
neg_num = math.ceil(passage_num * noise_rate)
|
17 |
+
pos_num = passage_num - neg_num
|
18 |
+
|
19 |
+
if '_int' in filename:
|
20 |
+
for i in instance['positive']:
|
21 |
+
random.shuffle(i)
|
22 |
+
print(len(instance['positive']))
|
23 |
+
docs = [i[0] for i in instance['positive']]
|
24 |
+
if len(docs) < pos_num:
|
25 |
+
maxnum = max([len(i) for i in instance['positive']])
|
26 |
+
for i in range(1,maxnum):
|
27 |
+
for j in instance['positive']:
|
28 |
+
if len(j) > i:
|
29 |
+
docs.append(j[i])
|
30 |
+
if len(docs) == pos_num:
|
31 |
+
break
|
32 |
+
if len(docs) == pos_num:
|
33 |
+
break
|
34 |
+
neg_num = passage_num - len(docs)
|
35 |
+
if neg_num > 0:
|
36 |
+
negative = instance['negative'][:neg_num]
|
37 |
+
docs += negative
|
38 |
+
elif '_fact' in filename:
|
39 |
+
correct_num = math.ceil(passage_num * correct_rate)
|
40 |
+
pos_num = passage_num - neg_num - correct_num
|
41 |
+
indexs = list(range(len(instance['positive'])))
|
42 |
+
selected = random.sample(indexs,min(len(indexs),pos_num))
|
43 |
+
docs = [instance['positive_wrong'][i] for i in selected]
|
44 |
+
remain = [i for i in indexs if i not in selected]
|
45 |
+
if correct_num > 0 and len(remain) > 0:
|
46 |
+
docs += [instance['positive'][i] for i in random.sample(remain,min(len(remain),correct_num))]
|
47 |
+
if neg_num > 0:
|
48 |
+
docs += instance['negative'][:neg_num]
|
49 |
+
else:
|
50 |
+
if noise_rate == 1:
|
51 |
+
neg_num = passage_num
|
52 |
+
pos_num = 0
|
53 |
+
else:
|
54 |
+
if neg_num > len(instance['negative']):
|
55 |
+
neg_num = len(instance['negative'])
|
56 |
+
pos_num = passage_num - neg_num
|
57 |
+
elif pos_num > len(instance['positive']):
|
58 |
+
pos_num = len(instance['positive'])
|
59 |
+
neg_num = passage_num - pos_num
|
60 |
+
|
61 |
+
|
62 |
+
positive = instance['positive'][:pos_num]
|
63 |
+
negative = instance['negative'][:neg_num]
|
64 |
+
|
65 |
+
docs = positive + negative
|
66 |
+
|
67 |
+
random.shuffle(docs)
|
68 |
+
|
69 |
+
return query, ans, docs
|
70 |
+
|
71 |
+
|
72 |
+
def checkanswer(prediction, ground_truth):
|
73 |
+
prediction = prediction.lower()
|
74 |
+
if type(ground_truth) is not list:
|
75 |
+
ground_truth = [ground_truth]
|
76 |
+
labels = []
|
77 |
+
for instance in ground_truth:
|
78 |
+
flag = True
|
79 |
+
if type(instance) == list:
|
80 |
+
flag = False
|
81 |
+
instance = [i.lower() for i in instance]
|
82 |
+
for i in instance:
|
83 |
+
if i in prediction:
|
84 |
+
flag = True
|
85 |
+
break
|
86 |
+
else:
|
87 |
+
instance = instance.lower()
|
88 |
+
if instance not in prediction:
|
89 |
+
flag = False
|
90 |
+
labels.append(int(flag))
|
91 |
+
return labels
|
92 |
+
|
93 |
+
def getevalue(results):
|
94 |
+
results = np.array(results)
|
95 |
+
results = np.max(results,axis = 0)
|
96 |
+
if 0 in results:
|
97 |
+
return False
|
98 |
+
else:
|
99 |
+
return True
|
100 |
+
|
101 |
+
|
102 |
+
def predict(query, ground_truth, docs, model, system, instruction, temperature, dataset):
|
103 |
+
|
104 |
+
'''
|
105 |
+
label: 0 for positive, 1 for negative, -1 for not enough information
|
106 |
+
|
107 |
+
'''
|
108 |
+
if len(docs) == 0:
|
109 |
+
text = instruction.format(QUERY=query, DOCS='')
|
110 |
+
prediction = model.generate(text, temperature)
|
111 |
+
|
112 |
+
else:
|
113 |
+
docs = '\n'.join(docs)
|
114 |
+
text = instruction.format(QUERY=query, DOCS=docs)
|
115 |
+
prediction = model.generate(text, temperature, system)
|
116 |
+
|
117 |
+
if 'zh' in dataset:
|
118 |
+
prediction = prediction.replace(" ","")
|
119 |
+
if '信息不足' in prediction or 'insufficient information' in prediction:
|
120 |
+
labels = [-1]
|
121 |
+
else:
|
122 |
+
labels = checkanswer(prediction, ground_truth)
|
123 |
+
factlabel = 0
|
124 |
+
|
125 |
+
if '事实性错误' in prediction or 'factual errors' in prediction:
|
126 |
+
factlabel = 1
|
127 |
+
|
128 |
+
return labels,prediction, factlabel
|
129 |
+
|
130 |
+
if __name__ == '__main__':
|
131 |
+
|
132 |
+
parser = argparse.ArgumentParser()
|
133 |
+
|
134 |
+
parser.add_argument(
|
135 |
+
'--modelname', type=str, default='chatgpt',
|
136 |
+
help='model name'
|
137 |
+
)
|
138 |
+
parser.add_argument(
|
139 |
+
'--dataset', type=str, default='en',
|
140 |
+
help='evaluetion dataset',
|
141 |
+
choices=['en','zh','en_int','zh_int','en_fact','zh_fact']
|
142 |
+
)
|
143 |
+
parser.add_argument(
|
144 |
+
'--api_key', type=str, default='api_key',
|
145 |
+
help='api key of chatgpt'
|
146 |
+
)
|
147 |
+
parser.add_argument(
|
148 |
+
'--plm', type=str, default='THUDM/chatglm-6b',
|
149 |
+
help='name of plm'
|
150 |
+
)
|
151 |
+
parser.add_argument(
|
152 |
+
'--url', type=str, default='https://api.openai.com/v1/completions',
|
153 |
+
help='url of chatgpt'
|
154 |
+
)
|
155 |
+
parser.add_argument(
|
156 |
+
'--temp', type=float, default=0.7,
|
157 |
+
help='corpus id'
|
158 |
+
)
|
159 |
+
parser.add_argument(
|
160 |
+
'--noise_rate', type=float, default=0.0,
|
161 |
+
help='rate of noisy passages'
|
162 |
+
)
|
163 |
+
parser.add_argument(
|
164 |
+
'--correct_rate', type=float, default=0.0,
|
165 |
+
help='rate of correct passages'
|
166 |
+
)
|
167 |
+
parser.add_argument(
|
168 |
+
'--passage_num', type=int, default=5,
|
169 |
+
help='number of external passages'
|
170 |
+
)
|
171 |
+
parser.add_argument(
|
172 |
+
'--factchecking', type=bool, default=False,
|
173 |
+
help='whether to fact checking'
|
174 |
+
)
|
175 |
+
parser.add_argument(
|
176 |
+
'--max_instances', type=int, default=None,
|
177 |
+
help='Limit the number of examples to evaluate'
|
178 |
+
)
|
179 |
+
help='whether to fact checking'
|
180 |
+
|
181 |
+
|
182 |
+
|
183 |
+
args = parser.parse_args()
|
184 |
+
|
185 |
+
modelname = args.modelname
|
186 |
+
temperature = args.temp
|
187 |
+
noise_rate = args.noise_rate
|
188 |
+
passage_num = args.passage_num
|
189 |
+
|
190 |
+
instances = []
|
191 |
+
with open(f'data/{args.dataset}.json','r') as f:
|
192 |
+
for i, line in enumerate(f):
|
193 |
+
if args.max_instances and i >= args.max_instances:
|
194 |
+
break
|
195 |
+
instances.append(json.loads(line))
|
196 |
+
if 'en' in args.dataset:
|
197 |
+
resultpath = 'result-en'
|
198 |
+
elif 'zh' in args.dataset:
|
199 |
+
resultpath = 'result-zh'
|
200 |
+
if not os.path.exists(resultpath):
|
201 |
+
os.mkdir(resultpath)
|
202 |
+
|
203 |
+
if args.factchecking:
|
204 |
+
prompt = yaml.load(open('config/instruction_fact.yaml', 'r'), Loader=yaml.FullLoader)[args.dataset[:2]]
|
205 |
+
resultpath = resultpath + '/fact'
|
206 |
+
else:
|
207 |
+
prompt = yaml.load(open('config/instruction.yaml', 'r'), Loader=yaml.FullLoader)[args.dataset[:2]]
|
208 |
+
|
209 |
+
system = prompt['system']
|
210 |
+
instruction = prompt['instruction']
|
211 |
+
|
212 |
+
if modelname == 'chatgpt':
|
213 |
+
model = OpenAIAPIModel(api_key = args.api_key, url = args.url)
|
214 |
+
elif 'Llama-2' in modelname:
|
215 |
+
model = LLama2(plm = args.plm)
|
216 |
+
elif modelname == "groq":
|
217 |
+
model = GroqModel(api_key=args.api_key)
|
218 |
+
elif 'chatglm' in modelname:
|
219 |
+
model = ChatglmModel(plm = args.plm)
|
220 |
+
elif 'moss' in modelname:
|
221 |
+
model = Moss(plm = args.plm)
|
222 |
+
elif 'vicuna' in modelname:
|
223 |
+
model = Vicuna(plm = args.plm)
|
224 |
+
elif 'Qwen' in modelname:
|
225 |
+
model = Qwen(plm = args.plm)
|
226 |
+
elif 'Baichuan' in modelname:
|
227 |
+
model = Baichuan(plm = args.plm)
|
228 |
+
elif 'WizardLM' in modelname:
|
229 |
+
model = WizardLM(plm = args.plm)
|
230 |
+
elif 'BELLE' in modelname:
|
231 |
+
model = BELLE(plm = args.plm)
|
232 |
+
|
233 |
+
|
234 |
+
filename = f'{resultpath}/prediction_{args.dataset}_{modelname}_temp{temperature}_noise{noise_rate}_passage{passage_num}_correct{args.correct_rate}.json'
|
235 |
+
useddata = {}
|
236 |
+
if os.path.exists(filename):
|
237 |
+
with open(filename) as f:
|
238 |
+
data = json.loads(line)
|
239 |
+
useddata[data['id']] = data
|
240 |
+
|
241 |
+
results = []
|
242 |
+
with open(filename,'w') as f:
|
243 |
+
for instance in tqdm.tqdm(instances):
|
244 |
+
#if instance['id'] in useddata and instance['query'] == useddata[instance['id']]['query'] and instance['answer'] == useddata[instance['id']]['answer']:
|
245 |
+
#results.append(useddata[instance['id']])
|
246 |
+
#f.write(json.dumps(useddata[instance['id']], ensure_ascii=False)+'\n')
|
247 |
+
#continue
|
248 |
+
try:
|
249 |
+
random.seed(2333)
|
250 |
+
if passage_num == 0:
|
251 |
+
query = instance['query']
|
252 |
+
ans = instance['answer']
|
253 |
+
docs = []
|
254 |
+
else:
|
255 |
+
query, ans, docs = processdata(instance, noise_rate, passage_num, args.dataset, args.correct_rate)
|
256 |
+
label,prediction,factlabel = predict(query, ans, docs, model,system,instruction,temperature,args.dataset)
|
257 |
+
|
258 |
+
instance['label'] = label
|
259 |
+
newinstance = {
|
260 |
+
'id': instance['id'],
|
261 |
+
'query': query,
|
262 |
+
'ans': ans,
|
263 |
+
'label': [-1],
|
264 |
+
'label1': label,
|
265 |
+
'prediction': prediction,
|
266 |
+
'docs': docs,
|
267 |
+
'noise_rate': noise_rate,
|
268 |
+
'factlabel': factlabel
|
269 |
+
}
|
270 |
+
results.append(newinstance)
|
271 |
+
f.write(json.dumps(newinstance, ensure_ascii=False)+'\n')
|
272 |
+
except Exception as e:
|
273 |
+
print("Error:", e)
|
274 |
+
continue
|
275 |
+
tt = 0
|
276 |
+
for i in results:
|
277 |
+
label = i['label']
|
278 |
+
if noise_rate == 1 and label[0] == -1:
|
279 |
+
tt += 1
|
280 |
+
elif 0 not in label and 1 in label:
|
281 |
+
tt += 1
|
282 |
+
scores = {
|
283 |
+
'all_rate': (tt)/len(results),
|
284 |
+
'noise_rate': noise_rate,
|
285 |
+
'tt':tt,
|
286 |
+
'nums': len(results),
|
287 |
+
}
|
288 |
+
if '_fact' in args.dataset:
|
289 |
+
fact_tt = 0
|
290 |
+
correct_tt = 0
|
291 |
+
for i in results:
|
292 |
+
if i['factlabel'] == 1:
|
293 |
+
fact_tt += 1
|
294 |
+
if 0 not in i['label']:
|
295 |
+
correct_tt += 1
|
296 |
+
fact_check_rate = fact_tt/len(results)
|
297 |
+
if fact_tt > 0:
|
298 |
+
correct_rate = correct_tt/fact_tt
|
299 |
+
else:
|
300 |
+
correct_rate = 0
|
301 |
+
scores['fact_check_rate'] = fact_check_rate
|
302 |
+
scores['correct_rate'] = correct_rate
|
303 |
+
scores['fact_tt'] = fact_tt
|
304 |
+
scores['correct_tt'] = correct_tt
|
305 |
+
|
306 |
+
|
307 |
+
print(scores)
|
308 |
+
json.dump(scores,open(f'{resultpath}/prediction_{args.dataset}_{modelname}_temp{temperature}_noise{noise_rate}_passage{passage_num}_correct{args.correct_rate}_result.json','w'),ensure_ascii=False,indent=4)
|