Shiyu Zhao commited on
Commit
df3974d
·
1 Parent(s): d38a2a4

Update space

Browse files
app.py CHANGED
@@ -36,113 +36,179 @@ except Exception as e:
36
  result_lock = Lock()
37
 
38
  def process_single_instance(args):
39
- idx, eval_dict, qa_dataset, evaluator, eval_metrics = args
40
- query, query_id, answer_ids, meta_info = qa_dataset[idx]
41
-
42
  try:
43
- # Access prediction using dictionary instead of DataFrame
44
- if query_id not in eval_dict:
45
- raise IndexError(f'Error when processing query_id={query_id}, please make sure the predicted results exist for this query.')
46
 
47
- pred_rank = eval_dict[query_id]
 
 
 
 
 
 
 
 
 
48
 
49
- if isinstance(pred_rank, str):
50
- try:
 
51
  pred_rank = eval(pred_rank)
52
- except SyntaxError as e:
53
- raise ValueError(f'Failed to parse pred_rank as a list for query_id={query_id}: {e}')
54
-
 
 
 
 
 
 
 
55
  if not isinstance(pred_rank, list):
56
- raise TypeError(f'Error when processing query_id={query_id}, expected pred_rank to be a list but got {type(pred_rank)}.')
57
-
58
- pred_dict = {pred_rank[i]: -i for i in range(min(100, len(pred_rank)))}
 
 
 
 
 
 
 
59
  answer_ids = torch.LongTensor(answer_ids)
60
-
61
- # Evaluate metrics
62
  result = evaluator.evaluate(pred_dict, answer_ids, metrics=eval_metrics)
 
63
  result["idx"], result["query_id"] = idx, query_id
64
  return result
65
 
66
  except Exception as e:
67
- raise RuntimeError(f'Error processing query_id={query_id}: {str(e)}')
68
-
69
- def compute_metrics(csv_path: str, dataset: str, split: str, num_threads: int = 4):
70
- candidate_ids_dict = {
71
- 'amazon': [i for i in range(957192)],
72
- 'mag': [i for i in range(1172724, 1872968)],
73
- 'prime': [i for i in range(129375)]
74
- }
75
 
 
 
 
76
  try:
77
- # Read and validate CSV
 
 
 
78
  eval_csv = pd.read_csv(csv_path)
79
- if 'query_id' not in eval_csv.columns:
80
- raise ValueError('No `query_id` column found in the submitted csv.')
81
- if 'pred_rank' not in eval_csv.columns:
82
- raise ValueError('No `pred_rank` column found in the submitted csv.')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- # Convert DataFrame to dictionary for thread-safe access
85
- eval_dict = dict(zip(eval_csv['query_id'], eval_csv['pred_rank']))
86
-
87
- # Validate input parameters
88
- if dataset not in candidate_ids_dict:
89
- raise ValueError(f"Invalid dataset '{dataset}', expected one of {list(candidate_ids_dict.keys())}.")
90
- if split not in ['test', 'test-0.1', 'human_generated_eval']:
91
- raise ValueError(f"Invalid split '{split}', expected one of ['test', 'test-0.1', 'human_generated_eval'].")
92
-
93
- # Initialize evaluator and metrics
94
  evaluator = Evaluator(candidate_ids_dict[dataset])
95
  eval_metrics = ['hit@1', 'hit@5', 'recall@20', 'mrr']
96
-
97
- # Load dataset and get split indices
98
  qa_dataset = load_qa(dataset, human_generated_eval=split == 'human_generated_eval')
99
  split_idx = qa_dataset.get_idx_split()
100
  all_indices = split_idx[split].tolist()
101
-
102
- # Thread-safe containers for results
 
 
 
103
  results_list = []
104
- results_lock = Lock()
105
-
106
- # Prepare args for each thread
107
- args = [(idx, eval_dict, qa_dataset, evaluator, eval_metrics) for idx in all_indices]
108
-
109
- failed_queries = [] # Track failed queries
110
-
111
- # Process using threads
112
- with ThreadPoolExecutor(max_workers=num_threads) as executor:
113
- futures = [executor.submit(process_single_instance, arg) for arg in args]
114
-
115
- for future in tqdm(as_completed(futures), total=len(futures)):
116
- try:
117
- result = future.result()
118
- with results_lock:
119
- results_list.append(result)
120
- except Exception as e:
121
- query_id = str(e).split('query_id=')[-1].split(':')[0]
122
- failed_queries.append(query_id)
123
- print(f"Error processing instance: {str(e)}")
124
-
125
- if failed_queries:
126
- print(f"\nFailed to process {len(failed_queries)} queries.")
127
- print(f"First few failed query_ids: {failed_queries[:5]}")
128
-
129
- if not results_list:
130
- raise ValueError("No results were successfully processed")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  # Compute final metrics
 
 
 
 
 
