leedoming commited on
Commit
fcd89ff
ยท
verified ยท
1 Parent(s): a124f15

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -132
app.py CHANGED
@@ -12,10 +12,6 @@ from inference_sdk import InferenceHTTPClient
12
  import matplotlib.pyplot as plt
13
  import base64
14
 
15
- # ์˜ค๋ฅ˜ ์ฒ˜๋ฆฌ๋ฅผ ์œ„ํ•œ ์˜ˆ์™ธ ํด๋ž˜์Šค ์ •์˜
16
- class APIError(Exception):
17
- pass
18
-
19
  # Load model and tokenizer
20
  @st.cache_resource
21
  def load_model():
@@ -26,12 +22,62 @@ def load_model():
26
 
27
  model, preprocess_val, tokenizer, device = load_model()
28
 
29
- # Roboflow client setup function
30
- def setup_roboflow_client(api_key):
31
- return InferenceHTTPClient(
32
- api_url="https://outline.roboflow.com",
33
- api_key=api_key
34
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  # Streamlit app
37
  st.title("Fashion Search App with Segmentation")
@@ -42,127 +88,6 @@ api_key = st.text_input("Enter your Roboflow API Key", type="password")
42
  if api_key:
43
  CLIENT = setup_roboflow_client(api_key)
44
 
45
- def segment_image(image_path):
46
- try:
47
- # ์ด๋ฏธ์ง€ ํŒŒ์ผ ์ฝ๊ธฐ
48
- with open(image_path, "rb") as image_file:
49
- image_data = image_file.read()
50
-
51
- # ์ด๋ฏธ์ง€๋ฅผ base64๋กœ ์ธ์ฝ”๋”ฉ
52
- encoded_image = base64.b64encode(image_data).decode('utf-8')
53
-
54
- # ์›๋ณธ ์ด๋ฏธ์ง€ ๋กœ๋“œ
55
- image = cv2.imread(image_path)
56
- image = cv2.resize(image, (800, 600))
57
- mask = np.zeros(image.shape, dtype=np.uint8)
58
-
59
- try:
60
- # Roboflow API ํ˜ธ์ถœ
61
- results = CLIENT.infer(encoded_image, model_id="closet/1")
62
- except Exception as api_error:
63
- st.error(f"API Error: {str(api_error)}")
64
- return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
65
-
66
- if 'predictions' in results:
67
- for prediction in results['predictions']:
68
- points = prediction['points']
69
- pts = np.array([[p['x'], p['y']] for p in points], np.int32)
70
- scale_x = image.shape[1] / results['image']['width']
71
- scale_y = image.shape[0] / results['image']['height']
72
- pts = pts * [scale_x, scale_y]
73
- pts = pts.astype(np.int32)
74
- pts = pts.reshape((-1, 1, 2))
75
- cv2.fillPoly(mask, [pts], color=(255, 255, 255)) # White mask
76
-
77
- segmented_image = cv2.bitwise_and(image, mask)
78
- else:
79
- st.warning("No predictions found in the image. Returning original image.")
80
- segmented_image = image
81
-
82
- return Image.fromarray(cv2.cvtColor(segmented_image, cv2.COLOR_BGR2RGB))
83
- except Exception as e:
84
- st.error(f"Error in segmentation: {str(e)}")
85
- # ์›๋ณธ ์ด๋ฏธ์ง€๋ฅผ ๋‹ค์‹œ ์ฝ์–ด ๋ฐ˜ํ™˜
86
- return Image.open(image_path)
87
- def get_image_embedding(image):
88
- image_tensor = preprocess_val(image).unsqueeze(0).to(device)
89
- with torch.no_grad():
90
- image_features = model.encode_image(image_tensor)
91
- image_features /= image_features.norm(dim=-1, keepdim=True)
92
- return image_features.cpu().numpy()
93
-
94
- # Load and process data
95
- @st.cache_data
96
- def load_data():
97
- with open('musinsa-final.json', 'r', encoding='utf-8') as f:
98
- return json.load(f)
99
-
100
- data = load_data()
101
-
102
- # Process database with segmentation
103
- @st.cache_data
104
- def download_and_process_image(image_url):
105
- try:
106
- response = requests.get(image_url)
107
- response.raise_for_status() # Raises an HTTPError for bad responses
108
- image = Image.open(BytesIO(response.content))
109
-
110
- # Convert image to RGB mode if it's in RGBA mode
111
- if image.mode == 'RGBA':
112
- image = image.convert('RGB')
113
-
114
- return image
115
- except requests.RequestException as e:
116
- st.error(f"Error downloading image: {e}")
117
- return None
118
- except Exception as e:
119
- st.error(f"Error processing image: {e}")
120
- return None
121
-
122
- def process_database():
123
- database_embeddings = []
124
- database_info = []
125
- for item in data:
126
- image_url = item['์ด๋ฏธ์ง€ ๋งํฌ'][0]
127
- product_id = item.get('\ufeff์ƒํ’ˆ ID') or item.get('์ƒํ’ˆ ID')
128
-
129
- image = download_and_process_image(image_url)
130
- if image is None:
131
- continue
132
-
133
- # Save the image temporarily
134
- temp_path = f"temp_{product_id}.jpg"
135
- image.save(temp_path, 'JPEG')
136
-
137
- segmented_image = segment_image(temp_path)
138
- embedding = get_image_embedding(segmented_image)
139
-
140
- database_embeddings.append(embedding)
141
- database_info.append({
142
- 'id': product_id,
143
- 'category': item['์นดํ…Œ๊ณ ๋ฆฌ'],
144
- 'brand': item['๋ธŒ๋žœ๋“œ๋ช…'],
145
- 'name': item['์ œํ’ˆ๋ช…'],
146
- 'price': item['์ •๊ฐ€'],
147
- 'discount': item['ํ• ์ธ์œจ'],
148
- 'image_url': image_url
149
- })
150
-
151
- return np.vstack(database_embeddings), database_info
152
-
153
- def find_similar_images(query_embedding, top_k=5):
154
- similarities = np.dot(database_embeddings, query_embedding.T).squeeze()
155
- top_indices = np.argsort(similarities)[::-1][:top_k]
156
-
157
- results = []
158
- for idx in top_indices:
159
- results.append({
160
- 'info': database_info[idx],
161
- 'similarity': similarities[idx]
162
- })
163
-
164
- return results
165
-
166
  uploaded_file = st.file_uploader("Choose an image...", type="jpg")
167
  if uploaded_file is not None:
168
  image = Image.open(uploaded_file)
@@ -195,4 +120,101 @@ if api_key:
195
  st.write(f"Discount: {img['info']['discount']}%")
196
  st.write(f"Similarity: {img['similarity']:.2f}")
197
  else:
198
- st.warning("Please enter your Roboflow API Key to use the app.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  import matplotlib.pyplot as plt
13
  import base64
14
 
 
 
 
 
15
  # Load model and tokenizer
16
  @st.cache_resource
17
  def load_model():
 
22
 
23
  model, preprocess_val, tokenizer, device = load_model()
24
 
25
+ # Load and process data
26
+ @st.cache_data
27
+ def load_data():
28
+ with open('musinsa-final.json', 'r', encoding='utf-8') as f:
29
+ return json.load(f)
30
+
31
+ data = load_data()
32
+
33
+ # Process database with segmentation
34
+ @st.cache_data
35
+ def process_database():
36
+ database_embeddings = []
37
+ database_info = []
38
+ for item in data:
39
+ image_url = item['์ด๋ฏธ์ง€ ๋งํฌ'][0]
40
+ product_id = item.get('\ufeff์ƒํ’ˆ ID') or item.get('์ƒํ’ˆ ID')
41
+
42
+ image = download_and_process_image(image_url)
43
+ if image is None:
44
+ continue
45
+
46
+ # Save the image temporarily
47
+ temp_path = f"temp_{product_id}.jpg"
48
+ image.save(temp_path, 'JPEG')
49
+
50
+ segmented_image = segment_image(temp_path)
51
+ embedding = get_image_embedding(segmented_image)
52
+
53
+ database_embeddings.append(embedding)
54
+ database_info.append({
55
+ 'id': product_id,
56
+ 'category': item['์นดํ…Œ๊ณ ๋ฆฌ'],
57
+ 'brand': item['๋ธŒ๋žœ๋“œ๋ช…'],
58
+ 'name': item['์ œํ’ˆ๋ช…'],
59
+ 'price': item['์ •๊ฐ€'],
60
+ 'discount': item['ํ• ์ธ์œจ'],
61
+ 'image_url': image_url
62
+ })
63
+
64
+ return np.vstack(database_embeddings), database_info
65
+
66
+ # Initialize database_embeddings and database_info
67
+ database_embeddings, database_info = process_database()
68
+
69
+ def find_similar_images(query_embedding, top_k=5):
70
+ similarities = np.dot(database_embeddings, query_embedding.T).squeeze()
71
+ top_indices = np.argsort(similarities)[::-1][:top_k]
72
+
73
+ results = []
74
+ for idx in top_indices:
75
+ results.append({
76
+ 'info': database_info[idx],
77
+ 'similarity': similarities[idx]
78
+ })
79
+
80
+ return results
81
 
82
  # Streamlit app
83
  st.title("Fashion Search App with Segmentation")
 
88
  if api_key:
89
  CLIENT = setup_roboflow_client(api_key)
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  uploaded_file = st.file_uploader("Choose an image...", type="jpg")
92
  if uploaded_file is not None:
93
  image = Image.open(uploaded_file)
 
120
  st.write(f"Discount: {img['info']['discount']}%")
121
  st.write(f"Similarity: {img['similarity']:.2f}")
122
  else:
123
+ st.warning("Please enter your Roboflow API Key to use the app.")
124
+
125
+
126
+ # Process database with segmentation
127
+ @st.cache_data
128
+ def download_and_process_image(image_url):
129
+ try:
130
+ response = requests.get(image_url)
131
+ response.raise_for_status() # Raises an HTTPError for bad responses
132
+ image = Image.open(BytesIO(response.content))
133
+
134
+ # Convert image to RGB mode if it's in RGBA mode
135
+ if image.mode == 'RGBA':
136
+ image = image.convert('RGB')
137
+
138
+ return image
139
+ except requests.RequestException as e:
140
+ st.error(f"Error downloading image: {e}")
141
+ return None
142
+ except Exception as e:
143
+ st.error(f"Error processing image: {e}")
144
+ return None
145
+
146
+ def process_database():
147
+ database_embeddings = []
148
+ database_info = []
149
+ for item in data:
150
+ image_url = item['์ด๋ฏธ์ง€ ๋งํฌ'][0]
151
+ product_id = item.get('\ufeff์ƒํ’ˆ ID') or item.get('์ƒํ’ˆ ID')
152
+
153
+ image = download_and_process_image(image_url)
154
+ if image is None:
155
+ continue
156
+
157
+ # Save the image temporarily
158
+ temp_path = f"temp_{product_id}.jpg"
159
+ image.save(temp_path, 'JPEG')
160
+
161
+ segmented_image = segment_image(temp_path)
162
+ embedding = get_image_embedding(segmented_image)
163
+
164
+ database_embeddings.append(embedding)
165
+ database_info.append({
166
+ 'id': product_id,
167
+ 'category': item['์นดํ…Œ๊ณ ๋ฆฌ'],
168
+ 'brand': item['๋ธŒ๋žœ๋“œ๋ช…'],
169
+ 'name': item['์ œํ’ˆ๋ช…'],
170
+ 'price': item['์ •๊ฐ€'],
171
+ 'discount': item['ํ• ์ธ์œจ'],
172
+ 'image_url': image_url
173
+ })
174
+
175
+ return np.vstack(database_embeddings), database_info
176
+
177
+ def find_similar_images(query_embedding, top_k=5):
178
+ similarities = np.dot(database_embeddings, query_embedding.T).squeeze()
179
+ top_indices = np.argsort(similarities)[::-1][:top_k]
180
+
181
+ results = []
182
+ for idx in top_indices:
183
+ results.append({
184
+ 'info': database_info[idx],
185
+ 'similarity': similarities[idx]
186
+ })
187
+
188
+ return results
189
+
190
+ uploaded_file = st.file_uploader("Choose an image...", type="jpg")
191
+ if uploaded_file is not None:
192
+ image = Image.open(uploaded_file)
193
+ st.image(image, caption='Uploaded Image', use_column_width=True)
194
+
195
+ if st.button('Find Similar Items'):
196
+ with st.spinner('Processing...'):
197
+ # Save uploaded image temporarily
198
+ temp_path = "temp_upload.jpg"
199
+ image.save(temp_path)
200
+
201
+ # Segment the uploaded image
202
+ segmented_image = segment_image(temp_path)
203
+ st.image(segmented_image, caption='Segmented Image', use_column_width=True)
204
+
205
+ # Get embedding for segmented image
206
+ query_embedding = get_image_embedding(segmented_image)
207
+ similar_images = find_similar_images(query_embedding)
208
+
209
+ st.subheader("Similar Items:")
210
+ for img in similar_images:
211
+ col1, col2 = st.columns(2)
212
+ with col1:
213
+ st.image(img['info']['image_url'], use_column_width=True)
214
+ with col2:
215
+ st.write(f"Name: {img['info']['name']}")
216
+ st.write(f"Brand: {img['info']['brand']}")
217
+ st.write(f"Category: {img['info']['category']}")
218
+ st.write(f"Price: {img['info']['price']}")
219
+ st.write(f"Discount: {img['info']['discount']}%")
220
+ st.write(f"Similarity: {img['similarity']:.2f}")