MAIRA-2 / app.py
ayyuce's picture
Update app.py
f475e3a verified
raw
history blame
9.04 kB
from transformers import AutoModelForCausalLM, AutoProcessor
from PIL import Image
import torch
import gradio as gr
import requests
import tempfile
import os
MODEL_STATE = {
"model": None,
"processor": None,
"authenticated": False
}
def login(hf_token):
"""Authenticate and load the model"""
try:
MODEL_STATE.update({"model": None, "processor": None, "authenticated": False})
MODEL_STATE["model"] = AutoModelForCausalLM.from_pretrained(
"microsoft/maira-2",
trust_remote_code=True,
use_auth_token=hf_token
)
MODEL_STATE["processor"] = AutoProcessor.from_pretrained(
"microsoft/maira-2",
trust_remote_code=True,
use_auth_token=hf_token
)
MODEL_STATE["model"] = MODEL_STATE["model"].eval().to("cpu")
MODEL_STATE["authenticated"] = True
return "πŸ”“ Login successful! You can now use the model."
except Exception as e:
MODEL_STATE.update({"model": None, "processor": None, "authenticated": False})
return f"❌ Login failed: {str(e)}"
def get_sample_data():
"""Download sample medical images and data"""
frontal_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-1001.png"
lateral_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-2001.png"
def download_image(url):
response = requests.get(url, headers={"User-Agent": "MAIRA-2"}, stream=True)
return Image.open(response.raw)
return {
"frontal": download_image(frontal_url),
"lateral": download_image(lateral_url),
"indication": "Dyspnea.",
"technique": "PA and lateral views of the chest.",
"comparison": "None.",
"phrase": "Pleural effusion."
}
def save_temp_image(img):
"""Save PIL image to temporary file"""
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
img.save(temp_file.name)
return temp_file.name
def load_sample_findings():
sample = get_sample_data()
return [
save_temp_image(sample["frontal"]),
save_temp_image(sample["lateral"]),
sample["indication"],
sample["technique"],
sample["comparison"],
None, None, None, False
]
def load_sample_phrase():
sample = get_sample_data()
return [save_temp_image(sample["frontal"]), sample["phrase"]]
def generate_report(frontal_path, lateral_path, indication, technique, comparison,
prior_frontal_path, prior_lateral_path, prior_report, grounding):
"""Generate radiology report with authentication check"""
if not MODEL_STATE["authenticated"]:
return "⚠️ Please authenticate with your Hugging Face token first!"
try:
current_frontal = Image.open(frontal_path) if frontal_path else None
current_lateral = Image.open(lateral_path) if lateral_path else None
prior_frontal = Image.open(prior_frontal_path) if prior_frontal_path else None
prior_lateral = Image.open(prior_lateral_path) if prior_lateral_path else None
if not current_frontal or not current_lateral:
return "❌ Missing required current study images"
prior_report = prior_report or ""
processed = MODEL_STATE["processor"].format_and_preprocess_reporting_input(
current_frontal=current_frontal,
current_lateral=current_lateral,
prior_frontal=prior_frontal,
prior_lateral=prior_lateral,
indication=indication,
technique=technique,
comparison=comparison,
prior_report=prior_report,
return_tensors="pt",
get_grounding=grounding
).to("cpu")
if "image_sizes" in processed:
processed.pop("image_sizes")
outputs = MODEL_STATE["model"].generate(
**processed,
max_new_tokens=450 if grounding else 300,
use_cache=True
)
prompt_length = processed["input_ids"].shape[-1]
decoded = MODEL_STATE["processor"].decode(outputs[0][prompt_length:], skip_special_tokens=True)
return MODEL_STATE["processor"].convert_output_to_plaintext_or_grounded_sequence(decoded.lstrip())
except Exception as e:
return f"❌ Generation error: {str(e)}"
def ground_phrase(frontal_path, phrase):
"""Perform phrase grounding with authentication check"""
if not MODEL_STATE["authenticated"]:
return "⚠️ Please authenticate with your Hugging Face token first!"
try:
if not frontal_path:
return "❌ Missing frontal view image"
frontal = Image.open(frontal_path)
processed = MODEL_STATE["processor"].format_and_preprocess_phrase_grounding_input(
frontal_image=frontal,
phrase=phrase,
return_tensors="pt"
).to("cpu")
if "image_sizes" in processed:
processed.pop("image_sizes")
outputs = MODEL_STATE["model"].generate(
**processed,
max_new_tokens=150,
use_cache=True
)
prompt_length = processed["input_ids"].shape[-1]
decoded = MODEL_STATE["processor"].decode(outputs[0][prompt_length:], skip_special_tokens=True)
return MODEL_STATE["processor"].convert_output_to_plaintext_or_grounded_sequence(decoded)
except Exception as e:
return f"❌ Grounding error: {str(e)}"
with gr.Blocks(title="MAIRA-2 Medical Assistant") as demo:
gr.Markdown("""# MAIRA-2 Medical Assistant
**Authentication required** - You need a Hugging Face account and access token to use this model.
1. Get your access token from [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)
2. Request model access at [https://huggingface.co/microsoft/maira-2](https://huggingface.co/microsoft/maira-2)
3. Paste your token below to begin
""")
with gr.Row():
hf_token = gr.Textbox(
label="Hugging Face Token",
placeholder="hf_xxxxxxxxxxxxxxxxxxxx",
type="password"
)
login_btn = gr.Button("Authenticate")
login_status = gr.Textbox(label="Authentication Status", interactive=False)
login_btn.click(
login,
inputs=hf_token,
outputs=login_status
)
with gr.Tabs():
with gr.Tab("Report Generation"):
with gr.Row():
with gr.Column():
gr.Markdown("## Current Study")
frontal = gr.Image(label="Frontal View", type="filepath")
lateral = gr.Image(label="Lateral View", type="filepath")
indication = gr.Textbox(label="Clinical Indication")
technique = gr.Textbox(label="Imaging Technique")
comparison = gr.Textbox(label="Comparison")
gr.Markdown("## Prior Study (Optional)")
prior_frontal = gr.Image(label="Prior Frontal View", type="filepath")
prior_lateral = gr.Image(label="Prior Lateral View", type="filepath")
prior_report = gr.Textbox(label="Prior Report")
grounding = gr.Checkbox(label="Include Grounding")
sample_btn = gr.Button("Load Sample Data")
with gr.Column():
report_output = gr.Textbox(label="Generated Report", lines=10)
generate_btn = gr.Button("Generate Report")
sample_btn.click(
load_sample_findings,
outputs=[frontal, lateral, indication, technique, comparison,
prior_frontal, prior_lateral, prior_report, grounding]
)
generate_btn.click(
generate_report,
inputs=[frontal, lateral, indication, technique, comparison,
prior_frontal, prior_lateral, prior_report, grounding],
outputs=report_output
)
with gr.Tab("Phrase Grounding"):
with gr.Row():
with gr.Column():
pg_frontal = gr.Image(label="Frontal View", type="filepath")
phrase = gr.Textbox(label="Phrase to Ground")
pg_sample_btn = gr.Button("Load Sample Data")
with gr.Column():
pg_output = gr.Textbox(label="Grounding Result", lines=3)
pg_btn = gr.Button("Find Phrase")
pg_sample_btn.click(
load_sample_phrase,
outputs=[pg_frontal, phrase]
)
pg_btn.click(
ground_phrase,
inputs=[pg_frontal, phrase],
outputs=pg_output
)
demo.launch()