Spaces:
Running
Running
File size: 5,611 Bytes
1c817fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import json
import re
import unicodedata
from functools import lru_cache
import wget
import os
from constants import GPT2_FOLDER, ENCODER_FILE, VOCAB_FILE, END_OF_TEXT_TOKEN
import nltk
@lru_cache()
def bytes_to_unicode():
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class Encoder:
def __init__(self, encoder, bpe_merges, errors='replace', tokenize=None):
self.encoder = encoder
self.decoder = {v:k for k,v in self.encoder.items()}
self.errors = errors
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
if tokenize is None:
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\w+| ?[^\s\w]+|\s+(?!\S)|\s+""", re.UNICODE)
self.tokenize = lambda text: re.findall(self.pat, text)
else:
self.tokenize = tokenize
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except ValueError:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
normalized_text = unicodedata.normalize('NFKC', text)
normalized_text = ''.join(c for c in normalized_text if c.isascii() and c != '\t')
normalized_text = ''.join(c for c in normalized_text if not unicodedata.category(c).startswith('C'))
for token in self.tokenize(normalized_text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8', errors='ignore'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors='replace')
decoded_text = text.replace(" .", ".").replace(" ,", ",").replace(" '", "'").replace(" ?", "?").replace(" !", "!").replace(" :", ":").replace('\n', '<br>')
sentences = nltk.sent_tokenize(decoded_text)
return ' '.join(sentences).replace("<br>", "<br>\n")
def get_encoder_gpt2():
encoder_path = os.path.join(GPT2_FOLDER, ENCODER_FILE)
vocab_path = os.path.join(GPT2_FOLDER, VOCAB_FILE)
if not os.path.exists(GPT2_FOLDER):
os.makedirs(GPT2_FOLDER)
if not os.path.exists(encoder_path):
wget.download(ENCODER_URL, out=encoder_path)
if not os.path.exists(vocab_path):
wget.download(VOCAB_URL, out=vocab_path)
with open(encoder_path, 'r') as f:
encoder = json.load(f)
with open(vocab_path, 'r', encoding="utf-8") as f:
bpe_data = f.read()
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
encoder_obj = Encoder(encoder=encoder, bpe_merges=bpe_merges)
encoder_obj.encoder[END_OF_TEXT_TOKEN] = len(encoder_obj.encoder)
encoder_obj.decoder[len(encoder_obj.decoder)] = END_OF_TEXT_TOKEN
return encoder_obj
def get_codegen_tokenizer_pure(vocab_file, merges_file):
vocab = json.load(open(vocab_file))
merges = open(merges_file, 'r', encoding="utf-8").read().split('\n')[1:-1]
bpe_merges = [tuple(m.split()) for m in merges]
byte_encoder = bytes_to_unicode()
byte_decoder = {v: k for k, v in byte_encoder.items()}
tokenizer_regex = re.compile(r'''<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+''')
tokenize = lambda text: re.findall(tokenizer_regex, text)
encoder_obj = Encoder(
encoder=vocab,
bpe_merges=bpe_merges,
byte_encoder=byte_encoder,
byte_decoder=byte_decoder,
tokenize=tokenize
)
return encoder_obj
def codegen_tokenize(text, tokenizer):
return tokenizer.encode(text)
def codegen_decode(tokens, tokenizer):
return tokenizer.decode(tokens)
|