ljlmike commited on
Commit
9d09d3e
·
1 Parent(s): c3ca6f6

enable user saving preference to google sheets

Browse files
Files changed (1) hide show
  1. app.py +143 -94
app.py CHANGED
@@ -1,133 +1,182 @@
1
- import os
2
- import requests
3
- import gradio as gr
4
  from typing import List, Dict, Any
5
 
6
- """
7
- Lean Finder Retrieval Demo (Gradio Space)
8
- ----------------------------------------
9
- Secrets to add in **Settings ▸ Environment variables & secrets**
10
- ENDPOINT_URL – full HTTPS URL of your Hugging Face *Inference Endpoint*
11
- HF_TOKEN – (optional) Bearer token that can invoke the endpoint
12
- """
13
 
14
  ENDPOINT_ID = os.getenv("ENDPOINT_ID")
15
- HF_TOKEN = os.getenv("HF_TOKEN")
16
-
17
  if not ENDPOINT_ID:
18
  raise ValueError("ENDPOINT_ID is not set")
19
 
20
-
21
  def _call_endpoint(payload: Dict[str, Any]) -> Dict[str, Any]:
22
- """Call our private inference endpoint"""
23
  headers = {"Accept": "application/json", "Content-Type": "application/json"}
24
  if HF_TOKEN:
25
  headers["Authorization"] = f"Bearer {HF_TOKEN}"
26
-
27
- resp = requests.post(ENDPOINT_ID, json=payload, headers=headers, timeout=60)
28
- resp.raise_for_status()
29
- return resp.json()
30
-
31
 
32
  def _call_leansearch(query: str, k: int) -> List[Dict[str, Any]]:
33
- """Call LeanSearch public endpoint"""
34
  payload = {"query": [query], "num_results": str(k)}
35
- resp = requests.post("https://leansearch.net/search", json=payload, timeout=60)
36
- resp.raise_for_status()
37
- data = resp.json()
38
  return data[0] if isinstance(data, list) and data else []
39
 
40
-
41
- def _render_leanfinder(results: List[Dict[str, Any]]) -> str:
42
- if not results:
43
- return "<p>No results from Lean Finder.</p>"
44
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  rows = "\n".join(
46
  f"<tr><td>{i}</td><td>{r['score']:.4f}</td>"
47
  f"<td><code style='white-space:pre-wrap'>{r['formal_statement']}</code></td>"
48
  f"<td style='white-space:pre-wrap'>{r['informal_statement']}</td></tr>"
49
- for i, r in enumerate(results, 1)
50
  )
51
  return (
52
- "<h3>LeanFinder</h3>"
53
- "<table><thead><tr><th>Rank</th><th>Score</th><th>Formal statement</th><th>Informal statement</th></tr></thead>"
 
54
  f"<tbody>{rows}</tbody></table>"
55
  )
56
 
57
-
58
- def _render_leansearch(results: List[Dict[str, Any]]) -> str:
59
- if not results:
60
  return "<p>No results from LeanSearch.</p>"
61
-
62
- def row(idx: int, entry: Dict[str, Any]) -> str:
63
- res = entry.get("result", {})
64
- kind = res.get("kind", "").strip()
65
- name = ".".join(res.get("name", []))
66
- sig = res.get("signature", "")
67
- val = res.get("value", "").lstrip()
68
- formal = f"{kind} {name}{sig} {val}".strip()
69
- informal = res.get("informal_description", "")
70
- dist = entry.get("distance", 0.0)
71
- return (
72
- f"<tr><td>{idx}</td><td>{dist:.4f}</td>"
73
- f"<td><code style='white-space:pre-wrap'>{formal}</code></td>"
74
- f"<td style='white-space:pre-wrap'>{informal}</td></tr>"
75
- )
76
-
77
- rows = "\n".join(row(i, e) for i, e in enumerate(results, 1))
78
  return (
79
  "<h3>LeanSearch</h3>"
80
- "<table><thead><tr><th>Rank</th><th>Distance</th><th>Formal statement</th><th>Informal statement</th></tr></thead>"
 
81
  f"<tbody>{rows}</tbody></table>"
82
  )
