Shiyu Zhao commited on
Commit
255a7a4
·
1 Parent(s): 9698e43

Update space

Browse files
Files changed (1) hide show
  1. app.py +73 -103
app.py CHANGED
@@ -7,7 +7,7 @@ from datetime import datetime
7
  import json
8
  import torch
9
  from tqdm import tqdm
10
- from concurrent.futures import ThreadPoolExecutor
11
  import smtplib
12
  from email.mime.multipart import MIMEMultipart
13
  from email.mime.text import MIMEText
@@ -41,50 +41,55 @@ except Exception as e:
41
 
42
 
43
  def process_single_instance(args):
44
- """Process a single instance with improved prediction handling"""
45
  idx, eval_csv, qa_dataset, evaluator, eval_metrics, max_candidate_id = args
46
  try:
 
47
  query, query_id, answer_ids, meta_info = qa_dataset[idx]
48
 
49
- # Get predictions with better error handling
50
  matching_preds = eval_csv[eval_csv['query_id'] == query_id]['pred_rank']
51
  if len(matching_preds) == 0:
52
  print(f"Warning: No prediction found for query_id {query_id}")
53
  return None
54
  elif len(matching_preds) > 1:
55
  print(f"Warning: Multiple predictions found for query_id {query_id}, using first one")
56
- pred_rank = matching_preds.iloc[0]
57
- else:
58
- pred_rank = matching_preds.iloc[0]
59
 
60
  # Parse prediction
61
- try:
62
- if isinstance(pred_rank, str):
63
  pred_rank = eval(pred_rank)
64
- elif isinstance(pred_rank, list):
65
- pass
66
- else:
67
- print(f"Warning: Unexpected pred_rank type for query_id {query_id}: {type(pred_rank)}")
68
  return None
69
- except Exception as e:
70
- print(f"Error parsing pred_rank for query_id {query_id}: {str(e)}")
71
- return None
72
-
73
- # Validate and filter predictions
74
  if not isinstance(pred_rank, list):
75
  print(f"Warning: pred_rank is not a list for query_id {query_id}")
76
  return None
77
 
78
- # valid_ranks = [rank for rank in pred_rank if isinstance(rank, (int, np.integer)) and 0 <= rank < max_candidate_id]
79
- # if len(valid_ranks) == 0:
80
- # print(f"Warning: No valid predictions for query_id {query_id}")
81
- # return None
 
 
 
 
 
 
 
82
 
83
- # Use only valid predictions
84
- pred_dict = {pred_rank[i]: -i for i in range(min(100, len(pred_rank)))}
 
 
85
  answer_ids = torch.LongTensor(answer_ids)
86
- result = evaluator.evaluate(pred_dict, answer_ids, metrics=eval_metrics)
87
 
 
 
88
  result["idx"], result["query_id"] = idx, query_id
89
  return result
90
 
@@ -93,116 +98,81 @@ def process_single_instance(args):
93
  return None
94
 
95
  def compute_metrics(csv_path: str, dataset: str, split: str, num_workers: int = 4):
96
- """Compute metrics with improved prediction handling"""
97
  start_time = time.time()
 
 
98
  candidate_ids_dict = {
99
  'amazon': [i for i in range(957192)],
100
  'mag': [i for i in range(1172724, 1872968)],
101
  'prime': [i for i in range(129375)]
102
  }
 
 
 
 
 
 
 
103
  try:
104
- print(f"\nStarting compute_metrics for {dataset} {split}")
105
-
106
- # Load CSV and validate format
107
- print("Loading and validating CSV file...")
108
- eval_csv = pd.read_csv(csv_path)
109
- if 'query_id' not in eval_csv.columns or 'pred_rank' not in eval_csv.columns:
110
- raise ValueError("CSV must contain 'query_id' and 'pred_rank' columns")
111
-
112
- # Check for duplicate query_ids
113
- duplicate_queries = eval_csv['query_id'].duplicated()
114
- if duplicate_queries.any():
115
- dup_count = duplicate_queries.sum()
116
- print(f"Warning: Found {dup_count} duplicate query_ids in CSV")
117
-
118
- # Keep only necessary columns
119
- eval_csv = eval_csv[['query_id', 'pred_rank']]
120
- print(f"CSV loaded, shape: {eval_csv.shape}")
121
-
122
- # Get dataset-specific candidate size
123
- candidate_size_dict = {
124
- 'amazon': 957192,
125
- 'mag': 700244, # 1872968 - 1172724
126
- 'prime': 129375
127
- }
128
-
129
- if dataset not in candidate_size_dict:
130
  raise ValueError(f"Invalid dataset '{dataset}'")
 
 
131
 
 
 
 
 
 
 
 
 
132
  max_candidate_id = candidate_size_dict[dataset]
