Shiyu Zhao commited on
Commit
743ad0c
·
1 Parent(s): d6d7173

Update space

Browse files
Files changed (1) hide show
  1. app.py +70 -107
app.py CHANGED
@@ -15,10 +15,10 @@ from huggingface_hub import HfApi
15
  import shutil
16
  import tempfile
17
  import time
18
- from concurrent.futures import ThreadPoolExecutor
19
  from queue import Queue
20
  import threading
21
-
22
  from stark_qa import load_qa
23
  from stark_qa.evaluator import Evaluator
24
 
@@ -32,150 +32,113 @@ try:
32
  except Exception as e:
33
  raise RuntimeError(f"Failed to initialize HuggingFace Hub storage: {e}")
34
 
 
 
35
 
36
  def process_single_instance(args):
37
- """Process a single instance with progress tracking"""
38
  idx, eval_csv, qa_dataset, evaluator, eval_metrics = args
 
39
  try:
40
- query, query_id, answer_ids, meta_info = qa_dataset[idx]
41
-
42
- # Print progress for debugging
43
- print(f"Processing query_id: {query_id}")
44
-
 
 
 
 
 
 
 
45
  try:
46
- pred_rank = eval_csv[eval_csv['query_id'] == query_id]['pred_rank'].item()
47
- except Exception as e:
48
- print(f"Error getting pred_rank for query_id {query_id}: {str(e)}")
49
- raise
50
-
51
- if isinstance(pred_rank, str):
52
  pred_rank = eval(pred_rank)
53
-
54
- pred_dict = {pred_rank[i]: -i for i in range(min(100, len(pred_rank)))}
55
- answer_ids = torch.LongTensor(answer_ids)
56
- result = evaluator.evaluate(pred_dict, answer_ids, metrics=eval_metrics)
57
-
58
- result["idx"], result["query_id"] = idx, query_id
59
- return result
60
- except Exception as e:
61
- print(f"Error in process_single_instance for idx {idx}: {str(e)}")
62
- raise
63
 
 
 
 
 
 
 
 
64
 
65
- def compute_metrics(csv_path: str, dataset: str, split: str, num_workers: int = 4):
66
  candidate_ids_dict = {
67
  'amazon': [i for i in range(957192)],
68
  'mag': [i for i in range(1172724, 1872968)],
69
  'prime': [i for i in range(129375)]
70
  }
71
- start_time = time.time()
72
  try:
 
73
  eval_csv = pd.read_csv(csv_path)
74
  if 'query_id' not in eval_csv.columns:
75
  raise ValueError('No `query_id` column found in the submitted csv.')
76
  if 'pred_rank' not in eval_csv.columns:
77
  raise ValueError('No `pred_rank` column found in the submitted csv.')
78
-
 
79
  eval_csv = eval_csv[['query_id', 'pred_rank']]
80
 
 
81
  if dataset not in candidate_ids_dict:
82
  raise ValueError(f"Invalid dataset '{dataset}', expected one of {list(candidate_ids_dict.keys())}.")
83
  if split not in ['test', 'test-0.1', 'human_generated_eval']:
84
  raise ValueError(f"Invalid split '{split}', expected one of ['test', 'test-0.1', 'human_generated_eval'].")
85
 
86
- # print("Initializing evaluator...")
87
  evaluator = Evaluator(candidate_ids_dict[dataset])
88
  eval_metrics = ['hit@1', 'hit@5', 'recall@20', 'mrr']
89
- # print("Loading QA dataset...")
 
90
  qa_dataset = load_qa(dataset, human_generated_eval=split == 'human_generated_eval')
91
  split_idx = qa_dataset.get_idx_split()
92
  all_indices = split_idx[split].tolist()
93
- print(f"Dataset loaded, processing {len(all_indices)} instances")
94
-
95
- # results_list = []
96
- # query_ids = []
97
-
98
- # # Prepare args for each worker
99
- # args = [(idx, eval_csv, qa_dataset, evaluator, eval_metrics) for idx in all_indices]
100
 
101
- # with ProcessPoolExecutor(max_workers=num_workers) as executor:
102
- # futures = [executor.submit(process_single_instance, arg) for arg in args]
103
- # for future in tqdm(as_completed(futures), total=len(futures)):
104
- # result = future.result() # This will raise an error if the worker encountered one
105
- # results_list.append(result)
106
- # query_ids.append(result['query_id'])
107
 