83
 
84
-
85
- def retrieve(query: str, top_k: int, mode: str):
86
- query = query.strip()
87
- if not query:
88
- return "<p>Please enter a query.</p>"
89
-
90
- payload = {"inputs": query, "top_k": top_k}
91
- try:
92
- lf_json = _call_endpoint(payload)
93
- lf_html = _render_leanfinder(lf_json.get("results", []))
94
- except Exception as e:
95
- lf_html = f"<p>Error contacting Lean Finder endpoint: {e}</p>"
96
-
97
- if mode == "Normal":
98
- return lf_html
99
-
100
- try:
101
- ls_html = _render_leansearch(_call_leansearch(query, top_k))
102
- except Exception as e:
103
- ls_html = f"<p>Error contacting LeanSearch: {e}</p>"
104
-
105
- return (
106
- "<div style='display:flex; gap:0.5rem; width:100%; box-sizing:border-box;'>"
107
- f"<div style='flex:1 1 0;'>{lf_html}</div>"
108
- f"<div style='flex:1 1 0;'>{ls_html}</div>"
109
- "</div>"
110
- )
111
-
112
-
113
  CUSTOM_CSS = """
114
  html,body{margin:0;padding:0;width:100%;}
115
- .gradio-container,.gradio-container .block{max-width:none!important;width:100%!important;margin:0;padding:0 0.5rem;}
116
- table{width:100%;border-collapse:collapse;font-size:0.9rem;}
117
- th,td{border:1px solid #ddd;padding:6px;vertical-align:top;}
118
- th{background:#f5f5f5;font-weight:600;}
119
- code{background:#f8f8f8;border:1px solid #eee;border-radius:4px;padding:2px 4px;}
120
  """
121
 
122
- with gr.Blocks(title="LeanFinder Retrieval", css=CUSTOM_CSS) as demo:
123
- gr.Markdown("""# LeanFinder – Retrieval Demo \nChoose **Normal** for standard retrieval or **Arena** to compare side‑by‑side with LeanSearch.""")
 
124
  with gr.Row():
125
- query_box = gr.Textbox(label="Query", lines=2, placeholder="Type your Lean query here …")
126
- topk_slider = gr.Slider(label="Topk", minimum=1, maximum=20, step=1, value=3)
127
- mode_selector = gr.Radio(["Normal", "Arena"], value="Normal", label="Mode")
128
- run_button = gr.Button("Retrieve")
129
- results_html = gr.HTML()
130
- run_button.click(retrieve, inputs=[query_box, topk_slider, mode_selector], outputs=results_html)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  if __name__ == "__main__":
133
  demo.launch()
 
1
+ import os, json, datetime, threading, requests
 
 
2
  from typing import List, Dict, Any
3
 
4
+ import gradio as gr
5
+ import gspread
6
+ from google.oauth2.service_account import Credentials
 
 
 
 
7
 
8
  ENDPOINT_ID = os.getenv("ENDPOINT_ID")
9
+ HF_TOKEN = os.getenv("HF_TOKEN")
 
10
  if not ENDPOINT_ID:
11
  raise ValueError("ENDPOINT_ID is not set")
12
 
 
13
  def _call_endpoint(payload: Dict[str, Any]) -> Dict[str, Any]:
 
14
  headers = {"Accept": "application/json", "Content-Type": "application/json"}
15
  if HF_TOKEN:
16
  headers["Authorization"] = f"Bearer {HF_TOKEN}"
17
+ r = requests.post(ENDPOINT_ID, json=payload, headers=headers, timeout=60)
18
+ r.raise_for_status()
19
+ return r.json()
 
 
20
 
21
  def _call_leansearch(query: str, k: int) -> List[Dict[str, Any]]:
 
22
  payload = {"query": [query], "num_results": str(k)}
