leedoming commited on
Commit
dfc8c96
·
verified ·
1 Parent(s): 4e2381b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -24
app.py CHANGED
@@ -50,11 +50,6 @@ def download_and_process_image(image_url):
50
  st.error(f"Error processing image: {e}")
51
  return None
52
 
53
- def segment_image(image_path):
54
- # Implement your segmentation logic here
55
- # For now, we'll just return the original image
56
- return Image.open(image_path)
57
-
58
  def get_image_embedding(image):
59
  image_tensor = preprocess_val(image).unsqueeze(0).to(device)
60
  with torch.no_grad():
@@ -68,9 +63,49 @@ def setup_roboflow_client(api_key):
68
  api_key=api_key
69
  )
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  # Process database with segmentation
72
  @st.cache_data
73
- def process_database():
74
  database_embeddings = []
75
  database_info = []
76
  for item in data:
@@ -85,7 +120,7 @@ def process_database():
85
  temp_path = f"temp_{product_id}.jpg"
86
  image.save(temp_path, 'JPEG')
87
 
88
- segmented_image = segment_image(temp_path)
89
  embedding = get_image_embedding(segmented_image)
90
 
91
  database_embeddings.append(embedding)
@@ -101,22 +136,6 @@ def process_database():
101
 
102
  return np.vstack(database_embeddings), database_info
103
 
104
- # Initialize database_embeddings and database_info
105
- database_embeddings, database_info = process_database()
106
-
107
- def find_similar_images(query_embedding, top_k=5):
108
- similarities = np.dot(database_embeddings, query_embedding.T).squeeze()
109
- top_indices = np.argsort(similarities)[::-1][:top_k]
110
-
111
- results = []
112
- for idx in top_indices:
113
- results.append({
114
- 'info': database_info[idx],
115
- 'similarity': similarities[idx]
116
- })
117
-
118
- return results
119
-
120
  # Streamlit app
121
  st.title("Fashion Search App with Segmentation")
122
 
@@ -125,6 +144,9 @@ api_key = st.text_input("Enter your Roboflow API Key", type="password")
125
 
126
  if api_key:
127
  CLIENT = setup_roboflow_client(api_key)
 
 
 
128
 
129
  uploaded_file = st.file_uploader("Choose an image...", type="jpg")
130
  if uploaded_file is not None:
@@ -138,7 +160,7 @@ if api_key:
138
  image.save(temp_path)
139
 
140
  # Segment the uploaded image
141
- segmented_image = segment_image(temp_path)
142
  st.image(segmented_image, caption='Segmented Image', use_column_width=True)
143
 
144
  # Get embedding for segmented image
 
50
  st.error(f"Error processing image: {e}")
51
  return None
52
 
 
 
 
 
 
53
  def get_image_embedding(image):
54
  image_tensor = preprocess_val(image).unsqueeze(0).to(device)
55
  with torch.no_grad():
 
63
  api_key=api_key
64
  )
65
 
66
+ def segment_image(image_path, client):
67
+ try:
68
+ # 이미지 파일 읽기
69
+ with open(image_path, "rb") as image_file:
70
+ image_data = image_file.read()
71
+
72
+ # 이미지를 base64로 인코딩
73
+ encoded_image = base64.b64encode(image_data).decode('utf-8')
74
+
75
+ # 원본 이미지 로드
76
+ image = cv2.imread(image_path)
77
+ image = cv2.resize(image, (800, 600))
78
+ mask = np.zeros(image.shape, dtype=np.uint8)
79
+
80
+ # Roboflow API 호출
81
+ results = client.infer(encoded_image, model_id="closet/1")
82
+ results = json.loads(results)
83
+
84
+ if 'predictions' in results:
85
+ for prediction in results['predictions']:
86
+ points = prediction['points']
87
+ pts = np.array([[p['x'], p['y']] for p in points], np.int32)
88
+ scale_x = image.shape[1] / results['image']['width']
89
+ scale_y = image.shape[0] / results['image']['height']
90
+ pts = pts * [scale_x, scale_y]
91
+ pts = pts.astype(np.int32)
92
+ pts = pts.reshape((-1, 1, 2))
93
+ cv2.fillPoly(mask, [pts], color=(255, 255, 255)) # White mask
94
+
95
+ segmented_image = cv2.bitwise_and(image, mask)
96
+ else:
97
+ st.warning("No predictions found in the image. Returning original image.")
98
+ segmented_image = image
99
+
100
+ return Image.fromarray(cv2.cvtColor(segmented_image, cv2.COLOR_BGR2RGB))
101
+ except Exception as e:
102
+ st.error(f"Error in segmentation: {str(e)}")
103
+ # 원본 이미지를 다시 읽어 반환
104
+ return Image.open(image_path)
105
+
106
  # Process database with segmentation
107
  @st.cache_data
108
+ def process_database(client):
109
  database_embeddings = []
110
  database_info = []
111
  for item in data:
 
120
  temp_path = f"temp_{product_id}.jpg"
121
  image.save(temp_path, 'JPEG')
122
 
123
+ segmented_image = segment_image(temp_path, client)
124
  embedding = get_image_embedding(segmented_image)
125
 
126
  database_embeddings.append(embedding)
 
136
 
137
  return np.vstack(database_embeddings), database_info
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  # Streamlit app
140
  st.title("Fashion Search App with Segmentation")
141
 
 
144
 
145
  if api_key:
146
  CLIENT = setup_roboflow_client(api_key)
147
+
148
+ # Initialize database_embeddings and database_info
149
+ database_embeddings, database_info = process_database(CLIENT)
150
 
151
  uploaded_file = st.file_uploader("Choose an image...", type="jpg")
152
  if uploaded_file is not None:
 
160
  image.save(temp_path)
161
 
162
  # Segment the uploaded image
163
+ segmented_image = segment_image(temp_path, CLIENT)
164
  st.image(segmented_image, caption='Segmented Image', use_column_width=True)
165
 
166
  # Get embedding for segmented image