Update app.py
Browse files
app.py
CHANGED
@@ -11,6 +11,7 @@ model = SentenceTransformer('sentence-transformers/clip-ViT-B-32')
|
|
11 |
fashion = load_dataset("ashraq/fashion-product-images-small", split="train")
|
12 |
images = fashion['image']
|
13 |
metadata = fashion.remove_columns('image')
|
|
|
14 |
|
15 |
INDEX_NAME = 'srinivas-hybrid-search'
|
16 |
PINECONE_API_KEY = os.getenv(pinecone_api_key)
|
@@ -19,3 +20,100 @@ index = pinecone.Index(INDEX_NAME)
|
|
19 |
bm25 = BM25Encoder()
|
20 |
bm25.fit(metadata['productDisplayName'])
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
fashion = load_dataset("ashraq/fashion-product-images-small", split="train")
|
12 |
images = fashion['image']
|
13 |
metadata = fashion.remove_columns('image')
|
14 |
+
item_list = list(set(metadata['productDisplayName']))
|
15 |
|
16 |
INDEX_NAME = 'srinivas-hybrid-search'
|
17 |
PINECONE_API_KEY = os.getenv(pinecone_api_key)
|
|
|
20 |
bm25 = BM25Encoder()
|
21 |
bm25.fit(metadata['productDisplayName'])
|
22 |
|
23 |
+
# Function to display images in a grid layout
|
24 |
+
def display_result(image_batch, match_batch):
|
25 |
+
figures = []
|
26 |
+
for img, title in zip(image_batch, match_batch):
|
27 |
+
# Ensure the image is in the correct format for encoding
|
28 |
+
if img.mode != 'RGB':
|
29 |
+
img = img.convert('RGB')
|
30 |
+
|
31 |
+
# Convert image to bytes and encode as base64
|
32 |
+
b = BytesIO()
|
33 |
+
img.save(b, format='PNG')
|
34 |
+
img_str = b64encode(b.getvalue()).decode('utf-8')
|
35 |
+
|
36 |
+
# Create HTML figure element with the image title
|
37 |
+
figures.append(f'''
|
38 |
+
<figure style="margin: 0; padding: 0; text-align: left;">
|
39 |
+
<figcaption style="font-weight: bold; margin:0;">{title}</figcaption>
|
40 |
+
<img src="data:image/png;base64,{img_str}" style="width: 180px; height: 240px; margin: 0;" >
|
41 |
+
</figure>
|
42 |
+
''')
|
43 |
+
|
44 |
+
# Combine all figures into a single HTML string with reduced spacing
|
45 |
+
html_content = f'''
|
46 |
+
<div style="display: grid; grid-template-columns: repeat(4, 1fr); gap: 20px; align-items: start;">
|
47 |
+
{''.join(figures)}
|
48 |
+
</div>
|
49 |
+
'''
|
50 |
+
return html_content
|
51 |
+
|
52 |
+
|
53 |
+
# Function to scale vectors based on alpha for hybrid search
|
54 |
+
def hybrid_scale(dense, sparse, alpha):
|
55 |
+
if alpha < 0 or alpha > 1:
|
56 |
+
raise ValueError("Alpha must be between 0 and 1")
|
57 |
+
|
58 |
+
# Scale sparse and dense vectors to create hybrid search vectors
|
59 |
+
hsparse = {
|
60 |
+
'indices': sparse['indices'],
|
61 |
+
'values': [v * (1 - alpha) for v in sparse['values']]
|
62 |
+
}
|
63 |
+
hdense = [v * alpha for v in dense]
|
64 |
+
|
65 |
+
return hdense, hsparse
|
66 |
+
|
67 |
+
|
68 |
+
# Function to process the input text and slider value, with error handling
|
69 |
+
def process_input(query, slider_value):
|
70 |
+
try:
|
71 |
+
slider_value = float(slider_value)
|
72 |
+
sparse = bm25.encode_queries(query)
|
73 |
+
dense = model.encode(query).tolist()
|
74 |
+
hdense, hsparse = hybrid_scale(dense, sparse, slider_value)
|
75 |
+
|
76 |
+
result = index.query(
|
77 |
+
top_k=12,
|
78 |
+
vector=hdense, # Use hybrid dense vector
|
79 |
+
sparse_vector=hsparse, # Use hybrid sparse vector
|
80 |
+
include_metadata=True
|
81 |
+
)
|
82 |
+
|
83 |
+
imgs = [images[int(r["id"])] for r in result["matches"]]
|
84 |
+
matches = [x["metadata"]['productDisplayName'] for x in result["matches"]]
|
85 |
+
|
86 |
+
print(f"No. of matching images: {len(imgs)}")
|
87 |
+
print(matches)
|
88 |
+
return display_result(imgs, matches)
|
89 |
+
|
90 |
+
except Exception as e:
|
91 |
+
# Handle exceptions and return a friendly error message
|
92 |
+
return f"<p style='color:red;'>Not found. Try another search: {str(e)}</p>"
|
93 |
+
|
94 |
+
|
95 |
+
# Function to update the textbox value when a dropdown choice is selected
|
96 |
+
def update_textbox(choice):
|
97 |
+
return choice
|
98 |
+
|
99 |
+
|
100 |
+
# Gradio interface
|
101 |
+
with gr.Blocks() as demo:
|
102 |
+
gr.Markdown("# Search for Your Fashion Item")
|
103 |
+
|
104 |
+
with gr.Row():
|
105 |
+
dropdown = gr.Dropdown(choices=item_list, label="Select an item from here..", value= "Select an item from this list or start typing", interactive=True)
|
106 |
+
text_input = gr.Textbox(label="Alternatively, enter item text..", value="Type-in what you are looking for", interactive=True)
|
107 |
+
slider = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.5, label="Adjust the Slider to get better results that suit what you are looking for..", interactive=True)
|
108 |
+
|
109 |
+
# Automatically update the text input when a dropdown selection is made
|
110 |
+
dropdown.change(fn=update_textbox, inputs=dropdown, outputs=text_input)
|
111 |
+
|
112 |
+
# HTML output box to display images
|
113 |
+
html_output = gr.HTML(label="Relevant Images")
|
114 |
+
|
115 |
+
# Process and display images based on text input or slider changes
|
116 |
+
text_input.change(fn=process_input, inputs=[text_input, slider], outputs=html_output)
|
117 |
+
slider.change(fn=process_input, inputs=[text_input, slider], outputs=html_output)
|
118 |
+
|
119 |
+
demo.launch(debug=True, share=True)
|