import json import numpy as np import random, math import argparse import os import json, tqdm import yaml from models.models import GroqModel import time import random def processdata(instance, noise_rate, passage_num, filename, correct_rate = 0): query = instance['query'] ans = instance['answer'] neg_num = math.ceil(passage_num * noise_rate) pos_num = passage_num - neg_num if '_int' in filename: for i in instance['positive']: random.shuffle(i) print(len(instance['positive'])) docs = [i[0] for i in instance['positive']] if len(docs) < pos_num: maxnum = max([len(i) for i in instance['positive']]) for i in range(1,maxnum): for j in instance['positive']: if len(j) > i: docs.append(j[i]) if len(docs) == pos_num: break if len(docs) == pos_num: break neg_num = passage_num - len(docs) if neg_num > 0: negative = instance['negative'][:neg_num] docs += negative elif '_fact' in filename: correct_num = math.ceil(passage_num * correct_rate) pos_num = passage_num - neg_num - correct_num indexs = list(range(len(instance['positive']))) selected = random.sample(indexs,min(len(indexs),pos_num)) docs = [instance['positive_wrong'][i] for i in selected] remain = [i for i in indexs if i not in selected] if correct_num > 0 and len(remain) > 0: docs += [instance['positive'][i] for i in random.sample(remain,min(len(remain),correct_num))] if neg_num > 0: docs += instance['negative'][:neg_num] else: if noise_rate == 1: neg_num = passage_num pos_num = 0 else: if neg_num > len(instance['negative']): neg_num = len(instance['negative']) pos_num = passage_num - neg_num elif pos_num > len(instance['positive']): pos_num = len(instance['positive']) neg_num = passage_num - pos_num positive = instance['positive'][:pos_num] negative = instance['negative'][:neg_num] docs = positive + negative random.shuffle(docs) return query, ans, docs def checkanswer(prediction, ground_truth): prediction = prediction.lower() if type(ground_truth) is not list: ground_truth = [ground_truth] labels = [] for instance in ground_truth: flag = True if type(instance) == list: flag = False instance = [i.lower() for i in instance] for i in instance: if i in prediction: flag = True break else: instance = instance.lower() if instance not in prediction: flag = False labels.append(int(flag)) return labels def getevalue(results): results = np.array(results) results = np.max(results,axis = 0) if 0 in results: return False else: return True def predict(query, ground_truth, docs, model, system, instruction, temperature, dataset): ''' label: 0 for positive, 1 for negative, -1 for not enough information ''' if len(docs) == 0: text = instruction.format(QUERY=query, DOCS='') prediction = model.generate(text, temperature) else: docs = '\n'.join(docs) text = instruction.format(QUERY=query, DOCS=docs) prediction = model.generate(text, temperature, system) if 'zh' in dataset: prediction = prediction.replace(" ","") if '信息不足' in prediction or 'insufficient information' in prediction: labels = [-1] else: labels = checkanswer(prediction, ground_truth) factlabel = 0 if '事实性错误' in prediction or 'factual errors' in prediction: factlabel = 1 return labels,prediction, factlabel def main3(args): print(args) return 'I am from evalue' def main2(params): print(' am main2') parser = argparse.ArgumentParser() parser.add_argument( '--modelname', type=str, default='chatgpt', help='model name' ) parser.add_argument( '--dataset', type=str, default='en', help='evaluetion dataset', choices=['en','zh','en_int','zh_int','en_fact','zh_fact'] ) parser.add_argument( '--api_key', type=str, default='api_key', help='api key of chatgpt' ) parser.add_argument( '--plm', type=str, default='THUDM/chatglm-6b', help='name of plm' ) parser.add_argument( '--url', type=str, default='https://api.openai.com/v1/completions', help='url of chatgpt' ) parser.add_argument( '--temp', type=float, default=0.7, help='corpus id' ) parser.add_argument( '--noise_rate', type=float, default=0.0, help='rate of noisy passages' ) parser.add_argument( '--correct_rate', type=float, default=0.0, help='rate of correct passages' ) parser.add_argument( '--passage_num', type=int, default=5, help='number of external passages' ) parser.add_argument( '--factchecking', type=bool, default=False, help='whether to fact checking' ) parser.add_argument( '--max_instances', type=int, default=None, help='Limit the number of examples to evaluate' ) help='whether to fact checking' args = parser.parse_args(params) modelname = args.modelname temperature = args.temp noise_rate = args.noise_rate passage_num = args.passage_num instances = [] with open(f'data/{args.dataset}.json','r', encoding='utf-8') as f: for i, line in enumerate(f): if args.max_instances and i >= args.max_instances: break instances.append(json.loads(line)) if 'en' in args.dataset: resultpath = 'result-en' elif 'zh' in args.dataset: resultpath = 'result-zh' if not os.path.exists(resultpath): os.mkdir(resultpath) if args.factchecking: prompt = yaml.load(open('config/instruction_fact.yaml', 'r', encoding='utf-8'), Loader=yaml.FullLoader)[args.dataset[:2]] resultpath = resultpath + '/fact' else: prompt = yaml.load(open('config/instruction.yaml', 'r', encoding='utf-8'), Loader=yaml.FullLoader)[args.dataset[:2]] system = prompt['system'] instruction = prompt['instruction'] model = GroqModel() filename = f'{resultpath}/prediction_{args.dataset}_{modelname}_temp{temperature}_noise{noise_rate}_passage{passage_num}_correct{args.correct_rate}.json' useddata = {} if os.path.exists(filename): with open(filename) as f: data = json.loads(line) useddata[data['id']] = data results = [] with open(filename,'w') as f: for instance in tqdm.tqdm(instances): #if instance['id'] in useddata and instance['query'] == useddata[instance['id']]['query'] and instance['answer'] == useddata[instance['id']]['answer']: #results.append(useddata[instance['id']]) #f.write(json.dumps(useddata[instance['id']], ensure_ascii=False)+'\n') #continue try: random.seed(2333) if passage_num == 0: query = instance['query'] ans = instance['answer'] docs = [] else: query, ans, docs = processdata(instance, noise_rate, passage_num, args.dataset, args.correct_rate) label,prediction,factlabel = predict(query, ans, docs, model,system,instruction,temperature,args.dataset) instance['label'] = label newinstance = { 'id': instance['id'], 'query': query, 'ans': ans, 'label': [-1], 'label1': label, 'prediction': prediction, 'docs': docs, 'noise_rate': noise_rate, 'factlabel': factlabel } results.append(newinstance) f.write(json.dumps(newinstance, ensure_ascii=False)+'\n') time.sleep(random.uniform(2, 4)) except Exception as e: print("Error123:", e) continue tt = 0 for i in results: label = i['label'] if noise_rate == 1 and label[0] == -1: tt += 1 elif 0 not in label and 1 in label: tt += 1 scores = { 'all_rate': (tt)/len(results), 'noise_rate': noise_rate, 'tt':tt, 'nums': len(results), } if '_fact' in args.dataset: fact_tt = 0 correct_tt = 0 for i in results: if i['factlabel'] == 1: fact_tt += 1 if 0 not in i['label']: correct_tt += 1 fact_check_rate = fact_tt/len(results) if fact_tt > 0: correct_rate = correct_tt/fact_tt else: correct_rate = 0 scores['fact_check_rate'] = fact_check_rate scores['correct_rate'] = correct_rate scores['fact_tt'] = fact_tt scores['correct_tt'] = correct_tt print(scores) json.dump(scores,open(f'{resultpath}/prediction_{args.dataset}_{modelname}_temp{temperature}_noise{noise_rate}_passage{passage_num}_correct{args.correct_rate}_result.json','w'),ensure_ascii=False,indent=4) # if __name__ == '__main__': # # main() # print('test me here')