133
  results_df = pd.DataFrame(results_list)
134
  final_results = {
135
- metric: np.mean(results_df[metric]) for metric in eval_metrics
 
136
  }
137
-
 
 
138
  return final_results
139
 
140
- except pd.errors.EmptyDataError:
141
- return "Error: The CSV file is empty or could not be read. Please check the file and try again."
142
- except FileNotFoundError:
143
- return f"Error: The file {csv_path} could not be found. Please check the file path and try again."
144
  except Exception as error:
145
- return f"{error}"
 
 
 
146
 
147
 
148
  # Data dictionaries for leaderboard
 
36
  result_lock = Lock()
37
 
38
  def process_single_instance(args):
39
+ """Process a single instance with improved prediction handling"""
40
+ idx, eval_csv, qa_dataset, evaluator, eval_metrics, max_candidate_id = args
 
41
  try:
42
+ query, query_id, answer_ids, meta_info = qa_dataset[idx]
 
 
43
 
44
+ # Get predictions with better error handling
45
+ matching_preds = eval_csv[eval_csv['query_id'] == query_id]['pred_rank']
46
+ if len(matching_preds) == 0:
47
+ print(f"Warning: No prediction found for query_id {query_id}")
48
+ return None
49
+ elif len(matching_preds) > 1:
50
+ print(f"Warning: Multiple predictions found for query_id {query_id}, using first one")
51
+ pred_rank = matching_preds.iloc[0]
52
+ else:
53
+ pred_rank = matching_preds.iloc[0]
54
 
55
+ # Parse prediction
56
+ try:
57
+ if isinstance(pred_rank, str):
58
  pred_rank = eval(pred_rank)
59
+ elif isinstance(pred_rank, list):
60
+ pass
61
+ else:
62
+ print(f"Warning: Unexpected pred_rank type for query_id {query_id}: {type(pred_rank)}")
63
+ return None
64
+ except Exception as e:
65
+ print(f"Error parsing pred_rank for query_id {query_id}: {str(e)}")
66
+ return None
67
+
68
+ # Validate and filter predictions
69
  if not isinstance(pred_rank, list):
70
+ print(f"Warning: pred_rank is not a list for query_id {query_id}")
71
+ return None
72
+
73
+ valid_ranks = [rank for rank in pred_rank if isinstance(rank, (int, np.integer)) and 0 <= rank < max_candidate_id]
74
+ if len(valid_ranks) == 0:
75
+ print(f"Warning: No valid predictions for query_id {query_id}")
76
+ return None
77
+
78
+ # Use only valid predictions
79
+ pred_dict = {valid_ranks[i]: -i for i in range(min(100, len(valid_ranks)))}
80
  answer_ids = torch.LongTensor(answer_ids)
 
 
81
  result = evaluator.evaluate(pred_dict, answer_ids, metrics=eval_metrics)
82
+
83
  result["idx"], result["query_id"] = idx, query_id
84
  return result
85
 
86
  except Exception as e:
87
+ print(f"Error processing idx {idx}: {str(e)}")
88
+ return None
 
 
 
 
 
 
89
 
90
+ def compute_metrics(csv_path: str, dataset: str, split: str, num_workers: int = 4):
91
+ """Compute metrics with improved prediction handling"""
92
+ start_time = time.time()
93
  try:
94
+ print(f"\nStarting compute_metrics for {dataset} {split}")
95
+
96
+ # Load CSV and validate format
97
+ print("Loading and validating CSV file...")
98
  eval_csv = pd.read_csv(csv_path)