133
- print(f"Dataset {dataset} has {max_candidate_id} candidates")
134
 
 
135
  evaluator = Evaluator(candidate_ids_dict[dataset])
136
  eval_metrics = ['hit@1', 'hit@5', 'recall@20', 'mrr']
137
-
138
  qa_dataset = load_qa(dataset, human_generated_eval=split == 'human_generated_eval')
139
  split_idx = qa_dataset.get_idx_split()
140
  all_indices = split_idx[split].tolist()
141
 
142
- print(f"Processing {len(all_indices)} instances...")
143
 
144
- # Process in batches using ThreadPoolExecutor
145
- batch_size = 100
146
  results_list = []
147
- progress_queue = Queue()
148
- valid_results_count = 0
149
  error_count = 0
150
 
151
- def process_batch(batch_indices):
152
- nonlocal valid_results_count, error_count
153
- batch_results = []
154
- with ThreadPoolExecutor(max_workers=num_workers) as executor:
155
- futures = [
156
- executor.submit(process_single_instance,
157
- (idx, eval_csv, qa_dataset, evaluator, eval_metrics, max_candidate_id))
158
- for idx in batch_indices
159
- ]
160
- for future in futures:
 
161
  try:
162
  result = future.result()
163
  if result is not None:
164
- batch_results.append(result)
165
- valid_results_count += 1
166
  else:
167
  error_count += 1
168
  except Exception as e:
169
- print(f"Error in batch processing: {str(e)}")
170
  error_count += 1
171
- progress_queue.put(1)
172
- return batch_results
173
-
174
- # Process batches with progress tracking
175
- total_batches = (len(all_indices) + batch_size - 1) // batch_size
176
- remaining_indices = len(all_indices)
177
-
178
- def update_progress():
179
- with tqdm(total=len(all_indices), desc="Processing instances") as pbar:
180
- completed = 0
181
- while completed < len(all_indices):
182
- progress_queue.get()
183
- completed += 1
184
  pbar.update(1)
185
-
186
- # Start progress monitoring thread
187
- progress_thread = threading.Thread(target=update_progress)
188
- progress_thread.start()
189
-
190
- # Process batches
191
- for i in range(0, len(all_indices), batch_size):
192
- batch_indices = all_indices[i:min(i + batch_size, len(all_indices))]
193
- batch_results = process_batch(batch_indices)
194
- results_list.extend(batch_results)
195
- remaining_indices -= len(batch_indices)
196
- print(f"\rBatch {i//batch_size + 1}/{total_batches} completed. "
197
- f"Valid: {valid_results_count}, Errors: {error_count}, Remaining: {remaining_indices}")
198
-
199
- progress_thread.join()
200
-
201
  # Compute final metrics
202
  if not results_list:
203
  raise ValueError("No valid results were produced")
204
 
205
- print(f"\nProcessing complete. Valid results: {valid_results_count}, Errors: {error_count}")
206
 
207
  results_df = pd.DataFrame(results_list)