108
- # # Concatenate results and compute final metrics
109
- # eval_csv = pd.concat([eval_csv, pd.DataFrame(results_list)], ignore_index=True)
110
- # final_results = {
111
- # metric: np.mean(eval_csv[eval_csv['query_id'].isin(query_ids)][metric]) for metric in eval_metrics
112
- # }
113
- # return final_result
114
 
115
- batch_size = 100
116
- results_list = []
117
- progress_queue = Queue()
118
-
119
- def process_batch(batch_indices):
120
- batch_results = []
121
- with ThreadPoolExecutor(max_workers=num_workers) as executor:
122
- futures = [
123
- executor.submit(process_single_instance,
124
- (idx, eval_csv, qa_dataset, evaluator, eval_metrics))
125
- for idx in batch_indices
126
- ]
127
- for future in futures:
128
  result = future.result()
129
- if result is not None:
130
- batch_results.append(result)
131
- progress_queue.put(1)
132
- return batch_results
133
-
134
- # Process batches
135
- total_batches = (len(all_indices) + batch_size - 1) // batch_size
136
- remaining_indices = len(all_indices)
137
-
138
- def update_progress():
139
- with tqdm(total=len(all_indices), desc="Processing instances") as pbar:
140
- completed = 0
141
- while completed < len(all_indices):
142
- progress_queue.get()
143
- completed += 1
144
- pbar.update(1)
145
-
146
- # Start progress monitoring thread
147
- progress_thread = threading.Thread(target=update_progress)
148
- progress_thread.start()
149
-
150
- # Process batches
151
- for i in range(0, len(all_indices), batch_size):
152
- batch_indices = all_indices[i:min(i + batch_size, len(all_indices))]
153
- batch_results = process_batch(batch_indices)
154
- results_list.extend(batch_results)
155
- remaining_indices -= len(batch_indices)
156
- print(f"\rBatch {i//batch_size + 1}/{total_batches} completed. Remaining: {remaining_indices}")
157
-
158
- progress_thread.join()
159
-
160
- # Compute final metrics
161
- if not results_list:
162
- raise ValueError("No valid results were produced")
163
 
164
- results_df = pd.DataFrame(results_list)
165
- final_results = {
166
- metric: results_df[metric].mean()
167
- for metric in eval_metrics
168
- }
169
-
170
- elapsed_time = time.time() - start_time
171
- print(f"\nMetrics computation completed in {elapsed_time:.2f} seconds")
172
  return final_results
173
 
 
 
 
 
174
  except Exception as error:
175
- elapsed_time = time.time() - start_time
176
- error_msg = f"Error in compute_metrics ({elapsed_time:.2f}s): {str(error)}"
177
- print(error_msg)
178
- return error_msg
179
 
180
 
181
 
 
15
  import shutil
16
  import tempfile
17
  import time
18
+ from concurrent.futures import ThreadPoolExecutor, as_completed
19
  from queue import Queue
20
  import threading
21
+ from threading import Lock
22
  from stark_qa import load_qa
23
  from stark_qa.evaluator import Evaluator
24
 
 
32
  except Exception as e:
33
  raise RuntimeError(f"Failed to initialize HuggingFace Hub storage: {e}")
34
 
35
+ # Global lock for thread-safe operations
36
+ result_lock = Lock()
37
 
38
  def process_single_instance(args):
 
39
  idx, eval_csv, qa_dataset, evaluator, eval_metrics = args
40
+ query, query_id, answer_ids, meta_info = qa_dataset[idx]
41
  try:
42
+ # Using loc instead of direct boolean indexing for thread safety
43
+ with result_lock:
44
+ matching_rows = eval_csv.loc[eval_csv['query_id'] == query_id]
45
+ if matching_rows.empty:
46
+ raise IndexError(f'Error when processing query_id={query_id}, please make sure the predicted results exist for this query.')
47
+ pred_rank = matching_rows['pred_rank'].iloc[0]
48
+ except IndexError:
49
+ raise IndexError(f'Error when processing query_id={query_id}, please make sure the predicted results exist for this query.')
50
+ except Exception as e:
51
+ raise RuntimeError(f'Unexpected error occurred while fetching prediction rank for query_id={query_id}: {e}')
52
+
53
+ if isinstance(pred_rank, str):
54
  try:
 
 
 
 
 
 