23
+ r = requests.post("https://leansearch.net/search", json=payload, timeout=60)
24
+ r.raise_for_status()
25
+ data = r.json()
26
  return data[0] if isinstance(data, list) and data else []
27
 
28
+ # Google Sheets setup
29
+ SERVICE_ACCOUNT_INFO = os.getenv("GCP_SERVICE_ACCOUNT_JSON")
30
+ if not SERVICE_ACCOUNT_INFO:
31
+ raise ValueError("Missing GCP_SERVICE_ACCOUNT_JSON env var")
32
+
33
+ SCOPES = [
34
+ "https://www.googleapis.com/auth/spreadsheets",
35
+ "https://www.googleapis.com/auth/drive",
36
+ ]
37
+
38
+ credentials = Credentials.from_service_account_info(
39
+ json.loads(SERVICE_ACCOUNT_INFO), scopes=SCOPES
40
+ )
41
+ gc = gspread.authorize(credentials)
42
+ worksheet = gc.open("arena_votes").sheet1
43
+ SHEET_LOCK = threading.Lock()
44
+
45
+ def _save_vote(choice: str,
46
+ query: str,
47
+ lf_json: Dict[str, Any],
48
+ ls_json: List[Dict[str, Any]]) -> str:
49
+ """Append one row: timestamp | query | choice | results_json"""
50
+ if not choice:
51
+ return "⚠️ Please pick a system before submitting."
52
+ payload = {
53
+ "lean_finder": lf_json,
54
+ "lean_search": ls_json,
55
+ }
56
+ row = [
57
+ datetime.datetime.utcnow().isoformat(timespec="seconds"),
58
+ query,
59
+ choice,
60
+ json.dumps(payload, ensure_ascii=False),
61
+ ]
62
+ with SHEET_LOCK:
63
+ worksheet.append_row(row, value_input_option="RAW")
64
+ return "✅ Vote recorded — thanks!"
65
+
66
+ # Rendering contents
67
+ def _render_leanfinder(res: List[Dict[str, Any]]) -> str:
68
+ if not res:
69
+ return "<p>No results from Lean Finder.</p>"
70
  rows = "\n".join(
71
  f"<tr><td>{i}</td><td>{r['score']:.4f}</td>"
72
  f"<td><code style='white-space:pre-wrap'>{r['formal_statement']}</code></td>"
73
  f"<td style='white-space:pre-wrap'>{r['informal_statement']}</td></tr>"
74
+ for i, r in enumerate(res, 1)
75
  )
76
  return (
77
+ "<h3>Lean Finder</h3>"
78
+ "<table><thead><tr><th>Rank</th><th>Score</th>"
79
+ "<th>Formal statement</th><th>Informal statement</th></tr></thead>"
80
  f"<tbody>{rows}</tbody></table>"
81
  )
82
 
83
+ def _render_leansearch(res: List[Dict[str, Any]]) -> str:
84
+ if not res:
 
85
  return "<p>No results from LeanSearch.</p>"
86
+ def row(i: int, e: Dict[str, Any]) -> str:
87
+ r = e.get("result", {})
88
+ kind, name = r.get("kind","").strip(), ".".join(r.get("name", []))
89
+ sig, val = r.get("signature", ""), r.get("value","").lstrip()
90
+ dist = e.get("distance", 0.0)
91
+ formal = f"{kind} {name}{sig} {val}".strip()
92
+ return (f"<tr><td>{i}</td><td>{dist:.4f}</td>"
93
+ f"<td><code style='white-space:pre-wrap'>{formal}</code></td>"
94
+ f"<td style='white-space:pre-wrap'>{r.get('informal_description','')}</td></tr>")
95
+ rows = "\n".join(row(i,e) for i,e in enumerate(res,1))
 
 
 
 
 
 
 
96
  return (
97
  "<h3>LeanSearch</h3>"
98
+ "<table><thead><tr><th>Rank</th><th>Distance</th>"
99
+ "<th>Formal statement</th><th>Informal statement</th></tr></thead>"
100
  f"<tbody>{rows}</tbody></table>"
101
  )
