leedoming commited on
Commit
4e2381b
·
verified ·
1 Parent(s): fcd89ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -98
app.py CHANGED
@@ -30,6 +30,44 @@ def load_data():
30
 
31
  data = load_data()
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  # Process database with segmentation
34
  @st.cache_data
35
  def process_database():
@@ -120,101 +158,4 @@ if api_key:
120
  st.write(f"Discount: {img['info']['discount']}%")
121
  st.write(f"Similarity: {img['similarity']:.2f}")
122
  else:
123
- st.warning("Please enter your Roboflow API Key to use the app.")
124
-
125
-
126
- # Process database with segmentation
127
- @st.cache_data
128
- def download_and_process_image(image_url):
129
- try:
130
- response = requests.get(image_url)
131
- response.raise_for_status() # Raises an HTTPError for bad responses
132
- image = Image.open(BytesIO(response.content))
133
-
134
- # Convert image to RGB mode if it's in RGBA mode
135
- if image.mode == 'RGBA':
136
- image = image.convert('RGB')
137
-
138
- return image
139
- except requests.RequestException as e:
140
- st.error(f"Error downloading image: {e}")
141
- return None
142
- except Exception as e:
143
- st.error(f"Error processing image: {e}")
144
- return None
145
-
146
- def process_database():
147
- database_embeddings = []
148
- database_info = []
149
- for item in data:
150
- image_url = item['이미지 링크'][0]
151
- product_id = item.get('\ufeff상품 ID') or item.get('상품 ID')
152
-
153
- image = download_and_process_image(image_url)
154
- if image is None:
155
- continue
156
-
157
- # Save the image temporarily
158
- temp_path = f"temp_{product_id}.jpg"
159
- image.save(temp_path, 'JPEG')
160
-
161
- segmented_image = segment_image(temp_path)
162
- embedding = get_image_embedding(segmented_image)
163
-
164
- database_embeddings.append(embedding)
165
- database_info.append({
166
- 'id': product_id,
167
- 'category': item['카테고리'],
168
- 'brand': item['브랜드명'],
169
- 'name': item['제품명'],
170
- 'price': item['정가'],
171
- 'discount': item['할인율'],
172
- 'image_url': image_url
173
- })
174
-
175
- return np.vstack(database_embeddings), database_info
176
-
177
- def find_similar_images(query_embedding, top_k=5):
178
- similarities = np.dot(database_embeddings, query_embedding.T).squeeze()
179
- top_indices = np.argsort(similarities)[::-1][:top_k]
180
-
181
- results = []
182
- for idx in top_indices:
183
- results.append({
184
- 'info': database_info[idx],
185
- 'similarity': similarities[idx]
186
- })
187
-
188
- return results
189
-
190
- uploaded_file = st.file_uploader("Choose an image...", type="jpg")
191
- if uploaded_file is not None:
192
- image = Image.open(uploaded_file)
193
- st.image(image, caption='Uploaded Image', use_column_width=True)
194
-
195
- if st.button('Find Similar Items'):
196
- with st.spinner('Processing...'):
197
- # Save uploaded image temporarily
198
- temp_path = "temp_upload.jpg"
199
- image.save(temp_path)
200
-
201
- # Segment the uploaded image
202
- segmented_image = segment_image(temp_path)
203
- st.image(segmented_image, caption='Segmented Image', use_column_width=True)
204
-
205
- # Get embedding for segmented image
206
- query_embedding = get_image_embedding(segmented_image)
207
- similar_images = find_similar_images(query_embedding)
208
-
209
- st.subheader("Similar Items:")
210
- for img in similar_images:
211
- col1, col2 = st.columns(2)
212
- with col1:
213
- st.image(img['info']['image_url'], use_column_width=True)
214
- with col2:
215
- st.write(f"Name: {img['info']['name']}")
216
- st.write(f"Brand: {img['info']['brand']}")
217
- st.write(f"Category: {img['info']['category']}")
218
- st.write(f"Price: {img['info']['price']}")
219
- st.write(f"Discount: {img['info']['discount']}%")
220
- st.write(f"Similarity: {img['similarity']:.2f}")
 
30
 
31
  data = load_data()
32
 
33
+ # Helper functions
34
+ @st.cache_data
35
+ def download_and_process_image(image_url):
36
+ try:
37
+ response = requests.get(image_url)
38
+ response.raise_for_status() # Raises an HTTPError for bad responses
39
+ image = Image.open(BytesIO(response.content))
40
+
41
+ # Convert image to RGB mode if it's in RGBA mode
42
+ if image.mode == 'RGBA':
43
+ image = image.convert('RGB')
44
+
45
+ return image
46
+ except requests.RequestException as e:
47
+ st.error(f"Error downloading image: {e}")
48
+ return None
49
+ except Exception as e:
50
+ st.error(f"Error processing image: {e}")
51
+ return None
52
+
53
+ def segment_image(image_path):
54
+ # Implement your segmentation logic here
55
+ # For now, we'll just return the original image
56
+ return Image.open(image_path)
57
+
58
+ def get_image_embedding(image):
59
+ image_tensor = preprocess_val(image).unsqueeze(0).to(device)
60
+ with torch.no_grad():
61
+ image_features = model.encode_image(image_tensor)
62
+ image_features /= image_features.norm(dim=-1, keepdim=True)
63
+ return image_features.cpu().numpy()
64
+
65
+ def setup_roboflow_client(api_key):
66
+ return InferenceHTTPClient(
67
+ api_url="https://outline.roboflow.com",
68
+ api_key=api_key
69
+ )
70
+
71
  # Process database with segmentation
72
  @st.cache_data
73
  def process_database():
 
158
  st.write(f"Discount: {img['info']['discount']}%")
159
  st.write(f"Similarity: {img['similarity']:.2f}")
160
  else:
161
+ st.warning("Please enter your Roboflow API Key to use the app.")