55
  pred_rank = eval(pred_rank)
56
+ except SyntaxError as e:
57
+ raise ValueError(f'Failed to parse pred_rank as a list for query_id={query_id}: {e}')
58
+
59
+ if not isinstance(pred_rank, list):
60
+ raise TypeError(f'Error when processing query_id={query_id}, expected pred_rank to be a list but got {type(pred_rank)}.')
 
 
 
 
 
61
 
62
+ pred_dict = {pred_rank[i]: -i for i in range(min(100, len(pred_rank)))}
63
+ answer_ids = torch.LongTensor(answer_ids)
64
+
65
+ # Evaluate metrics
66
+ result = evaluator.evaluate(pred_dict, answer_ids, metrics=eval_metrics)
67
+ result["idx"], result["query_id"] = idx, query_id
68
+ return result
69
 
70
+ def compute_metrics(csv_path: str, dataset: str, split: str, num_threads: int = 4):
71
  candidate_ids_dict = {
72
  'amazon': [i for i in range(957192)],
73
  'mag': [i for i in range(1172724, 1872968)],
74
  'prime': [i for i in range(129375)]
75
  }
76
+
77
  try:
78
+ # Read and validate CSV
79
  eval_csv = pd.read_csv(csv_path)
80
  if 'query_id' not in eval_csv.columns:
81
  raise ValueError('No `query_id` column found in the submitted csv.')
82
  if 'pred_rank' not in eval_csv.columns:
83
  raise ValueError('No `pred_rank` column found in the submitted csv.')
84
+
85
+ # Filter required columns
86
  eval_csv = eval_csv[['query_id', 'pred_rank']]
87
 
88
+ # Validate input parameters
89
  if dataset not in candidate_ids_dict:
90
  raise ValueError(f"Invalid dataset '{dataset}', expected one of {list(candidate_ids_dict.keys())}.")
91
  if split not in ['test', 'test-0.1', 'human_generated_eval']:
92
  raise ValueError(f"Invalid split '{split}', expected one of ['test', 'test-0.1', 'human_generated_eval'].")
93
 
94
+ # Initialize evaluator and metrics
95
  evaluator = Evaluator(candidate_ids_dict[dataset])
96
  eval_metrics = ['hit@1', 'hit@5', 'recall@20', 'mrr']
97
+
98
+ # Load dataset and get split indices
99
  qa_dataset = load_qa(dataset, human_generated_eval=split == 'human_generated_eval')
100
  split_idx = qa_dataset.get_idx_split()
101
  all_indices = split_idx[split].tolist()
 
 
 
 
 
 
 
102
 
103
+ # Thread-safe containers
104
+ results_list = []
105
+ query_ids = []
106
+ results_lock = Lock()
 
 
107
 
108
+ # Prepare args for each thread
109
+ args = [(idx, eval_csv, qa_dataset, evaluator, eval_metrics) for idx in all_indices]
 
 
 
 
110
 
111
+ # Process using threads
112
+ with ThreadPoolExecutor(max_workers=num_threads) as executor:
113
+ futures = [executor.submit(process_single_instance, arg) for arg in args]
114
+
115
+ for future in tqdm(as_completed(futures), total=len(futures)):
116
+ try:
 
 
 
 
 
 
 
117
  result = future.result()
118
+ with results_lock:
119
+ results_list.append(result)
120
+ query_ids.append(result['query_id'])
121
+ except Exception as e:
122
+ print(f"Error processing instance: {str(e)}")
123
+
124
+ # Concatenate results and compute final metrics
125
+ with result_lock:
126
+ results_df = pd.DataFrame(results_list)
127
+ eval_csv = pd.concat([eval_csv, results_df], ignore_index=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
+ final_results = {
130
+ metric: np.mean(eval_csv[eval_csv['query_id'].isin(query_ids)][metric])
131
+ for metric in eval_metrics
132
+ }
133
+
 
 
 
134
  return final_results
135
 
136
+ except pd.errors.EmptyDataError:
137
+ return "Error: The CSV file is empty or could not be read. Please check the file and try again."
138
+ except FileNotFoundError:
139
+ return f"Error: The file {csv_path} could not be found. Please check the file path and try again."
140
  except Exception as error:
141
+ return f"{error}"
 
 
 
142
 
143
 
144