vsrinivas commited on
Commit
53dc0c8
·
verified ·
1 Parent(s): 773c328

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -0
app.py CHANGED
@@ -11,6 +11,7 @@ model = SentenceTransformer('sentence-transformers/clip-ViT-B-32')
11
  fashion = load_dataset("ashraq/fashion-product-images-small", split="train")
12
  images = fashion['image']
13
  metadata = fashion.remove_columns('image')
 
14
 
15
  INDEX_NAME = 'srinivas-hybrid-search'
16
  PINECONE_API_KEY = os.getenv(pinecone_api_key)
@@ -19,3 +20,100 @@ index = pinecone.Index(INDEX_NAME)
19
  bm25 = BM25Encoder()
20
  bm25.fit(metadata['productDisplayName'])
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  fashion = load_dataset("ashraq/fashion-product-images-small", split="train")
12
  images = fashion['image']
13
  metadata = fashion.remove_columns('image')
14
+ item_list = list(set(metadata['productDisplayName']))
15
 
16
  INDEX_NAME = 'srinivas-hybrid-search'
17
  PINECONE_API_KEY = os.getenv(pinecone_api_key)
 
20
  bm25 = BM25Encoder()
21
  bm25.fit(metadata['productDisplayName'])
22
 
23
+ # Function to display images in a grid layout
24
+ def display_result(image_batch, match_batch):
25
+ figures = []
26
+ for img, title in zip(image_batch, match_batch):
27
+ # Ensure the image is in the correct format for encoding
28
+ if img.mode != 'RGB':
29
+ img = img.convert('RGB')
30
+
31
+ # Convert image to bytes and encode as base64
32
+ b = BytesIO()
33
+ img.save(b, format='PNG')
34
+ img_str = b64encode(b.getvalue()).decode('utf-8')
35
+
36
+ # Create HTML figure element with the image title
37
+ figures.append(f'''
38
+ <figure style="margin: 0; padding: 0; text-align: left;">
39
+ <figcaption style="font-weight: bold; margin:0;">{title}</figcaption>
40
+ <img src="data:image/png;base64,{img_str}" style="width: 180px; height: 240px; margin: 0;" >
41
+ </figure>
42
+ ''')
43
+
44
+ # Combine all figures into a single HTML string with reduced spacing
45
+ html_content = f'''
46
+ <div style="display: grid; grid-template-columns: repeat(4, 1fr); gap: 20px; align-items: start;">
47
+ {''.join(figures)}
48
+ </div>
49
+ '''
50
+ return html_content
51
+
52
+
53
+ # Function to scale vectors based on alpha for hybrid search
54
+ def hybrid_scale(dense, sparse, alpha):
55
+ if alpha < 0 or alpha > 1:
56
+ raise ValueError("Alpha must be between 0 and 1")
57
+
58
+ # Scale sparse and dense vectors to create hybrid search vectors
59
+ hsparse = {
60
+ 'indices': sparse['indices'],
61
+ 'values': [v * (1 - alpha) for v in sparse['values']]
62
+ }
63
+ hdense = [v * alpha for v in dense]
64
+
65
+ return hdense, hsparse
66
+
67
+
68
+ # Function to process the input text and slider value, with error handling
69
+ def process_input(query, slider_value):
70
+ try:
71
+ slider_value = float(slider_value)
72
+ sparse = bm25.encode_queries(query)
73
+ dense = model.encode(query).tolist()
74
+ hdense, hsparse = hybrid_scale(dense, sparse, slider_value)
75
+
76
+ result = index.query(
77
+ top_k=12,
78
+ vector=hdense, # Use hybrid dense vector
79
+ sparse_vector=hsparse, # Use hybrid sparse vector
80
+ include_metadata=True
81
+ )
82
+
83
+ imgs = [images[int(r["id"])] for r in result["matches"]]
84
+ matches = [x["metadata"]['productDisplayName'] for x in result["matches"]]
85
+
86
+ print(f"No. of matching images: {len(imgs)}")
87
+ print(matches)
88
+ return display_result(imgs, matches)
89
+
90
+ except Exception as e:
91
+ # Handle exceptions and return a friendly error message
92
+ return f"<p style='color:red;'>Not found. Try another search: {str(e)}</p>"
93
+
94
+
95
+ # Function to update the textbox value when a dropdown choice is selected
96
+ def update_textbox(choice):
97
+ return choice
98
+
99
+
100
+ # Gradio interface
101
+ with gr.Blocks() as demo:
102
+ gr.Markdown("# Search for Your Fashion Item")
103
+
104
+ with gr.Row():
105
+ dropdown = gr.Dropdown(choices=item_list, label="Select an item from here..", value= "Select an item from this list or start typing", interactive=True)
106
+ text_input = gr.Textbox(label="Alternatively, enter item text..", value="Type-in what you are looking for", interactive=True)
107
+ slider = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.5, label="Adjust the Slider to get better results that suit what you are looking for..", interactive=True)
108
+
109
+ # Automatically update the text input when a dropdown selection is made
110
+ dropdown.change(fn=update_textbox, inputs=dropdown, outputs=text_input)
111
+
112
+ # HTML output box to display images
113
+ html_output = gr.HTML(label="Relevant Images")
114
+
115
+ # Process and display images based on text input or slider changes
116
+ text_input.change(fn=process_input, inputs=[text_input, slider], outputs=html_output)
117
+ slider.change(fn=process_input, inputs=[text_input, slider], outputs=html_output)
118
+
119
+ demo.launch(debug=True, share=True)