208
  final_results = {
@@ -211,9 +181,9 @@ def compute_metrics(csv_path: str, dataset: str, split: str, num_workers: int =
211
  }
212
 
213
  elapsed_time = time.time() - start_time
214
- print(f"\nMetrics computation completed in {elapsed_time:.2f} seconds")
215
  return final_results
216
-
217
  except Exception as error:
218
  elapsed_time = time.time() - start_time
219
  error_msg = f"Error in compute_metrics ({elapsed_time:.2f}s): {str(error)}"
 
7
  import json
8
  import torch
9
  from tqdm import tqdm
10
+ from concurrent.futures import ThreadPoolExecutor, as_completed
11
  import smtplib
12
  from email.mime.multipart import MIMEMultipart
13
  from email.mime.text import MIMEText
 
41
 
42
 
43
  def process_single_instance(args):
44
+ """Process a single instance with improved validation and error handling"""
45
  idx, eval_csv, qa_dataset, evaluator, eval_metrics, max_candidate_id = args
46
  try:
47
+ # Get query data
48
  query, query_id, answer_ids, meta_info = qa_dataset[idx]
49
 
50
+ # Get predictions
51
  matching_preds = eval_csv[eval_csv['query_id'] == query_id]['pred_rank']
52
  if len(matching_preds) == 0:
53
  print(f"Warning: No prediction found for query_id {query_id}")
54
  return None
55
  elif len(matching_preds) > 1:
56
  print(f"Warning: Multiple predictions found for query_id {query_id}, using first one")
57
+
58
+ pred_rank = matching_preds.iloc[0]
 
59
 
60
  # Parse prediction
61
+ if isinstance(pred_rank, str):
62
+ try:
63
  pred_rank = eval(pred_rank)
64
+ except Exception as e:
65
+ print(f"Error parsing pred_rank for query_id {query_id}: {str(e)}")
 
 
66
  return None
67
+
68
+ # Validate prediction format
 
 
 
69
  if not isinstance(pred_rank, list):
70
  print(f"Warning: pred_rank is not a list for query_id {query_id}")
71
  return None
72
 
73
+ # Validate and filter prediction values
74
+ valid_pred_rank = []
75
+ for rank in pred_rank[:100]: # Only use top 100 predictions
76
+ if isinstance(rank, (int, np.integer)) and 0 <= rank < max_candidate_id:
77
+ valid_pred_rank.append(rank)
78
+ else:
79
+ print(f"Warning: Invalid prediction {rank} for query_id {query_id}")
80
+
81
+ if not valid_pred_rank:
82
+ print(f"Warning: No valid predictions for query_id {query_id}")
83
+ return None
84
 
85
+ # Create prediction dictionary with valid predictions only
86
+ pred_dict = {rank: -i for i, rank in enumerate(valid_pred_rank)}
87
+
88
+ # Convert answer_ids to tensor
89
  answer_ids = torch.LongTensor(answer_ids)
 
90
 
91
+ # Evaluate
92
+ result = evaluator.evaluate(pred_dict, answer_ids, metrics=eval_metrics)
93
  result["idx"], result["query_id"] = idx, query_id
94
  return result
95
 
 
98
  return None
99
 
100
  def compute_metrics(csv_path: str, dataset: str, split: str, num_workers: int = 4):
101
+ """Compute metrics with improved thread safety and error handling"""
102
  start_time = time.time()
103
+
104
+ # Dataset configuration
105
  candidate_ids_dict = {
106
  'amazon': [i for i in range(957192)],
107
  'mag': [i for i in range(1172724, 1872968)],
108
  'prime': [i for i in range(129375)]
109
  }
110
+
111
+ candidate_size_dict = {
112
+ 'amazon': 957192,
113
+ 'mag': 700244, # 1872968 - 1172724
114
+ 'prime': 129375
115
+ }
116
+
117
  try:
118
+ # Input validation
119
+ if dataset not in candidate_ids_dict:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  raise ValueError(f"Invalid dataset '{dataset}'")
121
+ if split not in ['test', 'test-0.1', 'human_generated_eval']:
122
+ raise ValueError(f"Invalid split '{split}'")
123
 
124
+ # Load and validate CSV
125
+ print(f"\nLoading data for {dataset} {split}")
126
+ eval_csv = pd.read_csv(csv_path)
127
+ required_columns = ['query_id', 'pred_rank']
128
+ if not all(col in eval_csv.columns for col in required_columns):
129
+ raise ValueError(f"CSV must contain columns: {required_columns}")
130
+
131
+ eval_csv = eval_csv[required_columns]
132
  max_candidate_id = candidate_size_dict[dataset]
 
133
 
134
+ # Initialize components
135
  evaluator = Evaluator(candidate_ids_dict[dataset])
136
  eval_metrics = ['hit@1', 'hit@5', 'recall@20', 'mrr']
 
137
  qa_dataset = load_qa(dataset, human_generated_eval=split == 'human_generated_eval')
138
  split_idx = qa_dataset.get_idx_split()
139
  all_indices = split_idx[split].tolist()
140
 
141
+ print(f"Processing {len(all_indices)} instances with {num_workers} threads")
142
 
143
+ # Process instances
 
144
  results_list = []
145
+ valid_count = 0
 
146
  error_count = 0
147
 
148
+ with ThreadPoolExecutor(max_workers=num_workers) as executor:
149
+ futures = [
150
+ executor.submit(
151
+ process_single_instance,
152
+ (idx, eval_csv, qa_dataset, evaluator, eval_metrics, max_candidate_id)
153
+ )
154
+ for idx in all_indices
155
+ ]
156
+
157
+ with tqdm(total=len(futures), desc="Processing") as pbar:
158
+ for future in as_completed(futures):
159
  try:
160
  result = future.result()
161
  if result is not None:
162
+ results_list.append(result)
163
+ valid_count += 1
164
  else:
165
  error_count += 1
166
  except Exception as e:
167
+ print(f"Error in future: {str(e)}")
168
  error_count += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  pbar.update(1)
170
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  # Compute final metrics
172
  if not results_list:
173
  raise ValueError("No valid results were produced")
174
 
175
+ print(f"\nProcessing complete. Valid: {valid_count}, Errors: {error_count}")
176
 
177
  results_df = pd.DataFrame(results_list)
178
  final_results = {
 
181
  }
182
 
183
  elapsed_time = time.time() - start_time
184
+ print(f"Completed in {elapsed_time:.2f} seconds")
185
  return final_results
186
+
187
  except Exception as error:
188
  elapsed_time = time.time() - start_time
189
  error_msg = f"Error in compute_metrics ({elapsed_time:.2f}s): {str(error)}"