102
 
103
+ # Gradio app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  CUSTOM_CSS = """
105
  html,body{margin:0;padding:0;width:100%;}
106
+ .gradio-container,.gradio-container .block{max-width:none!important;width:100%!important;padding:0 0.5rem;}
107
+ table{width:100%;border-collapse:collapse;font-size:0.9rem;}
108
+ th,td{border:1px solid #ddd;padding:6px;vertical-align:top;}
109
+ th{background:#f5f5f5;font-weight:600;}
110
+ code{background:#f8f8f8;border:1px solid #eee;border-radius:4px;padding:2px 4px;}
111
  """
112
 
113
+ with gr.Blocks(title="Lean Finder Retrieval", css=CUSTOM_CSS) as demo:
114
+ gr.Markdown("# Lean Finder – Retrieval Demo\n"
115
+ "Choose **Normal** for standard retrieval or **Arena** to compare and vote.")
116
  with gr.Row():
117
+ query_box = gr.Textbox(label="Query", lines=2, placeholder="Type your Lean query here …")
118
+ topk_slider = gr.Slider(label="Top-k", minimum=1, maximum=20, step=1, value=3)
119
+ mode_sel = gr.Radio(["Normal", "Arena"], value="Normal", label="Mode")
120
+ run_btn = gr.Button("Retrieve")
121
+ results_html= gr.HTML()
122
+
123
+ # voting widgets
124
+ with gr.Row():
125
+ vote_radio = gr.Radio(
126
+ ["Lean Finder better", "LeanSearch better", "Tie", "Both are bad"],
127
+ label="Which result is better?", visible=False
128
+ )
129
+ submit_btn = gr.Button("Submit vote", visible=False)
130
+ vote_status = gr.Textbox(label="", interactive=False, max_lines=1, visible=False)
131
+
132
+ # internal state: stored per browser session
133
+ st_query = gr.State("")
134
+ st_lf_js = gr.State({})
135
+ st_ls_js = gr.State([])
136
+
137
+ # callbacks
138
+ def retrieve(query: str, k: int, mode: str):
139
+ query = query.strip()
140
+ if not query:
141
+ return gr.update(value="<p>Please enter a query.</p>"), query, {}, []
142
+
143
+ payload = {"inputs": query, "top_k": k}
144
+ lf_json = _call_endpoint(payload).get("results", [])
145
+ lf_html = _render_leanfinder(lf_json)
146
+
147
+ if mode == "Normal":
148
+ return lf_html, query, lf_json, [] # ls_json empty
149
+
150
+ # Arena: get LeanSearch as well
151
+ ls_json = _call_leansearch(query, k)
152
+ ls_html = _render_leansearch(ls_json)
153
+ html = (
154
+ "<div style='display:flex; gap:0.5rem;'>"
155
+ f"<div style='flex:1 1 0;'>{lf_html}</div>"
156
+ f"<div style='flex:1 1 0;'>{ls_html}</div>"
157
+ "</div>"
158
+ )
159
+ return html, query, lf_json, ls_json
160
+
161
+ run_btn.click(
162
+ retrieve,
163
+ inputs=[query_box, topk_slider, mode_sel],
164
+ outputs=[results_html, st_query, st_lf_js, st_ls_js],
165
+ )
166
+
167
+ # show/hide voting widgets when mode changes
168
+ def _toggle_widgets(mode):
169
+ vis = (mode == "Arena")
170
+ return [gr.update(visible=vis), gr.update(visible=vis), gr.update(visible=vis)]
171
+ mode_sel.change(_toggle_widgets, inputs=mode_sel,
172
+ outputs=[vote_radio, submit_btn, vote_status])
173
+
174
+ # submit vote
175
+ submit_btn.click(
176
+ _save_vote,
177
+ inputs=[vote_radio, st_query, st_lf_js, st_ls_js],
178
+ outputs=vote_status
179
+ )
180
 
181
  if __name__ == "__main__":
182
  demo.launch()