from transformers import AutoModelForCausalLM, AutoProcessor from PIL import Image import torch import gradio as gr import requests import tempfile import os MODEL_STATE = { "model": None, "processor": None, "authenticated": False } def login(hf_token): """Authenticate and load the model""" try: MODEL_STATE.update({"model": None, "processor": None, "authenticated": False}) MODEL_STATE["model"] = AutoModelForCausalLM.from_pretrained( "microsoft/maira-2", trust_remote_code=True, use_auth_token=hf_token ) MODEL_STATE["processor"] = AutoProcessor.from_pretrained( "microsoft/maira-2", trust_remote_code=True, use_auth_token=hf_token ) MODEL_STATE["model"] = MODEL_STATE["model"].eval().to("cpu") MODEL_STATE["authenticated"] = True return "🔓 Login successful! You can now use the model." except Exception as e: MODEL_STATE.update({"model": None, "processor": None, "authenticated": False}) return f"❌ Login failed: {str(e)}" def get_sample_data(): """Download sample medical images and data""" frontal_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-1001.png" lateral_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-2001.png" def download_image(url): response = requests.get(url, headers={"User-Agent": "MAIRA-2"}, stream=True) return Image.open(response.raw) return { "frontal": download_image(frontal_url), "lateral": download_image(lateral_url), "indication": "Dyspnea.", "technique": "PA and lateral views of the chest.", "comparison": "None.", "phrase": "Pleural effusion." } def save_temp_image(img): """Save PIL image to temporary file""" temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) img.save(temp_file.name) return temp_file.name def load_sample_findings(): sample = get_sample_data() return [ save_temp_image(sample["frontal"]), save_temp_image(sample["lateral"]), sample["indication"], sample["technique"], sample["comparison"], None, None, None, False ] def load_sample_phrase(): sample = get_sample_data() return [save_temp_image(sample["frontal"]), sample["phrase"]] def generate_report(frontal_path, lateral_path, indication, technique, comparison, prior_frontal_path, prior_lateral_path, prior_report, grounding): """Generate radiology report with authentication check""" if not MODEL_STATE["authenticated"]: return "⚠️ Please authenticate with your Hugging Face token first!" try: current_frontal = Image.open(frontal_path) if frontal_path else None current_lateral = Image.open(lateral_path) if lateral_path else None prior_frontal = Image.open(prior_frontal_path) if prior_frontal_path else None prior_lateral = Image.open(prior_lateral_path) if prior_lateral_path else None if not current_frontal or not current_lateral: return "❌ Missing required current study images" prior_report = prior_report or "" processed = MODEL_STATE["processor"].format_and_preprocess_reporting_input( current_frontal=current_frontal, current_lateral=current_lateral, prior_frontal=prior_frontal, prior_lateral=prior_lateral, indication=indication, technique=technique, comparison=comparison, prior_report=prior_report, return_tensors="pt", get_grounding=grounding ).to("cpu") processed = dict(processed) image_size_keys = [k for k in processed.keys() if "image_sizes" in k] for k in image_size_keys: processed.pop(k, None) outputs = MODEL_STATE["model"].generate( **processed, max_new_tokens=450 if grounding else 300, use_cache=True ) prompt_length = processed["input_ids"].shape[-1] decoded = MODEL_STATE["processor"].decode(outputs[0][prompt_length:], skip_special_tokens=True) return MODEL_STATE["processor"].convert_output_to_plaintext_or_grounded_sequence(decoded.lstrip()) except Exception as e: return f"❌ Generation error: {str(e)}" def ground_phrase(frontal_path, phrase): """Perform phrase grounding with authentication check""" if not MODEL_STATE["authenticated"]: return "⚠️ Please authenticate with your Hugging Face token first!" try: if not frontal_path: return "❌ Missing frontal view image" frontal = Image.open(frontal_path) processed = MODEL_STATE["processor"].format_and_preprocess_phrase_grounding_input( frontal_image=frontal, phrase=phrase, return_tensors="pt" ).to("cpu") # Convert to regular dict and remove image size related keys processed = dict(processed) image_size_keys = [k for k in processed.keys() if "image_sizes" in k] for k in image_size_keys: processed.pop(k, None) outputs = MODEL_STATE["model"].generate( **processed, max_new_tokens=150, use_cache=True ) prompt_length = processed["input_ids"].shape[-1] decoded = MODEL_STATE["processor"].decode(outputs[0][prompt_length:], skip_special_tokens=True) return MODEL_STATE["processor"].convert_output_to_plaintext_or_grounded_sequence(decoded) except Exception as e: return f"❌ Grounding error: {str(e)}" with gr.Blocks(title="MAIRA-2 Medical Assistant") as demo: gr.Markdown("""# MAIRA-2 Medical Assistant **Authentication required** - You need a Hugging Face account and access token to use this model. 1. Get your access token from [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens) 2. Request model access at [https://huggingface.co/microsoft/maira-2](https://huggingface.co/microsoft/maira-2) 3. Paste your token below to begin """) with gr.Row(): hf_token = gr.Textbox( label="Hugging Face Token", placeholder="hf_xxxxxxxxxxxxxxxxxxxx", type="password" ) login_btn = gr.Button("Authenticate") login_status = gr.Textbox(label="Authentication Status", interactive=False) login_btn.click( login, inputs=hf_token, outputs=login_status ) with gr.Tabs(): with gr.Tab("Report Generation"): with gr.Row(): with gr.Column(): gr.Markdown("## Current Study") frontal = gr.Image(label="Frontal View", type="filepath") lateral = gr.Image(label="Lateral View", type="filepath") indication = gr.Textbox(label="Clinical Indication") technique = gr.Textbox(label="Imaging Technique") comparison = gr.Textbox(label="Comparison") gr.Markdown("## Prior Study (Optional)") prior_frontal = gr.Image(label="Prior Frontal View", type="filepath") prior_lateral = gr.Image(label="Prior Lateral View", type="filepath") prior_report = gr.Textbox(label="Prior Report") grounding = gr.Checkbox(label="Include Grounding") sample_btn = gr.Button("Load Sample Data") with gr.Column(): report_output = gr.Textbox(label="Generated Report", lines=10) generate_btn = gr.Button("Generate Report") sample_btn.click( load_sample_findings, outputs=[frontal, lateral, indication, technique, comparison, prior_frontal, prior_lateral, prior_report, grounding] ) generate_btn.click( generate_report, inputs=[frontal, lateral, indication, technique, comparison, prior_frontal, prior_lateral, prior_report, grounding], outputs=report_output ) with gr.Tab("Phrase Grounding"): with gr.Row(): with gr.Column(): pg_frontal = gr.Image(label="Frontal View", type="filepath") phrase = gr.Textbox(label="Phrase to Ground") pg_sample_btn = gr.Button("Load Sample Data") with gr.Column(): pg_output = gr.Textbox(label="Grounding Result", lines=3) pg_btn = gr.Button("Find Phrase") pg_sample_btn.click( load_sample_phrase, outputs=[pg_frontal, phrase] ) pg_btn.click( ground_phrase, inputs=[pg_frontal, phrase], outputs=pg_output ) demo.launch()