Shiyu Zhao commited on
Commit
d6d7173
·
1 Parent(s): ca736d2

Update space

Browse files
Files changed (1) hide show
  1. app.py +65 -40
app.py CHANGED
@@ -15,6 +15,9 @@ from huggingface_hub import HfApi
15
  import shutil
16
  import tempfile
17
  import time
 
 
 
18
 
19
  from stark_qa import load_qa
20
  from stark_qa.evaluator import Evaluator
@@ -80,16 +83,16 @@ def compute_metrics(csv_path: str, dataset: str, split: str, num_workers: int =
80
  if split not in ['test', 'test-0.1', 'human_generated_eval']:
81
  raise ValueError(f"Invalid split '{split}', expected one of ['test', 'test-0.1', 'human_generated_eval'].")
82
 
83
- print("Initializing evaluator...")
84
  evaluator = Evaluator(candidate_ids_dict[dataset])
85
  eval_metrics = ['hit@1', 'hit@5', 'recall@20', 'mrr']
86
- print("Loading QA dataset...")
87
  qa_dataset = load_qa(dataset, human_generated_eval=split == 'human_generated_eval')
88
  split_idx = qa_dataset.get_idx_split()
89
  all_indices = split_idx[split].tolist()
90
  print(f"Dataset loaded, processing {len(all_indices)} instances")
91
 
92
- results_list = []
93
  # query_ids = []
94
 
95
  # # Prepare args for each worker
@@ -108,34 +111,56 @@ def compute_metrics(csv_path: str, dataset: str, split: str, num_workers: int =
108
  # metric: np.mean(eval_csv[eval_csv['query_id'].isin(query_ids)][metric]) for metric in eval_metrics
109
  # }
110
  # return final_result
111
- batch_size = 50 # Smaller batch size for more frequent updates
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  total_batches = (len(all_indices) + batch_size - 1) // batch_size
 
113
 
114
- for batch_num in range(total_batches):
115
- batch_start = batch_num * batch_size
116
- batch_end = min((batch_num + 1) * batch_size, len(all_indices))
117
- batch_indices = all_indices[batch_start:batch_end]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- print(f"\nProcessing batch {batch_num + 1}/{total_batches}")
120
- print(f"Batch size: {len(batch_indices)}")
121
-
122
- args = [(idx, eval_csv, qa_dataset, evaluator, eval_metrics)
123
- for idx in batch_indices]
124
-
125
- with ProcessPoolExecutor(max_workers=num_workers) as executor:
126
- futures = [executor.submit(process_single_instance, arg)
127
- for arg in args]
128
- for future in tqdm(as_completed(futures),
129
- total=len(futures),
130
- desc=f"Batch {batch_num + 1}"):
131
- try:
132
- result = future.result()
133
- results_list.append(result)
134
- except Exception as e:
135
- print(f"Error processing result: {str(e)}")
136
- raise
137
-
138
- print("\nComputing final metrics...")
139
  results_df = pd.DataFrame(results_list)
140
  final_results = {
141
  metric: results_df[metric].mean()
@@ -151,13 +176,7 @@ def compute_metrics(csv_path: str, dataset: str, split: str, num_workers: int =
151
  error_msg = f"Error in compute_metrics ({elapsed_time:.2f}s): {str(error)}"
152
  print(error_msg)
153
  return error_msg
154
-
155
- except pd.errors.EmptyDataError:
156
- return "Error: The CSV file is empty or could not be read. Please check the file and try again."
157
- except FileNotFoundError:
158
- return f"Error: The file {csv_path} could not be found. Please check the file path and try again."
159
- except Exception as error:
160
- return f"{error}"
161
 
162
 
163
  # Data dictionaries for leaderboard
@@ -666,6 +685,7 @@ def process_submission(
666
  code_repo, csv_file, model_description, hardware, paper_link, model_type
667
  ):
668
  """Process submission with progress updates"""
 
669
  try:
670
  # 1. Initial validation
671
  yield "Validating submission details..."
@@ -680,6 +700,7 @@ def process_submission(
680
  else:
681
  try:
682
  temp_fd, temp_csv_path = tempfile.mkstemp(suffix='.csv')
 
683
  os.close(temp_fd)
684
  shutil.copy2(csv_file.name, temp_csv_path)
685
  except Exception as e:
@@ -700,7 +721,7 @@ def process_submission(
700
  csv_path=temp_csv_path,
701
  dataset=dataset.lower(),
702
  split=split,
703
- num_workers=2 # Reduced from 4 to 2
704
  )
705
 
706
  if isinstance(results, str):
@@ -762,11 +783,15 @@ def process_submission(
762
  """
763
 
764
  except Exception as e:
765
- return f"Error: {str(e)}"
 
766
  finally:
767
- # Cleanup
768
- if temp_csv_path and os.path.exists(temp_csv_path):
769
- os.unlink(temp_csv_path)
 
 
 
770
 
771
  def filter_by_model_type(df, selected_types):
772
  """
 
15
  import shutil
16
  import tempfile
17
  import time
18
+ from concurrent.futures import ThreadPoolExecutor
19
+ from queue import Queue
20
+ import threading
21
 
22
  from stark_qa import load_qa
23
  from stark_qa.evaluator import Evaluator
 
83
  if split not in ['test', 'test-0.1', 'human_generated_eval']:
84
  raise ValueError(f"Invalid split '{split}', expected one of ['test', 'test-0.1', 'human_generated_eval'].")
85
 
86
+ # print("Initializing evaluator...")
87
  evaluator = Evaluator(candidate_ids_dict[dataset])
88
  eval_metrics = ['hit@1', 'hit@5', 'recall@20', 'mrr']
89
+ # print("Loading QA dataset...")
90
  qa_dataset = load_qa(dataset, human_generated_eval=split == 'human_generated_eval')
91
  split_idx = qa_dataset.get_idx_split()
92
  all_indices = split_idx[split].tolist()
93
  print(f"Dataset loaded, processing {len(all_indices)} instances")
94
 
95
+ # results_list = []
96
  # query_ids = []
97
 
98
  # # Prepare args for each worker
 
111
  # metric: np.mean(eval_csv[eval_csv['query_id'].isin(query_ids)][metric]) for metric in eval_metrics
112
  # }
113
  # return final_result
114
+
115
+ batch_size = 100
116
+ results_list = []
117
+ progress_queue = Queue()
118
+
119
+ def process_batch(batch_indices):
120
+ batch_results = []
121
+ with ThreadPoolExecutor(max_workers=num_workers) as executor:
122
+ futures = [
123
+ executor.submit(process_single_instance,
124
+ (idx, eval_csv, qa_dataset, evaluator, eval_metrics))
125
+ for idx in batch_indices
126
+ ]
127
+ for future in futures:
128
+ result = future.result()
129
+ if result is not None:
130
+ batch_results.append(result)
131
+ progress_queue.put(1)
132
+ return batch_results
133
+
134
+ # Process batches
135
  total_batches = (len(all_indices) + batch_size - 1) // batch_size
136
+ remaining_indices = len(all_indices)
137
 
138
+ def update_progress():
139
+ with tqdm(total=len(all_indices), desc="Processing instances") as pbar:
140
+ completed = 0
141
+ while completed < len(all_indices):
142
+ progress_queue.get()
143
+ completed += 1
144
+ pbar.update(1)
145
+
146
+ # Start progress monitoring thread
147
+ progress_thread = threading.Thread(target=update_progress)
148
+ progress_thread.start()
149
+
150
+ # Process batches
151
+ for i in range(0, len(all_indices), batch_size):
152
+ batch_indices = all_indices[i:min(i + batch_size, len(all_indices))]
153
+ batch_results = process_batch(batch_indices)
154
+ results_list.extend(batch_results)
155
+ remaining_indices -= len(batch_indices)
156
+ print(f"\rBatch {i//batch_size + 1}/{total_batches} completed. Remaining: {remaining_indices}")
157
+
158
+ progress_thread.join()
159
+
160
+ # Compute final metrics
161
+ if not results_list:
162
+ raise ValueError("No valid results were produced")
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  results_df = pd.DataFrame(results_list)
165
  final_results = {
166
  metric: results_df[metric].mean()
 
176
  error_msg = f"Error in compute_metrics ({elapsed_time:.2f}s): {str(error)}"
177
  print(error_msg)
178
  return error_msg
179
+
 
 
 
 
 
 
180
 
181
 
182
  # Data dictionaries for leaderboard
 
685
  code_repo, csv_file, model_description, hardware, paper_link, model_type
686
  ):
687
  """Process submission with progress updates"""
688
+ temp_files = []
689
  try:
690
  # 1. Initial validation
691
  yield "Validating submission details..."
 
700
  else:
701
  try:
702
  temp_fd, temp_csv_path = tempfile.mkstemp(suffix='.csv')
703
+ temp_files.append(temp_csv_path)
704
  os.close(temp_fd)
705
  shutil.copy2(csv_file.name, temp_csv_path)
706
  except Exception as e:
 
721
  csv_path=temp_csv_path,
722
  dataset=dataset.lower(),
723
  split=split,
724
+ num_workers=4
725
  )
726
 
727
  if isinstance(results, str):
 
783
  """
784
 
785
  except Exception as e:
786
+ total_time = time.time() - start_time
787
+ return f"Error ({total_time:.1f}s): {str(e)}"
788
  finally:
789
+ for temp_file in temp_files:
790
+ try:
791
+ if os.path.exists(temp_file):
792
+ os.unlink(temp_file)
793
+ except Exception as e:
794
+ print(f"Warning: Failed to delete {temp_file}: {str(e)}")
795
 
796
  def filter_by_model_type(df, selected_types):
797
  """