File size: 5,611 Bytes
1c817fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import json
import re
import unicodedata
from functools import lru_cache
import wget
import os
from constants import GPT2_FOLDER, ENCODER_FILE, VOCAB_FILE, END_OF_TEXT_TOKEN
import nltk

@lru_cache()
def bytes_to_unicode():
    bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
    cs = bs[:]
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8 + n)
            n += 1
    cs = [chr(n) for n in cs]
    return dict(zip(bs, cs))

def get_pairs(word):
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs

class Encoder:
    def __init__(self, encoder, bpe_merges, errors='replace', tokenize=None):
        self.encoder = encoder
        self.decoder = {v:k for k,v in self.encoder.items()}
        self.errors = errors
        self.byte_encoder = bytes_to_unicode()
        self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
        self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
        self.cache = {}
        if tokenize is None:
            self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\w+| ?[^\s\w]+|\s+(?!\S)|\s+""", re.UNICODE)
            self.tokenize = lambda text: re.findall(self.pat, text)
        else:
            self.tokenize = tokenize

    def bpe(self, token):
        if token in self.cache:
            return self.cache[token]
        word = tuple(token)
        pairs = get_pairs(word)
        if not pairs:
            return token
        while True:
            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
            if bigram not in self.bpe_ranks:
                break
            first, second = bigram
            new_word = []
            i = 0
            while i < len(word):
                try:
                    j = word.index(first, i)
                    new_word.extend(word[i:j])
                    i = j
                except ValueError:
                    new_word.extend(word[i:])
                    break
                if word[i] == first and i < len(word)-1 and word[i+1] == second:
                    new_word.append(first+second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_word = tuple(new_word)
            word = new_word
            if len(word) == 1:
                break
            else:
                pairs = get_pairs(word)
        word = ' '.join(word)
        self.cache[token] = word
        return word

    def encode(self, text):
        bpe_tokens = []
        normalized_text = unicodedata.normalize('NFKC', text)
        normalized_text = ''.join(c for c in normalized_text if c.isascii() and c != '\t')
        normalized_text = ''.join(c for c in normalized_text if not unicodedata.category(c).startswith('C'))
        for token in self.tokenize(normalized_text):
            token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8', errors='ignore'))
            bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
        return bpe_tokens

    def decode(self, tokens):
        text = ''.join([self.decoder[token] for token in tokens])
        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors='replace')
        decoded_text = text.replace(" .", ".").replace(" ,", ",").replace(" '", "'").replace(" ?", "?").replace(" !", "!").replace(" :", ":").replace('\n', '<br>')
        sentences = nltk.sent_tokenize(decoded_text)
        return ' '.join(sentences).replace("<br>", "<br>\n")

def get_encoder_gpt2():
    encoder_path = os.path.join(GPT2_FOLDER, ENCODER_FILE)
    vocab_path = os.path.join(GPT2_FOLDER, VOCAB_FILE)
    if not os.path.exists(GPT2_FOLDER):
        os.makedirs(GPT2_FOLDER)
    if not os.path.exists(encoder_path):
        wget.download(ENCODER_URL, out=encoder_path)
    if not os.path.exists(vocab_path):
        wget.download(VOCAB_URL, out=vocab_path)

    with open(encoder_path, 'r') as f:
        encoder = json.load(f)
    with open(vocab_path, 'r', encoding="utf-8") as f:
        bpe_data = f.read()
    bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
    encoder_obj = Encoder(encoder=encoder, bpe_merges=bpe_merges)
    encoder_obj.encoder[END_OF_TEXT_TOKEN] = len(encoder_obj.encoder)
    encoder_obj.decoder[len(encoder_obj.decoder)] = END_OF_TEXT_TOKEN
    return encoder_obj

def get_codegen_tokenizer_pure(vocab_file, merges_file):
    vocab = json.load(open(vocab_file))
    merges = open(merges_file, 'r', encoding="utf-8").read().split('\n')[1:-1]
    bpe_merges = [tuple(m.split()) for m in merges]
    byte_encoder = bytes_to_unicode()
    byte_decoder = {v: k for k, v in byte_encoder.items()}
    tokenizer_regex = re.compile(r'''<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+''')
    tokenize = lambda text: re.findall(tokenizer_regex, text)
    encoder_obj = Encoder(
        encoder=vocab,
        bpe_merges=bpe_merges,
        byte_encoder=byte_encoder,
        byte_decoder=byte_decoder,
        tokenize=tokenize
    )
    return encoder_obj

def codegen_tokenize(text, tokenizer):
    return tokenizer.encode(text)

def codegen_decode(tokens, tokenizer):
    return tokenizer.decode(tokens)