Spaces:
Running
Running
import os, json, datetime, threading, requests | |
from typing import List, Dict, Any | |
import gradio as gr | |
import gspread | |
from google.oauth2.service_account import Credentials | |
ENDPOINT_ID = os.getenv("ENDPOINT_ID") | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
if not ENDPOINT_ID: | |
raise ValueError("ENDPOINT_ID is not set") | |
def _call_endpoint(payload: Dict[str, Any]) -> Dict[str, Any]: | |
headers = {"Accept": "application/json", "Content-Type": "application/json"} | |
if HF_TOKEN: | |
headers["Authorization"] = f"Bearer {HF_TOKEN}" | |
r = requests.post(ENDPOINT_ID, json=payload, headers=headers, timeout=60) | |
r.raise_for_status() | |
return r.json() | |
def _call_leansearch(query: str, k: int) -> List[Dict[str, Any]]: | |
payload = {"query": [query], "num_results": str(k)} | |
r = requests.post("https://leansearch.net/search", json=payload, timeout=60) | |
r.raise_for_status() | |
data = r.json() | |
return data[0] if isinstance(data, list) and data else [] | |
# Google Sheets setup | |
SERVICE_ACCOUNT_INFO = os.getenv("GCP_SERVICE_ACCOUNT_JSON") | |
if not SERVICE_ACCOUNT_INFO: | |
raise ValueError("Missing GCP_SERVICE_ACCOUNT_JSON env var") | |
SCOPES = [ | |
"https://www.googleapis.com/auth/spreadsheets", | |
"https://www.googleapis.com/auth/drive", | |
] | |
credentials = Credentials.from_service_account_info( | |
json.loads(SERVICE_ACCOUNT_INFO), scopes=SCOPES | |
) | |
gc = gspread.authorize(credentials) | |
worksheet = gc.open("arena_votes").sheet1 | |
SHEET_LOCK = threading.Lock() | |
def _save_vote(choice: str, | |
query: str, | |
lf_json: Dict[str, Any], | |
ls_json: List[Dict[str, Any]]) -> str: | |
"""Append one row: timestamp | query | choice | results_json""" | |
if not choice: | |
return "β οΈ Please pick a system before submitting." | |
payload = { | |
"lean_finder": lf_json, | |
"lean_search": ls_json, | |
} | |
row = [ | |
datetime.datetime.utcnow().isoformat(timespec="seconds"), | |
query, | |
choice, | |
json.dumps(payload, ensure_ascii=False), | |
] | |
with SHEET_LOCK: | |
worksheet.append_row(row, value_input_option="RAW") | |
return "β Vote recorded β thanks!" | |
# Rendering contents | |
def _render_leanfinder(res: List[Dict[str, Any]]) -> str: | |
if not res: | |
return "<p>No results from Lean Finder.</p>" | |
rows = "\n".join( | |
f"<tr><td>{i}</td><td>{r['score']:.4f}</td>" | |
f"<td><code style='white-space:pre-wrap'>{r['formal_statement']}</code></td>" | |
f"<td style='white-space:pre-wrap'>{r['informal_statement']}</td></tr>" | |
for i, r in enumerate(res, 1) | |
) | |
return ( | |
"<h3>Lean Finder</h3>" | |
"<table><thead><tr><th>Rank</th><th>Score</th>" | |
"<th>Formal statement</th><th>Informal statement</th></tr></thead>" | |
f"<tbody>{rows}</tbody></table>" | |
) | |
def _render_leansearch(res: List[Dict[str, Any]]) -> str: | |
if not res: | |
return "<p>No results from LeanSearch.</p>" | |
def row(i: int, e: Dict[str, Any]) -> str: | |
r = e.get("result", {}) | |
kind, name = r.get("kind","").strip(), ".".join(r.get("name", [])) | |
sig, val = r.get("signature", ""), r.get("value","").lstrip() | |
dist = e.get("distance", 0.0) | |
formal = f"{kind} {name}{sig} {val}".strip() | |
return (f"<tr><td>{i}</td><td>{dist:.4f}</td>" | |
f"<td><code style='white-space:pre-wrap'>{formal}</code></td>" | |
f"<td style='white-space:pre-wrap'>{r.get('informal_description','')}</td></tr>") | |
rows = "\n".join(row(i,e) for i,e in enumerate(res,1)) | |
return ( | |
"<h3>LeanSearch</h3>" | |
"<table><thead><tr><th>Rank</th><th>Distance</th>" | |
"<th>Formal statement</th><th>Informal statement</th></tr></thead>" | |
f"<tbody>{rows}</tbody></table>" | |
) | |
# Gradio app | |
CUSTOM_CSS = """ | |
html,body{margin:0;padding:0;width:100%;} | |
.gradio-container,.gradio-container .block{max-width:none!important;width:100%!important;padding:0 0.5rem;} | |
table{width:100%;border-collapse:collapse;font-size:0.9rem;} | |
th,td{border:1px solid #ddd;padding:6px;vertical-align:top;} | |
th{background:#f5f5f5;font-weight:600;} | |
code{background:#f8f8f8;border:1px solid #eee;border-radius:4px;padding:2px 4px;} | |
""" | |
with gr.Blocks(title="Lean Finder Retrieval", css=CUSTOM_CSS) as demo: | |
gr.Markdown("# Lean Finder β Retrieval Demo\n" | |
"Choose **Normal** for standard retrieval or **Arena** to compare and vote.") | |
with gr.Row(): | |
query_box = gr.Textbox(label="Query", lines=2, placeholder="Type your Lean query here β¦") | |
topk_slider = gr.Slider(label="Top-k", minimum=1, maximum=20, step=1, value=3) | |
mode_sel = gr.Radio(["Normal", "Arena"], value="Normal", label="Mode") | |
run_btn = gr.Button("Retrieve") | |
results_html= gr.HTML() | |
# voting widgets | |
with gr.Row(): | |
vote_radio = gr.Radio( | |
["Lean Finder better", "LeanSearch better", "Tie", "Both are bad"], | |
label="Which result is better?", visible=False | |
) | |
submit_btn = gr.Button("Submit vote", visible=False) | |
vote_status = gr.Textbox(label="", interactive=False, max_lines=1, visible=False) | |
# internal state: stored per browser session | |
st_query = gr.State("") | |
st_lf_js = gr.State({}) | |
st_ls_js = gr.State([]) | |
# callbacks | |
def retrieve(query: str, k: int, mode: str): | |
query = query.strip() | |
if not query: | |
return gr.update(value="<p>Please enter a query.</p>"), query, {}, [] | |
payload = {"inputs": query, "top_k": k} | |
lf_json = _call_endpoint(payload).get("results", []) | |
lf_html = _render_leanfinder(lf_json) | |
if mode == "Normal": | |
return lf_html, query, lf_json, [] # ls_json empty | |
# Arena: get LeanSearch as well | |
ls_json = _call_leansearch(query, k) | |
ls_html = _render_leansearch(ls_json) | |
html = ( | |
"<div style='display:flex; gap:0.5rem;'>" | |
f"<div style='flex:1 1 0;'>{lf_html}</div>" | |
f"<div style='flex:1 1 0;'>{ls_html}</div>" | |
"</div>" | |
) | |
return html, query, lf_json, ls_json | |
run_btn.click( | |
retrieve, | |
inputs=[query_box, topk_slider, mode_sel], | |
outputs=[results_html, st_query, st_lf_js, st_ls_js], | |
) | |
# show/hide voting widgets when mode changes | |
def _toggle_widgets(mode): | |
vis = (mode == "Arena") | |
return [gr.update(visible=vis), gr.update(visible=vis), gr.update(visible=vis)] | |
mode_sel.change(_toggle_widgets, inputs=mode_sel, | |
outputs=[vote_radio, submit_btn, vote_status]) | |
# submit vote | |
submit_btn.click( | |
_save_vote, | |
inputs=[vote_radio, st_query, st_lf_js, st_ls_js], | |
outputs=vote_status | |
) | |
if __name__ == "__main__": | |
demo.launch() | |