import os
import gradio as gr
from html import escape
from transformers import AutoTokenizer
def get_available_models():
"""获取models目录下所有包含tokenizer.json的模型"""
models_dir = "models"
if not os.path.exists(models_dir):
return []
available_models = []
for model_name in os.listdir(models_dir):
model_path = os.path.join(models_dir, model_name)
tokenizer_file = os.path.join(model_path, "config.json")
if os.path.isdir(model_path) and os.path.isfile(tokenizer_file):
available_models.append(model_name)
return sorted(available_models)
def tokenize_text(model_name, text):
"""处理tokenize请求"""
if not model_name:
return "Please choose a model and input some texts", 0, 0
if not text:
text = "Please choose a model and input some texts"
try:
# 加载tokenizer
model_path = os.path.join("models", model_name)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, device_map="cpu")
# Tokenize处理
input_ids = tokenizer.encode(text, add_special_tokens=True)
# 生成带颜色的HTML
colors = ["#A8D8EA", "#AA96DA", "#FCBAD3"]
html_parts = []
for i, token_id in enumerate(input_ids):
# 转义HTML特殊字符
safe_token = escape(tokenizer.decode(token_id))
# 交替颜色
color = colors[i % len(colors)]
html_part = (
f''
f"{safe_token}
"
f'{token_id}'
f""
)
html_parts.append(html_part)
# 统计信息
token_len = len(input_ids)
char_len = len(text)
return "".join(html_parts), token_len, char_len
except Exception as e:
error_msg = f"Error: {str(e)}"
return error_msg, ""
banner_md = """# 🎨 Tokenize it!
Powerful token visualization tool for your text inputs. 🚀
Works for LLMs both online and *locally* on your machine!"""
banner = gr.Markdown(banner_md)
model_selector = gr.Dropdown(
label="Choose Model", choices=get_available_models(), interactive=True
)
text_input = gr.Textbox(label="Input Text", placeholder="Hello World!", lines=4)
submit_btn = gr.Button("🚀 Tokenize!", variant="primary")
output_html = gr.HTML(label="Tokenized Output", elem_classes="token-output")
token_count = gr.Number(label="Token Count", value=0, interactive=False)
char_count = gr.Number(label="Character Count", value=0, interactive=False)
with gr.Blocks(title="Token Visualizer", theme="NoCrypt/miku") as webui:
banner.render()
with gr.Column():
model_selector.render()
text_input.render()
submit_btn.render()
with gr.Column():
with gr.Row():
token_count.render()
char_count.render()
output_html.render()
# 定义CSS样式
webui.css = """
.token-output span {
margin: 3px;
vertical-align: top;
}
.stats-output {
font-weight: bold !important;
color: #2c3e50 !important;
}
.gradio-container { /* 针对 Gradio 的主容器 */
width: 100%; /* 根据需要调整宽度 */
display: flex;
justify-content: center;
align-items: center;
}
.gradio-container > div { /* 直接子元素,通常包含你的内容 */
width: 90%; /* 或者你想要的固定宽度 */
max-width: 1200px; /* 设置最大宽度 */
}
"""
submit_btn.click(
fn=tokenize_text,
inputs=[model_selector, text_input],
outputs=[output_html, token_count, char_count],
)
if __name__ == "__main__":
os.makedirs("models", exist_ok=True)
webui.launch(pwa=True)