Spaces:
Running
Running
Shiyu Zhao
commited on
Commit
·
d6d7173
1
Parent(s):
ca736d2
Update space
Browse files
app.py
CHANGED
@@ -15,6 +15,9 @@ from huggingface_hub import HfApi
|
|
15 |
import shutil
|
16 |
import tempfile
|
17 |
import time
|
|
|
|
|
|
|
18 |
|
19 |
from stark_qa import load_qa
|
20 |
from stark_qa.evaluator import Evaluator
|
@@ -80,16 +83,16 @@ def compute_metrics(csv_path: str, dataset: str, split: str, num_workers: int =
|
|
80 |
if split not in ['test', 'test-0.1', 'human_generated_eval']:
|
81 |
raise ValueError(f"Invalid split '{split}', expected one of ['test', 'test-0.1', 'human_generated_eval'].")
|
82 |
|
83 |
-
print("Initializing evaluator...")
|
84 |
evaluator = Evaluator(candidate_ids_dict[dataset])
|
85 |
eval_metrics = ['hit@1', 'hit@5', 'recall@20', 'mrr']
|
86 |
-
print("Loading QA dataset...")
|
87 |
qa_dataset = load_qa(dataset, human_generated_eval=split == 'human_generated_eval')
|
88 |
split_idx = qa_dataset.get_idx_split()
|
89 |
all_indices = split_idx[split].tolist()
|
90 |
print(f"Dataset loaded, processing {len(all_indices)} instances")
|
91 |
|
92 |
-
results_list = []
|
93 |
# query_ids = []
|
94 |
|
95 |
# # Prepare args for each worker
|
@@ -108,34 +111,56 @@ def compute_metrics(csv_path: str, dataset: str, split: str, num_workers: int =
|
|
108 |
# metric: np.mean(eval_csv[eval_csv['query_id'].isin(query_ids)][metric]) for metric in eval_metrics
|
109 |
# }
|
110 |
# return final_result
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
total_batches = (len(all_indices) + batch_size - 1) // batch_size
|
|
|
113 |
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
-
print(f"\nProcessing batch {batch_num + 1}/{total_batches}")
|
120 |
-
print(f"Batch size: {len(batch_indices)}")
|
121 |
-
|
122 |
-
args = [(idx, eval_csv, qa_dataset, evaluator, eval_metrics)
|
123 |
-
for idx in batch_indices]
|
124 |
-
|
125 |
-
with ProcessPoolExecutor(max_workers=num_workers) as executor:
|
126 |
-
futures = [executor.submit(process_single_instance, arg)
|
127 |
-
for arg in args]
|
128 |
-
for future in tqdm(as_completed(futures),
|
129 |
-
total=len(futures),
|
130 |
-
desc=f"Batch {batch_num + 1}"):
|
131 |
-
try:
|
132 |
-
result = future.result()
|
133 |
-
results_list.append(result)
|
134 |
-
except Exception as e:
|
135 |
-
print(f"Error processing result: {str(e)}")
|
136 |
-
raise
|
137 |
-
|
138 |
-
print("\nComputing final metrics...")
|
139 |
results_df = pd.DataFrame(results_list)
|
140 |
final_results = {
|
141 |
metric: results_df[metric].mean()
|
@@ -151,13 +176,7 @@ def compute_metrics(csv_path: str, dataset: str, split: str, num_workers: int =
|
|
151 |
error_msg = f"Error in compute_metrics ({elapsed_time:.2f}s): {str(error)}"
|
152 |
print(error_msg)
|
153 |
return error_msg
|
154 |
-
|
155 |
-
except pd.errors.EmptyDataError:
|
156 |
-
return "Error: The CSV file is empty or could not be read. Please check the file and try again."
|
157 |
-
except FileNotFoundError:
|
158 |
-
return f"Error: The file {csv_path} could not be found. Please check the file path and try again."
|
159 |
-
except Exception as error:
|
160 |
-
return f"{error}"
|
161 |
|
162 |
|
163 |
# Data dictionaries for leaderboard
|
@@ -666,6 +685,7 @@ def process_submission(
|
|
666 |
code_repo, csv_file, model_description, hardware, paper_link, model_type
|
667 |
):
|
668 |
"""Process submission with progress updates"""
|
|
|
669 |
try:
|
670 |
# 1. Initial validation
|
671 |
yield "Validating submission details..."
|
@@ -680,6 +700,7 @@ def process_submission(
|
|
680 |
else:
|
681 |
try:
|
682 |
temp_fd, temp_csv_path = tempfile.mkstemp(suffix='.csv')
|
|
|
683 |
os.close(temp_fd)
|
684 |
shutil.copy2(csv_file.name, temp_csv_path)
|
685 |
except Exception as e:
|
@@ -700,7 +721,7 @@ def process_submission(
|
|
700 |
csv_path=temp_csv_path,
|
701 |
dataset=dataset.lower(),
|
702 |
split=split,
|
703 |
-
num_workers=
|
704 |
)
|
705 |
|
706 |
if isinstance(results, str):
|
@@ -762,11 +783,15 @@ def process_submission(
|
|
762 |
"""
|
763 |
|
764 |
except Exception as e:
|
765 |
-
|
|
|
766 |
finally:
|
767 |
-
|
768 |
-
|
769 |
-
|
|
|
|
|
|
|
770 |
|
771 |
def filter_by_model_type(df, selected_types):
|
772 |
"""
|
|
|
15 |
import shutil
|
16 |
import tempfile
|
17 |
import time
|
18 |
+
from concurrent.futures import ThreadPoolExecutor
|
19 |
+
from queue import Queue
|
20 |
+
import threading
|
21 |
|
22 |
from stark_qa import load_qa
|
23 |
from stark_qa.evaluator import Evaluator
|
|
|
83 |
if split not in ['test', 'test-0.1', 'human_generated_eval']:
|
84 |
raise ValueError(f"Invalid split '{split}', expected one of ['test', 'test-0.1', 'human_generated_eval'].")
|
85 |
|
86 |
+
# print("Initializing evaluator...")
|
87 |
evaluator = Evaluator(candidate_ids_dict[dataset])
|
88 |
eval_metrics = ['hit@1', 'hit@5', 'recall@20', 'mrr']
|
89 |
+
# print("Loading QA dataset...")
|
90 |
qa_dataset = load_qa(dataset, human_generated_eval=split == 'human_generated_eval')
|
91 |
split_idx = qa_dataset.get_idx_split()
|
92 |
all_indices = split_idx[split].tolist()
|
93 |
print(f"Dataset loaded, processing {len(all_indices)} instances")
|
94 |
|
95 |
+
# results_list = []
|
96 |
# query_ids = []
|
97 |
|
98 |
# # Prepare args for each worker
|
|
|
111 |
# metric: np.mean(eval_csv[eval_csv['query_id'].isin(query_ids)][metric]) for metric in eval_metrics
|
112 |
# }
|
113 |
# return final_result
|
114 |
+
|
115 |
+
batch_size = 100
|
116 |
+
results_list = []
|
117 |
+
progress_queue = Queue()
|
118 |
+
|
119 |
+
def process_batch(batch_indices):
|
120 |
+
batch_results = []
|
121 |
+
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
122 |
+
futures = [
|
123 |
+
executor.submit(process_single_instance,
|
124 |
+
(idx, eval_csv, qa_dataset, evaluator, eval_metrics))
|
125 |
+
for idx in batch_indices
|
126 |
+
]
|
127 |
+
for future in futures:
|
128 |
+
result = future.result()
|
129 |
+
if result is not None:
|
130 |
+
batch_results.append(result)
|
131 |
+
progress_queue.put(1)
|
132 |
+
return batch_results
|
133 |
+
|
134 |
+
# Process batches
|
135 |
total_batches = (len(all_indices) + batch_size - 1) // batch_size
|
136 |
+
remaining_indices = len(all_indices)
|
137 |
|
138 |
+
def update_progress():
|
139 |
+
with tqdm(total=len(all_indices), desc="Processing instances") as pbar:
|
140 |
+
completed = 0
|
141 |
+
while completed < len(all_indices):
|
142 |
+
progress_queue.get()
|
143 |
+
completed += 1
|
144 |
+
pbar.update(1)
|
145 |
+
|
146 |
+
# Start progress monitoring thread
|
147 |
+
progress_thread = threading.Thread(target=update_progress)
|
148 |
+
progress_thread.start()
|
149 |
+
|
150 |
+
# Process batches
|
151 |
+
for i in range(0, len(all_indices), batch_size):
|
152 |
+
batch_indices = all_indices[i:min(i + batch_size, len(all_indices))]
|
153 |
+
batch_results = process_batch(batch_indices)
|
154 |
+
results_list.extend(batch_results)
|
155 |
+
remaining_indices -= len(batch_indices)
|
156 |
+
print(f"\rBatch {i//batch_size + 1}/{total_batches} completed. Remaining: {remaining_indices}")
|
157 |
+
|
158 |
+
progress_thread.join()
|
159 |
+
|
160 |
+
# Compute final metrics
|
161 |
+
if not results_list:
|
162 |
+
raise ValueError("No valid results were produced")
|
163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
results_df = pd.DataFrame(results_list)
|
165 |
final_results = {
|
166 |
metric: results_df[metric].mean()
|
|
|
176 |
error_msg = f"Error in compute_metrics ({elapsed_time:.2f}s): {str(error)}"
|
177 |
print(error_msg)
|
178 |
return error_msg
|
179 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
|
181 |
|
182 |
# Data dictionaries for leaderboard
|
|
|
685 |
code_repo, csv_file, model_description, hardware, paper_link, model_type
|
686 |
):
|
687 |
"""Process submission with progress updates"""
|
688 |
+
temp_files = []
|
689 |
try:
|
690 |
# 1. Initial validation
|
691 |
yield "Validating submission details..."
|
|
|
700 |
else:
|
701 |
try:
|
702 |
temp_fd, temp_csv_path = tempfile.mkstemp(suffix='.csv')
|
703 |
+
temp_files.append(temp_csv_path)
|
704 |
os.close(temp_fd)
|
705 |
shutil.copy2(csv_file.name, temp_csv_path)
|
706 |
except Exception as e:
|
|
|
721 |
csv_path=temp_csv_path,
|
722 |
dataset=dataset.lower(),
|
723 |
split=split,
|
724 |
+
num_workers=4
|
725 |
)
|
726 |
|
727 |
if isinstance(results, str):
|
|
|
783 |
"""
|
784 |
|
785 |
except Exception as e:
|
786 |
+
total_time = time.time() - start_time
|
787 |
+
return f"Error ({total_time:.1f}s): {str(e)}"
|
788 |
finally:
|
789 |
+
for temp_file in temp_files:
|
790 |
+
try:
|
791 |
+
if os.path.exists(temp_file):
|
792 |
+
os.unlink(temp_file)
|
793 |
+
except Exception as e:
|
794 |
+
print(f"Warning: Failed to delete {temp_file}: {str(e)}")
|
795 |
|
796 |
def filter_by_model_type(df, selected_types):
|
797 |
"""
|