PSNbst commited on
Commit
eb0b8f5
·
verified ·
1 Parent(s): 8d8b4cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -43
app.py CHANGED
@@ -1,8 +1,9 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import CLIPProcessor, CLIPModel, BlipProcessor, BlipForConditionalGeneration
4
- from PIL import Image
5
  import numpy as np
 
6
  from openai import OpenAI
7
 
8
  # 初始化模型
@@ -11,13 +12,77 @@ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
11
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
12
  blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
13
 
14
- # 定义功能函数
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def analyze_images(image_a, image_b, api_key):
 
 
 
 
 
 
 
 
16
  # BLIP生成描述
17
- def generate_caption(image):
18
- inputs = blip_processor(image, return_tensors="pt")
19
- caption = blip_model.generate(**inputs)
20
- return blip_processor.decode(caption[0], skip_special_tokens=True)
21
 
22
  # CLIP特征提取
23
  def extract_features(image):
@@ -25,24 +90,11 @@ def analyze_images(image_a, image_b, api_key):
25
  features = clip_model.get_image_features(**inputs)
26
  return features.detach().numpy()
27
 
28
- # 图像已经是 PIL.Image 对象,直接处理
29
- img_a = image_a.convert("RGB")
30
- img_b = image_b.convert("RGB")
31
-
32
- # 生成描述
33
- caption_a = generate_caption(img_a)
34
- caption_b = generate_caption(img_b)
35
-
36
- # 提取特征
37
  features_a = extract_features(img_a)
38
  features_b = extract_features(img_b)
39
-
40
- # 计算嵌入相似性
41
- cosine_similarity = np.dot(features_a, features_b.T) / (np.linalg.norm(features_a) * np.linalg.norm(features_b))
42
  latent_diff = np.abs(features_a - features_b).tolist()
43
 
44
- # 调用 DeepSeek API 生成详细分析
45
- client = OpenAI(api_key=api_key, base_url="https://api.deepseek.com")
46
  gpt_response = client.chat.completions.create(
47
  model="deepseek-chat",
48
  messages=[
@@ -51,45 +103,79 @@ def analyze_images(image_a, image_b, api_key):
51
  ],
52
  stream=False
53
  )
54
- textual_analysis = gpt_response.choices[0].message.content.strip()
55
-
56
- # 返回结果
 
 
57
  return {
58
  "caption_a": caption_a,
59
  "caption_b": caption_b,
60
- "similarity": cosine_similarity[0][0],
61
- "latent_diff": latent_diff,
62
- "text_analysis": textual_analysis
 
63
  }
64
 
65
- # 定义Gradio界面
66
  with gr.Blocks() as demo:
67
- gr.Markdown("# 图片对比分析工具")
68
-
 
69
  with gr.Row():
70
  with gr.Column():
71
- image_a = gr.Image(label="图片A", type="pil") # 使用 PIL 类型
72
  with gr.Column():
73
- image_b = gr.Image(label="图片B", type="pil") # 使用 PIL 类型
74
-
75
- api_key_input = gr.Textbox(label="API Key", placeholder="输入您的 DeepSeek API Key", type="password")
76
-
77
  analyze_button = gr.Button("分析图片")
78
- result_caption_a = gr.Textbox(label="图片A描述", interactive=False)
79
- result_caption_b = gr.Textbox(label="图片B描述", interactive=False)
80
- result_similarity = gr.Number(label="图片相似性", interactive=False)
81
- result_latent_diff = gr.DataFrame(label="潜在特征差异", interactive=False)
82
- result_text_analysis = gr.Textbox(label="详细分析", interactive=False, lines=5)
83
-
 
 
 
 
 
 
 
 
 
84
  # 分析逻辑
85
  def process_analysis(img_a, img_b, api_key):
86
  results = analyze_images(img_a, img_b, api_key)
87
- return results["caption_a"], results["caption_b"], results["similarity"], results["latent_diff"], results["text_analysis"]
88
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  analyze_button.click(
90
  fn=process_analysis,
91
  inputs=[image_a, image_b, api_key_input],
92
- outputs=[result_caption_a, result_caption_b, result_similarity, result_latent_diff, result_text_analysis]
 
 
 
 
 
 
 
93
  )
94
 
95
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
  from transformers import CLIPProcessor, CLIPModel, BlipProcessor, BlipForConditionalGeneration
4
+ from PIL import Image, ImageChops
5
  import numpy as np
6
+ import matplotlib.pyplot as plt
7
  from openai import OpenAI
8
 
9
  # 初始化模型
 
12
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
13
  blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
14
 