99
+ if 'query_id' not in eval_csv.columns or 'pred_rank' not in eval_csv.columns:
100
+ raise ValueError("CSV must contain 'query_id' and 'pred_rank' columns")
101
+
102
+ # Check for duplicate query_ids
103
+ duplicate_queries = eval_csv['query_id'].duplicated()
104
+ if duplicate_queries.any():
105
+ dup_count = duplicate_queries.sum()
106
+ print(f"Warning: Found {dup_count} duplicate query_ids in CSV")
107
+
108
+ # Keep only necessary columns
109
+ eval_csv = eval_csv[['query_id', 'pred_rank']]
110
+ print(f"CSV loaded, shape: {eval_csv.shape}")
111
+
112
+ # Get dataset-specific candidate size
113
+ candidate_size_dict = {
114
+ 'amazon': 957192,
115
+ 'mag': 700244, # 1872968 - 1172724
116
+ 'prime': 129375
117
+ }
118
+
119
+ if dataset not in candidate_size_dict:
120
+ raise ValueError(f"Invalid dataset '{dataset}'")
121
+
122
+ max_candidate_id = candidate_size_dict[dataset]
123
+ print(f"Dataset {dataset} has {max_candidate_id} candidates")
124
 
 
 
 
 
 
 
 
 
 
 
125
  evaluator = Evaluator(candidate_ids_dict[dataset])
126
  eval_metrics = ['hit@1', 'hit@5', 'recall@20', 'mrr']
127
+
 
128
  qa_dataset = load_qa(dataset, human_generated_eval=split == 'human_generated_eval')
129
  split_idx = qa_dataset.get_idx_split()
130
  all_indices = split_idx[split].tolist()
131
+
132
+ print(f"Processing {len(all_indices)} instances...")
133
+
134
+ # Process in batches using ThreadPoolExecutor
135
+ batch_size = 100
136
  results_list = []
137
+ progress_queue = Queue()
138
+ valid_results_count = 0
139
+ error_count = 0
140
+
141
+ def process_batch(batch_indices):
142
+ nonlocal valid_results_count, error_count
143
+ batch_results = []
144
+ with ThreadPoolExecutor(max_workers=num_workers) as executor:
145
+ futures = [
146
+ executor.submit(process_single_instance,
147
+ (idx, eval_csv, qa_dataset, evaluator, eval_metrics, max_candidate_id))
148
+ for idx in batch_indices
149
+ ]
150
+ for future in futures:
151
+ try:
152
+ result = future.result()
153
+ if result is not None:
154
+ batch_results.append(result)
155
+ valid_results_count += 1
156
+ else:
157
+ error_count += 1
158
+ except Exception as e:
159
+ print(f"Error in batch processing: {str(e)}")
160
+ error_count += 1
161
+ progress_queue.put(1)
162
+ return batch_results
163
+
164
+ # Process batches with progress tracking
165
+ total_batches = (len(all_indices) + batch_size - 1) // batch_size
166
+ remaining_indices = len(all_indices)
167
+
168
+ def update_progress():
169
+ with tqdm(total=len(all_indices), desc="Processing instances") as pbar:
170
+ completed = 0
171
+ while completed < len(all_indices):
172
+ progress_queue.get()
173
+ completed += 1
174
+ pbar.update(1)
175
+
176
+ # Start progress monitoring thread
177
+ progress_thread = threading.Thread(target=update_progress)
178
+ progress_thread.start()
179
+
180
+ # Process batches
181
+ for i in range(0, len(all_indices), batch_size):
182
+ batch_indices = all_indices[i:min(i + batch_size, len(all_indices))]
183
+ batch_results = process_batch(batch_indices)
184
+ results_list.extend(batch_results)
185
+ remaining_indices -= len(batch_indices)
186
+ print(f"\rBatch {i//batch_size + 1}/{total_batches} completed. "
187
+ f"Valid: {valid_results_count}, Errors: {error_count}, Remaining: {remaining_indices}")
188
+
189
+ progress_thread.join()
190
 
191
  # Compute final metrics
192
+ if not results_list:
193
+ raise ValueError("No valid results were produced")
194
+
195
+ print(f"\nProcessing complete. Valid results: {valid_results_count}, Errors: {error_count}")
196
+
197
  results_df = pd.DataFrame(results_list)
198
  final_results = {
199
+ metric: results_df[metric].mean()
200
+ for metric in eval_metrics
201
  }
202
+
203
+ elapsed_time = time.time() - start_time
204
+ print(f"\nMetrics computation completed in {elapsed_time:.2f} seconds")
205
  return final_results
206
 
 
 
 
 
207
  except Exception as error:
208
+ elapsed_time = time.time() - start_time
209
+ error_msg = f"Error in compute_metrics ({elapsed_time:.2f}s): {str(error)}"
210
+ print(error_msg)
211
+ return error_msg
212
 
213
 
214
  # Data dictionaries for leaderboard
submissions/ance_test_abc/predictions_20241115_001153.csv DELETED
The diff for this file is too large to render. See raw diff