import io import torch import gradio as gr import numpy as np from PIL import Image from unimernet.common.config import Config from unimernet.processors import load_processor import unimernet.tasks as tasks import argparse import os MAX_WIDTH = 872 MAX_HEIGHT = 1024 class ImageProcessor: """ImageProcessor handles model loading and image processing.""" def __init__(self, cfg_path): self.cfg_path = cfg_path self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model, self.vis_processor = self.load_model_and_processor() def load_model_and_processor(self): args = argparse.Namespace(cfg_path=self.cfg_path, options=None) cfg = Config(args) task = tasks.setup_task(cfg) model = task.build_model(cfg).to(self.device) vis_processor = load_processor( "formula_image_eval", cfg.config.datasets.formula_rec_eval.vis_processor.eval, ) return model, vis_processor def process_single_image(self, pil_image): image = self.vis_processor(pil_image).unsqueeze(0).to(self.device) output = self.model.generate({"image": image}) pred = output["pred_str"][0] return pred # 初始化模型 cfg_path = "demo.yaml" processor = ImageProcessor(cfg_path) # 单张预测 def predict_single(img): if img is None: return "No image uploaded." img = img.convert("RGB") img.thumbnail((MAX_WIDTH, MAX_HEIGHT), Image.Resampling.LANCZOS) latex_code = processor.process_single_image(img) return latex_code # 批量预测 def predict_batch(img_list): if not img_list: return ["No images uploaded."] results = [] for img in img_list: if img is None: results.append("Invalid image") continue img = img.convert("RGB") img.thumbnail((MAX_WIDTH, MAX_HEIGHT), Image.Resampling.LANCZOS) latex_code = processor.process_single_image(img) results.append(latex_code) return results # 界面搭建 title = "UniMERNet Formula Recognition" description = "Upload an image (or multiple images) containing math formulas. The model will return LaTeX code." with gr.Blocks(title=title) as demo: gr.Markdown(f"# {title}") gr.Markdown(description) with gr.Tab("Single Image Recognition"): with gr.Row(): input_img = gr.Image(type="pil", label="Upload a single formula image") output_text = gr.Textbox(label="Predicted LaTeX code", lines=5) btn_single = gr.Button("Recognize Single Image") btn_single.click(fn=predict_single, inputs=input_img, outputs=output_text) # with gr.Tab("Batch Image Recognition"): # with gr.Row(): # input_imgs = gr.File(file_types=["image"], file_count="multiple", label="Upload multiple images (png/jpg/jpeg/webp)") # # batch_outputs = gr.Dataframe(headers=["Image", "Predicted LaTeX Code"], datatype=["str", "str"]) # # def batch_process(files): # imgs = [] # file_names = [] # for file in files: # with Image.open(file.name) as img: # imgs.append(img.copy()) # file_names.append(os.path.basename(file.name)) # preds = predict_batch(imgs) # return list(zip(file_names, preds)) # # btn_batch = gr.Button("Recognize Batch Images") # btn_batch.click(fn=batch_process, inputs=input_imgs, outputs=batch_outputs) with gr.Tab("Batch Image Recognition"): with gr.Row(): input_imgs = gr.File(file_types=["image"], file_count="multiple", label="Upload multiple images (png/jpg/jpeg/webp)") batch_outputs = gr.Dataframe() # ✅ 改这里,不加headers和datatype! def batch_process(files): imgs = [] file_names = [] for file in files: with Image.open(file.name) as img: imgs.append(img.copy()) file_names.append(os.path.basename(file.name)) preds = predict_batch(imgs) return list(zip(file_names, preds)) btn_batch = gr.Button("Recognize Batch Images") btn_batch.click(fn=batch_process, inputs=input_imgs, outputs=batch_outputs) if __name__ == "__main__": demo.launch()