import os import gradio as gr from html import escape from transformers import AutoTokenizer def get_available_models() -> list[str]: """获取models目录下所有包含config.json的模型""" models_dir = "models" if not os.path.exists(models_dir): return [] available_models = [] for model_name in os.listdir(models_dir): model_path = os.path.join(models_dir, model_name) config_file = os.path.join(model_path, "config.json") if os.path.isdir(model_path) and os.path.isfile(config_file): available_models.append(model_name) return sorted(available_models) def tokenize_text( model_name: str, text: str ) -> tuple[str | None, str | None, int | None, dict | None, int, int]: """处理tokenize请求""" if not model_name: return "Please choose a model and input some texts", None, None, None, 0, 0 if not text: text = "Please choose a model and input some texts" try: # 加载tokenizer model_path = os.path.join("models", model_name) if os.path.isdir(model_path): tokenizer = AutoTokenizer.from_pretrained( model_path, trust_remote_code=True, device_map="cpu" ) else: tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=True, device_map="cpu" ) tokenizer_type = tokenizer.__class__.__name__ if hasattr(tokenizer, "vocab_size"): vocab_size = tokenizer.vocab_size elif hasattr(tokenizer, "get_vocab"): vocab_size = len(tokenizer.get_vocab()) else: vocab_size = -1 sp_token_list = [ "pad_token", "eos_token", "bos_token", "sep_token", "cls_token", "unk_token", "mask_token", "image_token", "audio_token", "video_token", "vision_bos_token", "vision_eos_token", "audio_bos_token", "audio_eos_token", ] special_tokens = {} for token_name in sp_token_list: if ( hasattr(tokenizer, token_name) and getattr(tokenizer, token_name) is not None ): token_value = getattr(tokenizer, token_name) if token_value and str(token_value).strip(): special_tokens[token_name] = str(token_value) # Tokenize处理 input_ids = tokenizer.encode(text, add_special_tokens=True) # 生成带颜色的HTML colors = ["#A8D8EA", "#AA96DA", "#FCBAD3"] html_parts = [] for i, token_id in enumerate(input_ids): # 转义HTML特殊字符 safe_token = escape(tokenizer.decode(token_id)) # 交替颜色 color = colors[i % len(colors)] html_part = ( f'' f"{safe_token}
" f'{token_id}' f"
" ) html_parts.append(html_part) # 统计信息 token_len = len(input_ids) char_len = len(text) return ( "".join(html_parts), tokenizer_type, vocab_size, special_tokens, token_len, char_len, ) except Exception as e: error_msg = f"Error: {str(e)}" return error_msg, None, None, None, 0, 0 banner_md = """# 🎨 Tokenize it! Powerful token visualization tool for your text inputs. 🚀 Works for LLMs both online and *locally* on your machine!""" banner = gr.Markdown(banner_md) model_selector = gr.Dropdown( label="Choose or enter model name", choices=get_available_models(), interactive=True, allow_custom_value=True, ) text_input = gr.Textbox(label="Input Text", placeholder="Hello World!", lines=4) submit_btn = gr.Button("🚀 Tokenize!", variant="primary") tokenizer_type = gr.Textbox(label="Tokenizer Type", interactive=False) vocab_size = gr.Number(label="Vocab Size", interactive=False) sp_tokens = gr.JSON(label="Special Tokens") output_html = gr.HTML(label="Tokenized Output", elem_classes="token-output") token_count = gr.Number(label="Token Count", value=0, interactive=False) char_count = gr.Number(label="Character Count", value=0, interactive=False) with gr.Blocks(title="Token Visualizer", theme="NoCrypt/miku") as webui: banner.render() with gr.Row(scale=2): with gr.Column(): model_selector.render() text_input.render() submit_btn.render() output_html.render() with gr.Column(): with gr.Accordion("Details", open=False): with gr.Row(): tokenizer_type.render() vocab_size.render() sp_tokens.render() with gr.Row(): token_count.render() char_count.render() # 定义CSS样式 webui.css = """ .token-output span { margin: 3px; vertical-align: top; } .stats-output { font-weight: bold !important; color: #2c3e50 !important; } """ submit_btn.click( fn=tokenize_text, inputs=[model_selector, text_input], outputs=[ output_html, tokenizer_type, vocab_size, sp_tokens, token_count, char_count, ], ) if __name__ == "__main__": os.makedirs("models", exist_ok=True) webui.launch(pwa=True)