PSNbst's picture
Update app.py
827da2e verified
raw
history blame
7.98 kB
import gradio as gr
import os
import subprocess
from PIL import Image, ImageChops, ImageFilter
from ultralytics import YOLO
from segment_anything import SamPredictor, sam_model_registry
from transformers import BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel, AutoProcessor, AutoModelForImageClassification
import torch
import matplotlib.pyplot as plt
import numpy as np
from openai import OpenAI
from huggingface_hub import hf_hub_download
from segment_anything import SamPredictor, sam_model_registry
from yolo_world.models.detectors import build_detector
from mmcv import Config
# 初始化模型
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
sam_checkpoint = hf_hub_download(
repo_id="facebook/sam-vit-large", # 仓库 ID
filename="model.safetensors", # 模型文件名
use_auth_token=False # 公共仓库无需身份验证
)
sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint)
sam_predictor = SamPredictor(sam)
# 从 Hugging Face 下载 YOLO-World 权重
yolo_checkpoint = hf_hub_download(
repo_id="stevengrove/YOLO-World", # Hugging Face 仓库 ID
filename="yolo_world_v2_xl_obj365v1_goldg_cc3mlite_pretrain.pth", # 模型权重文件名
use_auth_token=False # 公共仓库无需身份验证
)
# 加载 YOLO-World 配置文件
yolo_config = Config.fromfile('path/to/yolo_world_config.py') # 替换为实际配置文件路径
# 构建 YOLO-World 模型
yolo_model = build_detector(yolo_config.model)
# 加载权重到模型
checkpoint = torch.load(yolo_checkpoint, map_location="cpu") # 使用 CPU 加载权重,后续可以转移到 GPU
yolo_model.load_state_dict(checkpoint["state_dict"])
yolo_model.eval() # 设置为评估模式
wd_processor = AutoProcessor.from_pretrained("SmilingWolf/wd-vit-tagger-v3")
wd_model = AutoModelForImageClassification.from_pretrained("SmilingWolf/wd-vit-tagger-v3")
# 自动识别图片类型
def classify_image_type(image):
inputs = wd_processor(images=image, return_tensors="pt")
outputs = wd_model(**inputs)
scores = torch.softmax(outputs.logits, dim=1)[0]
anime_score = scores[wd_processor.label2id["anime"]].item()
return "anime" if anime_score > 0.5 else "real"
# 分割图像对象
def segment_objects(image, boxes):
image_np = np.array(image)
sam_predictor.set_image(image_np)
masks = []
for box in boxes:
mask, _, _ = sam_predictor.predict(
point_coords=None, point_labels=None, box=box, multimask_output=False
)
masks.append(mask)
return masks
# 检测对象
def detect_objects(image, image_type):
if image_type == "real":
results = yolo_model.predict(np.array(image), conf=0.25)
objects = [{"label": r["class"], "box": r["bbox"], "confidence": r["confidence"]} for r in results]
else:
inputs = wd_processor(images=image, return_tensors="pt")
outputs = wd_model(**inputs)
scores = torch.softmax(outputs.logits, dim=1)[0]
top_k = torch.topk(scores, k=5)
objects = [{"label": wd_processor.decode(top_k.indices[i].item()), "confidence": top_k.values[i].item()} for i in range(5)]
return objects
# 生成语义描述
def generate_object_descriptions(image, objects):
descriptions = []
for obj in objects:
box = obj.get("box", None)
if box:
cropped = image.crop(box)
else:
cropped = image
inputs = blip_processor(cropped, return_tensors="pt")
caption = blip_model.generate(**inputs, max_length=128, num_beams=5, no_repeat_ngram_size=2)
description = blip_processor.decode(caption[0], skip_special_tokens=True)
descriptions.append({"label": obj["label"], "description": description})
return descriptions
# 特征差异可视化
def plot_feature_differences(latent_diff, descriptions, prefix):
diff_magnitude = [abs(x) for x in latent_diff[0]]
indices = range(len(diff_magnitude))
top_indices = np.argsort(diff_magnitude)[-10:][::-1]
plt.figure(figsize=(8, 4))
plt.bar(indices, diff_magnitude, alpha=0.7)
plt.xlabel("Feature Index")
plt.ylabel("Magnitude of Difference")
plt.title("Feature Differences (Bar Chart)")
bar_chart_path = f"{prefix}_bar_chart.png"
plt.savefig(bar_chart_path)
plt.close()
plt.figure(figsize=(6, 6))
plt.pie(
[diff_magnitude[i] for i in top_indices],
labels=[descriptions[i] for i in top_indices],
autopct="%1.1f%%",
startangle=140
)
plt.title("Top 10 Feature Differences (Pie Chart)")
pie_chart_path = f"{prefix}_pie_chart.png"
plt.savefig(pie_chart_path)
plt.close()
return bar_chart_path, pie_chart_path
# 生成详细分析文本
def generate_text_analysis(api_key, api_type, caption_a, caption_b):
if api_type == "DeepSeek":
client = OpenAI(api_key=api_key, base_url="https://api.deepseek.com")
else:
client = OpenAI(api_key=api_key)
response = client.chat.completions.create(
model="gpt-4" if api_type == "GPT" else "deepseek-chat",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": f"图片A的描述为:{caption_a}。\n图片B的描述为:{caption_b}。\n请对两张图片进行详细对比分析。"}
]
)
return response.choices[0].message.content.strip()
# 分析单对图片
def analyze_images(img_a, img_b, api_key, api_type, prefix):
type_a = classify_image_type(img_a)
type_b = classify_image_type(img_b)
objects_a = detect_objects(img_a, type_a)
objects_b = detect_objects(img_b, type_b)
descriptions_a = generate_object_descriptions(img_a, objects_a)
descriptions_b = generate_object_descriptions(img_b, objects_b)
inputs = clip_processor(images=img_a, return_tensors="pt")
features_a = clip_model.get_image_features(**inputs).detach().numpy()
inputs = clip_processor(images=img_b, return_tensors="pt")
features_b = clip_model.get_image_features(**inputs).detach().numpy()
latent_diff = np.abs(features_a - features_b).tolist()
bar_chart, pie_chart = plot_feature_differences(latent_diff, [d['label'] for d in descriptions_a], prefix)
text_analysis = generate_text_analysis(api_key, api_type, descriptions_a, descriptions_b)
return {
"bar_chart": bar_chart,
"pie_chart": pie_chart,
"text_analysis": text_analysis
}
# Gradio 界面
with gr.Blocks() as demo:
gr.Markdown("# 综合图像对比分析工具")
api_key_input = gr.Textbox(label="API Key", placeholder="输入 API Key", type="password")
api_type_input = gr.Radio(label="API 类型", choices=["GPT", "DeepSeek"], value="GPT")
images_a_input = gr.File(label="上传文件夹A图片", file_types=[".png", ".jpg"], file_count="multiple")
images_b_input = gr.File(label="上传文件夹B图片", file_types=[".png", ".jpg"], file_count="multiple")
analyze_button = gr.Button("开始分析")
result_gallery = gr.Gallery(label="差异可视化")
result_text = gr.Textbox(label="分析结果", lines=5)
def process_batch(images_a, images_b, api_key, api_type):
images_a = [Image.open(img).convert("RGB") for img in images_a]
images_b = [Image.open(img).convert("RGB") for img in images_b]
results = [analyze_images(img_a, img_b, api_key, api_type, f"comparison_{i+1}") for i, (img_a, img_b) in enumerate(zip(images_a, images_b))]
return results
analyze_button.click(process_batch, inputs=[images_a_input, images_b_input, api_key_input, api_type_input], outputs=[result_gallery, result_text])
demo.launch()