15
+ # 图像处理函数
16
+ def compute_difference_images(img_a, img_b):
17
+ # 线稿提取
18
+ def extract_sketch(image):
19
+ grayscale = image.convert("L")
20
+ inverted = ImageChops.invert(grayscale)
21
+ sketch = ImageChops.screen(grayscale, inverted)
22
+ return sketch
23
+
24
+ # 法向量图像(模拟法向量处理为简单的边缘增强)
25
+ def compute_normal_map(image):
26
+ edges = image.filter(ImageFilter.FIND_EDGES)
27
+ return edges
28
+
29
+ # 图像混合差异
30
+ diff_overlay = ImageChops.difference(img_a, img_b)
31
+
32
+ return {
33
+ "original_a": img_a,
34
+ "original_b": img_b,
35
+ "sketch_a": extract_sketch(img_a),
36
+ "sketch_b": extract_sketch(img_b),
37
+ "normal_a": compute_normal_map(img_a),
38
+ "normal_b": compute_normal_map(img_b),
39
+ "diff_overlay": diff_overlay
40
+ }
41
+
42
+ # BLIP生成更详尽描述
43
+ def generate_detailed_caption(image):
44
+ inputs = blip_processor(image, return_tensors="pt")
45
+ caption = blip_model.generate(**inputs, max_length=128, num_beams=5, no_repeat_ngram_size=2)
46
+ return blip_processor.decode(caption[0], skip_special_tokens=True)
47
+
48
+ # 特征差异可视化
49
+ def plot_feature_differences(latent_diff):
50
+ diff_magnitude = [abs(x) for x in latent_diff[0]]
51
+ indices = range(len(diff_magnitude))
52
+
53
+ # 柱状图
54
+ plt.figure(figsize=(8, 4))
55
+ plt.bar(indices, diff_magnitude, alpha=0.7)
56
+ plt.xlabel("Feature Index")
57
+ plt.ylabel("Magnitude of Difference")
58
+ plt.title("Feature Differences (Bar Chart)")
59
+ bar_chart_path = "bar_chart.png"
60
+ plt.savefig(bar_chart_path)
61
+ plt.close()
62
+
63
+ # 饼图
64
+ plt.figure(figsize=(6, 6))
65
+ plt.pie(diff_magnitude[:10], labels=range(10), autopct="%1.1f%%", startangle=140)
66
+ plt.title("Top 10 Feature Differences (Pie Chart)")
67
+ pie_chart_path = "pie_chart.png"
68
+ plt.savefig(pie_chart_path)
69
+ plt.close()
70
+
71
+ return bar_chart_path, pie_chart_path
72
+
73
+ # 分析函数
74
  def analyze_images(image_a, image_b, api_key):
75
+ # 调用 OpenAI 客户端
76
+ client = OpenAI(api_key=api_key, base_url="https://api.deepseek.com")
77
+
78
+ # 图像差异处理
79
+ img_a = image_a.convert("RGB")
80
+ img_b = image_b.convert("RGB")
81
+ images_diff = compute_difference_images(img_a, img_b)
82
+
83
  # BLIP生成描述
84
+ caption_a = generate_detailed_caption(img_a)
85
+ caption_b = generate_detailed_caption(img_b)
 
 
86
 
87
  # CLIP特征提取
88
  def extract_features(image):
 
90
  features = clip_model.get_image_features(**inputs)
91
  return features.detach().numpy()
92
 
 
 
 
 
 
 
 
 
 
93
  features_a = extract_features(img_a)
94
  features_b = extract_features(img_b)
 
 
 
95
  latent_diff = np.abs(features_a - features_b).tolist()
96
 
97
+ # 调用 GPT 获取更详细描述
 
98
  gpt_response = client.chat.completions.create(
99
  model="deepseek-chat",
100
  messages=[
 
103
  ],
104
  stream=False
105
  )
106
+ text_analysis = gpt_response.choices[0].message.content.strip()
107
+
108
+ # 可视化特征差异
109
+ bar_chart_path, pie_chart_path = plot_feature_differences(latent_diff)
110
+
111
  return {
112
  "caption_a": caption_a,
113
  "caption_b": caption_b,
114
+ "text_analysis": text_analysis,
115
+ "images_diff": images_diff,
116
+ "bar_chart": bar_chart_path,
117
+ "pie_chart": pie_chart_path
118
  }
119
 
120
+ # Gradio界面
121
  with gr.Blocks() as demo:
122
+ gr.Markdown("# 图像对比分析工具")
123
+ api_key_input = gr.Textbox(label="API Key", placeholder="输入您的 DeepSeek API Key", type="password")
124
+
125
  with gr.Row():
126
  with gr.Column():
127
+ image_a = gr.Image(label="图片A", type="pil")
128
  with gr.Column():
129
+ image_b = gr.Image(label="图片B", type="pil")
130
+
 
 
131
  analyze_button = gr.Button("分析图片")
132
+
133
+ with gr.Row():
134
+ gr.Markdown("## 图像差异")
135
+ result_diff = gr.Gallery(label="混合差异图像").style(grid=3)
136
+
137
+ with gr.Row():
138
+ result_caption_a = gr.Textbox(label="图片A描述", interactive=False)
139
+ result_caption_b = gr.Textbox(label="图片B描述", interactive=False)
140
+
141
+ with gr.Row():
142
+ gr.Markdown("## 差异分析")
143
+ result_text_analysis = gr.Textbox(label="详细分析", interactive=False, lines=5)
144
+ result_bar_chart = gr.Image(label="特征差异柱状图")
145
+ result_pie_chart = gr.Image(label="特征差异饼图")
146
+
147
  # 分析逻辑
148
  def process_analysis(img_a, img_b, api_key):
149
  results = analyze_images(img_a, img_b, api_key)
150
+ diff_images = [
151
+ ("Original A", results["images_diff"]["original_a"]),
152
+ ("Original B", results["images_diff"]["original_b"]),
153
+ ("Sketch A", results["images_diff"]["sketch_a"]),
154
+ ("Sketch B", results["images_diff"]["sketch_b"]),
155
+ ("Normal A", results["images_diff"]["normal_a"]),
156
+ ("Normal B", results["images_diff"]["normal_b"]),
157
+ ("Difference Overlay", results["images_diff"]["diff_overlay"]),
158
+ ]
159
+ return (
160
+ diff_images,
161
+ results["caption_a"],
162
+ results["caption_b"],
163
+ results["text_analysis"],
164
+ results["bar_chart"],
165
+ results["pie_chart"]
166
+ )
167
+
168
  analyze_button.click(
169
  fn=process_analysis,
170
  inputs=[image_a, image_b, api_key_input],
171
+ outputs=[
172
+ result_diff,
173
+ result_caption_a,
174
+ result_caption_b,
175
+ result_text_analysis,
176
+ result_bar_chart,
177
+ result_pie_chart
178
+ ]
179
  )
180
 
181
  demo.launch()