import json
from tqdm import tqdm
import os
from sklearn.utils import shuffle
import re
import argparse


def cut_sent(para):
    para = re.sub('([。,,!?\?])([^”’])', r"\1\n\2", para)  # 单字符断句符
    para = re.sub('(\.{6})([^”’])', r"\1\n\2", para)  # 英文省略号
    para = re.sub('(\…{2})([^”’])', r"\1\n\2", para)  # 中文省略号
    para = re.sub('([。!?\?][”’])([^,。!?\?])', r'\1\n\2', para)
    # 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后,注意前面的几句都小心保留了双引号
    para = para.rstrip()  # 段尾如果有多余的\n就去掉它
    # 很多规则中会考虑分号;,但是这里
    return para.split("\n")


def search(pattern, sequence):
    n = len(pattern)
    res=[]
    for i in range(len(sequence)):
        if sequence[i:i + n] == pattern:
            res.append([i,i + n-1])
    return res

max_length=512
stride=128
def stride_split(question, context, answer, start):
    end = start + len(answer) -1
    results, n = [], 0
    max_c_len = max_length - len(question) - 3
    while True:
        left, right = n * stride, n * stride + max_c_len
        if left <= start < end <= right:
            results.append((question, context[left:right], answer, start - left, end - left))
        elif right < start or end < right:
            results.append((question, context[left:right], '', -1, -1))
        if right >= len(context):
            return results
        n += 1


def load_data(file_path,is_training=False):
    task_type='抽取任务'
    subtask_type='抽取式阅读理解'
    with open(file_path, 'r', encoding='utf8') as f:
        lines = json.loads(''.join(f.readlines()))
        result=[]
        lines = lines['data']
        for line in tqdm(lines): 
            if line['paragraphs']==[]:
                continue
            data = line['paragraphs'][0]
            context=data['context'].strip()
            for qa in data['qas']:
                question=qa['question'].strip()
                rcv=[]
                for a in qa['answers']:
                    if a not in rcv:
                        rcv.append(a)
                        split=stride_split(question, context, a['text'], a['answer_start'])
                        for sp in split:
                            choices = []
                            
                            choice = {}
                            choice['id']=qa['id']
                            choice['entity_type'] = qa['question']
                            choice['label']=0
                            entity_list=[]
                            if sp[3]>=0 and sp[4]>=0:
                                entity_list.append({'entity_name':sp[2],'entity_type':'','entity_idx':[[sp[3],sp[4]]]})
                                
                            choice['entity_list']=entity_list
                            choices.append(choice)
                                
                            if choices==[]:
                                print(data)
                                continue
                            result.append({ 'task_type':task_type,
                                            'subtask_type':subtask_type,
                                            'text':sp[1],
                                            'choices':choices,
                                            'id':0}) 
                            
        return result


def save_data(data,file_path):
    with open(file_path, 'w', encoding='utf8') as f:
        for line in data:
            json_data=json.dumps(line,ensure_ascii=False)
            f.write(json_data+'\n')



if __name__=="__main__":
    parser = argparse.ArgumentParser(description="train")
    parser.add_argument("--data_path", type=str,default="")
    parser.add_argument("--save_path", type=str,default="")

    args = parser.parse_args()
    
    
    data_path = args.data_path
    save_path = args.save_path
    
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    
    file_list=['dev','train','trial','test']
    train_data = []
    dev_data = []
    for file in file_list:
        file_path = os.path.join(data_path,file+'.json')
        data=load_data(file_path=file_path)
        if 'train' in file or 'trial' in file:
            train_data.extend(data)
        else:
            output_path = os.path.join(save_path,file+'.json')
            save_data(data,output_path)
    
    output_path = os.path.join(save_path,'train.json')
    save_data(train_data,output_path)