giridhar99 commited on
Commit
3229aff
·
verified ·
1 Parent(s): 2d6b18f

# app.py
import streamlit as st
from evalue import main as run_eval

st.title("RGB Evaluation Tool")

dataset = st.text_input("Dataset name (e.g., en, en_int):", "en")
noise = st.slider("Noise Rate", 0.0, 1.0, 0.6)
model = st.text_input("Model name:", "moonshotai/kimi-k2-instruct")

if st.button("Run Evaluation"):
args = ["--dataset", dataset, "--noise_rate", str(noise), "--modelname", model]
st.write("Running evaluation...")
run_eval(args)
st.success("Evaluation completed!")

Files changed (1) hide show
  1. evalue.py +308 -0
evalue.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ import random, math
4
+ import argparse,torch
5
+ import os
6
+ import json, tqdm, requests
7
+ import yaml
8
+ from models.models import *
9
+
10
+
11
+
12
+ def processdata(instance, noise_rate, passage_num, filename, correct_rate = 0):
13
+ query = instance['query']
14
+ ans = instance['answer']
15
+
16
+ neg_num = math.ceil(passage_num * noise_rate)
17
+ pos_num = passage_num - neg_num
18
+
19
+ if '_int' in filename:
20
+ for i in instance['positive']:
21
+ random.shuffle(i)
22
+ print(len(instance['positive']))
23
+ docs = [i[0] for i in instance['positive']]
24
+ if len(docs) < pos_num:
25
+ maxnum = max([len(i) for i in instance['positive']])
26
+ for i in range(1,maxnum):
27
+ for j in instance['positive']:
28
+ if len(j) > i:
29
+ docs.append(j[i])
30
+ if len(docs) == pos_num:
31
+ break
32
+ if len(docs) == pos_num:
33
+ break
34
+ neg_num = passage_num - len(docs)
35
+ if neg_num > 0:
36
+ negative = instance['negative'][:neg_num]
37
+ docs += negative
38
+ elif '_fact' in filename:
39
+ correct_num = math.ceil(passage_num * correct_rate)
40
+ pos_num = passage_num - neg_num - correct_num
41
+ indexs = list(range(len(instance['positive'])))
42
+ selected = random.sample(indexs,min(len(indexs),pos_num))
43
+ docs = [instance['positive_wrong'][i] for i in selected]
44
+ remain = [i for i in indexs if i not in selected]
45
+ if correct_num > 0 and len(remain) > 0:
46
+ docs += [instance['positive'][i] for i in random.sample(remain,min(len(remain),correct_num))]
47
+ if neg_num > 0:
48
+ docs += instance['negative'][:neg_num]
49
+ else:
50
+ if noise_rate == 1:
51
+ neg_num = passage_num
52
+ pos_num = 0
53
+ else:
54
+ if neg_num > len(instance['negative']):
55
+ neg_num = len(instance['negative'])
56
+ pos_num = passage_num - neg_num
57
+ elif pos_num > len(instance['positive']):
58
+ pos_num = len(instance['positive'])
59
+ neg_num = passage_num - pos_num
60
+
61
+
62
+ positive = instance['positive'][:pos_num]
63
+ negative = instance['negative'][:neg_num]
64
+
65
+ docs = positive + negative
66
+
67
+ random.shuffle(docs)
68
+
69
+ return query, ans, docs
70
+
71
+
72
+ def checkanswer(prediction, ground_truth):
73
+ prediction = prediction.lower()
74
+ if type(ground_truth) is not list:
75
+ ground_truth = [ground_truth]
76
+ labels = []
77
+ for instance in ground_truth:
78
+ flag = True
79
+ if type(instance) == list:
80
+ flag = False
81
+ instance = [i.lower() for i in instance]
82
+ for i in instance:
83
+ if i in prediction:
84
+ flag = True
85
+ break
86
+ else:
87
+ instance = instance.lower()
88
+ if instance not in prediction:
89
+ flag = False
90
+ labels.append(int(flag))
91
+ return labels
92
+
93
+ def getevalue(results):
94
+ results = np.array(results)
95
+ results = np.max(results,axis = 0)
96
+ if 0 in results:
97
+ return False
98
+ else:
99
+ return True
100
+
101
+
102
+ def predict(query, ground_truth, docs, model, system, instruction, temperature, dataset):
103
+
104
+ '''
105
+ label: 0 for positive, 1 for negative, -1 for not enough information
106
+
107
+ '''
108
+ if len(docs) == 0:
109
+ text = instruction.format(QUERY=query, DOCS='')
110
+ prediction = model.generate(text, temperature)
111
+
112
+ else:
113
+ docs = '\n'.join(docs)
114
+ text = instruction.format(QUERY=query, DOCS=docs)
115
+ prediction = model.generate(text, temperature, system)
116
+
117
+ if 'zh' in dataset:
118
+ prediction = prediction.replace(" ","")
119
+ if '信息不足' in prediction or 'insufficient information' in prediction:
120
+ labels = [-1]
121
+ else:
122
+ labels = checkanswer(prediction, ground_truth)
123
+ factlabel = 0
124
+
125
+ if '事实性错误' in prediction or 'factual errors' in prediction:
126
+ factlabel = 1
127
+
128
+ return labels,prediction, factlabel
129
+
130
+ if __name__ == '__main__':
131
+
132
+ parser = argparse.ArgumentParser()
133
+
134
+ parser.add_argument(
135
+ '--modelname', type=str, default='chatgpt',
136
+ help='model name'
137
+ )
138
+ parser.add_argument(
139
+ '--dataset', type=str, default='en',
140
+ help='evaluetion dataset',
141
+ choices=['en','zh','en_int','zh_int','en_fact','zh_fact']
142
+ )
143
+ parser.add_argument(
144
+ '--api_key', type=str, default='api_key',
145
+ help='api key of chatgpt'
146
+ )
147
+ parser.add_argument(
148
+ '--plm', type=str, default='THUDM/chatglm-6b',
149
+ help='name of plm'
150
+ )
151
+ parser.add_argument(
152
+ '--url', type=str, default='https://api.openai.com/v1/completions',
153
+ help='url of chatgpt'
154
+ )
155
+ parser.add_argument(
156
+ '--temp', type=float, default=0.7,
157
+ help='corpus id'
158
+ )
159
+ parser.add_argument(
160
+ '--noise_rate', type=float, default=0.0,
161
+ help='rate of noisy passages'
162
+ )
163
+ parser.add_argument(
164
+ '--correct_rate', type=float, default=0.0,
165
+ help='rate of correct passages'
166
+ )
167
+ parser.add_argument(
168
+ '--passage_num', type=int, default=5,
169
+ help='number of external passages'
170
+ )
171
+ parser.add_argument(
172
+ '--factchecking', type=bool, default=False,
173
+ help='whether to fact checking'
174
+ )
175
+ parser.add_argument(
176
+ '--max_instances', type=int, default=None,
177
+ help='Limit the number of examples to evaluate'
178
+ )
179
+ help='whether to fact checking'
180
+
181
+
182
+
183
+ args = parser.parse_args()
184
+
185
+ modelname = args.modelname
186
+ temperature = args.temp
187
+ noise_rate = args.noise_rate
188
+ passage_num = args.passage_num
189
+
190
+ instances = []
191
+ with open(f'data/{args.dataset}.json','r') as f:
192
+ for i, line in enumerate(f):
193
+ if args.max_instances and i >= args.max_instances:
194
+ break
195
+ instances.append(json.loads(line))
196
+ if 'en' in args.dataset:
197
+ resultpath = 'result-en'
198
+ elif 'zh' in args.dataset:
199
+ resultpath = 'result-zh'
200
+ if not os.path.exists(resultpath):
201
+ os.mkdir(resultpath)
202
+
203
+ if args.factchecking:
204
+ prompt = yaml.load(open('config/instruction_fact.yaml', 'r'), Loader=yaml.FullLoader)[args.dataset[:2]]
205
+ resultpath = resultpath + '/fact'
206
+ else:
207
+ prompt = yaml.load(open('config/instruction.yaml', 'r'), Loader=yaml.FullLoader)[args.dataset[:2]]
208
+
209
+ system = prompt['system']
210
+ instruction = prompt['instruction']
211
+
212
+ if modelname == 'chatgpt':
213
+ model = OpenAIAPIModel(api_key = args.api_key, url = args.url)
214
+ elif 'Llama-2' in modelname:
215
+ model = LLama2(plm = args.plm)
216
+ elif modelname == "groq":
217
+ model = GroqModel(api_key=args.api_key)
218
+ elif 'chatglm' in modelname:
219
+ model = ChatglmModel(plm = args.plm)
220
+ elif 'moss' in modelname:
221
+ model = Moss(plm = args.plm)
222
+ elif 'vicuna' in modelname:
223
+ model = Vicuna(plm = args.plm)
224
+ elif 'Qwen' in modelname:
225
+ model = Qwen(plm = args.plm)
226
+ elif 'Baichuan' in modelname:
227
+ model = Baichuan(plm = args.plm)
228
+ elif 'WizardLM' in modelname:
229
+ model = WizardLM(plm = args.plm)
230
+ elif 'BELLE' in modelname:
231
+ model = BELLE(plm = args.plm)
232
+
233
+
234
+ filename = f'{resultpath}/prediction_{args.dataset}_{modelname}_temp{temperature}_noise{noise_rate}_passage{passage_num}_correct{args.correct_rate}.json'
235
+ useddata = {}
236
+ if os.path.exists(filename):
237
+ with open(filename) as f:
238
+ data = json.loads(line)
239
+ useddata[data['id']] = data
240
+
241
+ results = []
242
+ with open(filename,'w') as f:
243
+ for instance in tqdm.tqdm(instances):
244
+ #if instance['id'] in useddata and instance['query'] == useddata[instance['id']]['query'] and instance['answer'] == useddata[instance['id']]['answer']:
245
+ #results.append(useddata[instance['id']])
246
+ #f.write(json.dumps(useddata[instance['id']], ensure_ascii=False)+'\n')
247
+ #continue
248
+ try:
249
+ random.seed(2333)
250
+ if passage_num == 0:
251
+ query = instance['query']
252
+ ans = instance['answer']
253
+ docs = []
254
+ else:
255
+ query, ans, docs = processdata(instance, noise_rate, passage_num, args.dataset, args.correct_rate)
256
+ label,prediction,factlabel = predict(query, ans, docs, model,system,instruction,temperature,args.dataset)
257
+
258
+ instance['label'] = label
259
+ newinstance = {
260
+ 'id': instance['id'],
261
+ 'query': query,
262
+ 'ans': ans,
263
+ 'label': [-1],
264
+ 'label1': label,
265
+ 'prediction': prediction,
266
+ 'docs': docs,
267
+ 'noise_rate': noise_rate,
268
+ 'factlabel': factlabel
269
+ }
270
+ results.append(newinstance)
271
+ f.write(json.dumps(newinstance, ensure_ascii=False)+'\n')
272
+ except Exception as e:
273
+ print("Error:", e)
274
+ continue
275
+ tt = 0
276
+ for i in results:
277
+ label = i['label']
278
+ if noise_rate == 1 and label[0] == -1:
279
+ tt += 1
280
+ elif 0 not in label and 1 in label:
281
+ tt += 1
282
+ scores = {
283
+ 'all_rate': (tt)/len(results),
284
+ 'noise_rate': noise_rate,
285
+ 'tt':tt,
286
+ 'nums': len(results),
287
+ }
288
+ if '_fact' in args.dataset:
289
+ fact_tt = 0
290
+ correct_tt = 0
291
+ for i in results:
292
+ if i['factlabel'] == 1:
293
+ fact_tt += 1
294
+ if 0 not in i['label']:
295
+ correct_tt += 1
296
+ fact_check_rate = fact_tt/len(results)
297
+ if fact_tt > 0:
298
+ correct_rate = correct_tt/fact_tt
299
+ else:
300
+ correct_rate = 0
301
+ scores['fact_check_rate'] = fact_check_rate
302
+ scores['correct_rate'] = correct_rate
303
+ scores['fact_tt'] = fact_tt
304
+ scores['correct_tt'] = correct_tt
305
+
306
+
307
+ print(scores)
308
+ json.dump(scores,open(f'{resultpath}/prediction_{args.dataset}_{modelname}_temp{temperature}_noise{noise_rate}_passage{passage_num}_correct{args.correct_rate}_result.json','w'),ensure_ascii=False,indent=4)