import torch from transformers import AutoModelForCausalLM, AutoProcessor, VisionEncoderDecoderModel from huggingface_hub import snapshot_download from qwen_vl_utils import process_vision_info def load_model(model_name): """ Load the specified model and its processor based on the model name. Args: model_name (str): Name of the model ("dots.ocr" or "Dolphin"). Returns: tuple: (model, processor) for the specified model. """ if model_name == "dots.ocr": model_id = "rednote-hilab/dots.ocr" model_path = "./models/dots-ocr-local" snapshot_download( repo_id=model_id, local_dir=model_path, local_dir_use_symlinks=False, ) model = AutoModelForCausalLM.from_pretrained( model_path, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True ) processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) elif model_name == "Dolphin": model_id = "ByteDance/Dolphin" processor = AutoProcessor.from_pretrained(model_id) model = VisionEncoderDecoderModel.from_pretrained(model_id) model.eval() device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) model = model.half() # Use half precision else: raise ValueError(f"Unknown model: {model_name}") return model, processor def inference_dots_ocr(model, processor, image, prompt, max_new_tokens): """ Perform inference using the dots.ocr model. Args: model: The loaded dots.ocr model. processor: The corresponding processor. image (PIL.Image): Input image. prompt (str): Prompt for inference. max_new_tokens (int): Maximum number of tokens to generate. Returns: str: Generated text output. """ messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": prompt} ] } ] text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) image_inputs, video_inputs = process_vision_info(messages) inputs = processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) inputs = inputs.to(model.device) with torch.no_grad(): generated_ids = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False, temperature=0.1 ) generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] output_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) return output_text[0] if output_text else "" def inference_dolphin(model, processor, image): """ Perform inference using the Dolphin model. Args: model: The loaded Dolphin model. processor: The corresponding processor. image (PIL.Image): Input image. Returns: str: Generated text output. """ pixel_values = processor(image, return_tensors="pt").pixel_values.to(model.device).half() generated_ids = model.generate(pixel_values) generated_text = processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] return generated_text