Spaces:
Running
Running
Shiyu Zhao
commited on
Commit
·
df3974d
1
Parent(s):
d38a2a4
Update space
Browse files- app.py +144 -78
- submissions/ance_test_abc/predictions_20241115_001153.csv +0 -0
app.py
CHANGED
@@ -36,113 +36,179 @@ except Exception as e:
|
|
36 |
result_lock = Lock()
|
37 |
|
38 |
def process_single_instance(args):
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
try:
|
43 |
-
|
44 |
-
if query_id not in eval_dict:
|
45 |
-
raise IndexError(f'Error when processing query_id={query_id}, please make sure the predicted results exist for this query.')
|
46 |
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
-
|
50 |
-
|
|
|
51 |
pred_rank = eval(pred_rank)
|
52 |
-
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
if not isinstance(pred_rank, list):
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
answer_ids = torch.LongTensor(answer_ids)
|
60 |
-
|
61 |
-
# Evaluate metrics
|
62 |
result = evaluator.evaluate(pred_dict, answer_ids, metrics=eval_metrics)
|
|
|
63 |
result["idx"], result["query_id"] = idx, query_id
|
64 |
return result
|
65 |
|
66 |
except Exception as e:
|
67 |
-
|
68 |
-
|
69 |
-
def compute_metrics(csv_path: str, dataset: str, split: str, num_threads: int = 4):
|
70 |
-
candidate_ids_dict = {
|
71 |
-
'amazon': [i for i in range(957192)],
|
72 |
-
'mag': [i for i in range(1172724, 1872968)],
|
73 |
-
'prime': [i for i in range(129375)]
|
74 |
-
}
|
75 |
|
|
|
|
|
|
|
76 |
try:
|
77 |
-
|
|
|
|
|
|
|
78 |
eval_csv = pd.read_csv(csv_path)
|
79 |
-
if 'query_id' not in eval_csv.columns:
|
80 |
-
raise ValueError(
|
81 |
-
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
-
# Convert DataFrame to dictionary for thread-safe access
|
85 |
-
eval_dict = dict(zip(eval_csv['query_id'], eval_csv['pred_rank']))
|
86 |
-
|
87 |
-
# Validate input parameters
|
88 |
-
if dataset not in candidate_ids_dict:
|
89 |
-
raise ValueError(f"Invalid dataset '{dataset}', expected one of {list(candidate_ids_dict.keys())}.")
|
90 |
-
if split not in ['test', 'test-0.1', 'human_generated_eval']:
|
91 |
-
raise ValueError(f"Invalid split '{split}', expected one of ['test', 'test-0.1', 'human_generated_eval'].")
|
92 |
-
|
93 |
-
# Initialize evaluator and metrics
|
94 |
evaluator = Evaluator(candidate_ids_dict[dataset])
|
95 |
eval_metrics = ['hit@1', 'hit@5', 'recall@20', 'mrr']
|
96 |
-
|
97 |
-
# Load dataset and get split indices
|
98 |
qa_dataset = load_qa(dataset, human_generated_eval=split == 'human_generated_eval')
|
99 |
split_idx = qa_dataset.get_idx_split()
|
100 |
all_indices = split_idx[split].tolist()
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
103 |
results_list = []
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
# Compute final metrics
|
|
|
|
|
|
|
|
|
|
|
133 |
results_df = pd.DataFrame(results_list)
|
134 |
final_results = {
|
135 |
-
metric:
|
|
|
136 |
}
|
137 |
-
|
|
|
|
|
138 |
return final_results
|
139 |
|
140 |
-
except pd.errors.EmptyDataError:
|
141 |
-
return "Error: The CSV file is empty or could not be read. Please check the file and try again."
|
142 |
-
except FileNotFoundError:
|
143 |
-
return f"Error: The file {csv_path} could not be found. Please check the file path and try again."
|
144 |
except Exception as error:
|
145 |
-
|
|
|
|
|
|
|
146 |
|
147 |
|
148 |
# Data dictionaries for leaderboard
|
|
|
36 |
result_lock = Lock()
|
37 |
|
38 |
def process_single_instance(args):
|
39 |
+
"""Process a single instance with improved prediction handling"""
|
40 |
+
idx, eval_csv, qa_dataset, evaluator, eval_metrics, max_candidate_id = args
|
|
|
41 |
try:
|
42 |
+
query, query_id, answer_ids, meta_info = qa_dataset[idx]
|
|
|
|
|
43 |
|
44 |
+
# Get predictions with better error handling
|
45 |
+
matching_preds = eval_csv[eval_csv['query_id'] == query_id]['pred_rank']
|
46 |
+
if len(matching_preds) == 0:
|
47 |
+
print(f"Warning: No prediction found for query_id {query_id}")
|
48 |
+
return None
|
49 |
+
elif len(matching_preds) > 1:
|
50 |
+
print(f"Warning: Multiple predictions found for query_id {query_id}, using first one")
|
51 |
+
pred_rank = matching_preds.iloc[0]
|
52 |
+
else:
|
53 |
+
pred_rank = matching_preds.iloc[0]
|
54 |
|
55 |
+
# Parse prediction
|
56 |
+
try:
|
57 |
+
if isinstance(pred_rank, str):
|
58 |
pred_rank = eval(pred_rank)
|
59 |
+
elif isinstance(pred_rank, list):
|
60 |
+
pass
|
61 |
+
else:
|
62 |
+
print(f"Warning: Unexpected pred_rank type for query_id {query_id}: {type(pred_rank)}")
|
63 |
+
return None
|
64 |
+
except Exception as e:
|
65 |
+
print(f"Error parsing pred_rank for query_id {query_id}: {str(e)}")
|
66 |
+
return None
|
67 |
+
|
68 |
+
# Validate and filter predictions
|
69 |
if not isinstance(pred_rank, list):
|
70 |
+
print(f"Warning: pred_rank is not a list for query_id {query_id}")
|
71 |
+
return None
|
72 |
+
|
73 |
+
valid_ranks = [rank for rank in pred_rank if isinstance(rank, (int, np.integer)) and 0 <= rank < max_candidate_id]
|
74 |
+
if len(valid_ranks) == 0:
|
75 |
+
print(f"Warning: No valid predictions for query_id {query_id}")
|
76 |
+
return None
|
77 |
+
|
78 |
+
# Use only valid predictions
|
79 |
+
pred_dict = {valid_ranks[i]: -i for i in range(min(100, len(valid_ranks)))}
|
80 |
answer_ids = torch.LongTensor(answer_ids)
|
|
|
|
|
81 |
result = evaluator.evaluate(pred_dict, answer_ids, metrics=eval_metrics)
|
82 |
+
|
83 |
result["idx"], result["query_id"] = idx, query_id
|
84 |
return result
|
85 |
|
86 |
except Exception as e:
|
87 |
+
print(f"Error processing idx {idx}: {str(e)}")
|
88 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
+
def compute_metrics(csv_path: str, dataset: str, split: str, num_workers: int = 4):
|
91 |
+
"""Compute metrics with improved prediction handling"""
|
92 |
+
start_time = time.time()
|
93 |
try:
|
94 |
+
print(f"\nStarting compute_metrics for {dataset} {split}")
|
95 |
+
|
96 |
+
# Load CSV and validate format
|
97 |
+
print("Loading and validating CSV file...")
|
98 |
eval_csv = pd.read_csv(csv_path)
|
99 |
+
if 'query_id' not in eval_csv.columns or 'pred_rank' not in eval_csv.columns:
|
100 |
+
raise ValueError("CSV must contain 'query_id' and 'pred_rank' columns")
|
101 |
+
|
102 |
+
# Check for duplicate query_ids
|
103 |
+
duplicate_queries = eval_csv['query_id'].duplicated()
|
104 |
+
if duplicate_queries.any():
|
105 |
+
dup_count = duplicate_queries.sum()
|
106 |
+
print(f"Warning: Found {dup_count} duplicate query_ids in CSV")
|
107 |
+
|
108 |
+
# Keep only necessary columns
|
109 |
+
eval_csv = eval_csv[['query_id', 'pred_rank']]
|
110 |
+
print(f"CSV loaded, shape: {eval_csv.shape}")
|
111 |
+
|
112 |
+
# Get dataset-specific candidate size
|
113 |
+
candidate_size_dict = {
|
114 |
+
'amazon': 957192,
|
115 |
+
'mag': 700244, # 1872968 - 1172724
|
116 |
+
'prime': 129375
|
117 |
+
}
|
118 |
+
|
119 |
+
if dataset not in candidate_size_dict:
|
120 |
+
raise ValueError(f"Invalid dataset '{dataset}'")
|
121 |
+
|
122 |
+
max_candidate_id = candidate_size_dict[dataset]
|
123 |
+
print(f"Dataset {dataset} has {max_candidate_id} candidates")
|
124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
evaluator = Evaluator(candidate_ids_dict[dataset])
|
126 |
eval_metrics = ['hit@1', 'hit@5', 'recall@20', 'mrr']
|
127 |
+
|
|
|
128 |
qa_dataset = load_qa(dataset, human_generated_eval=split == 'human_generated_eval')
|
129 |
split_idx = qa_dataset.get_idx_split()
|
130 |
all_indices = split_idx[split].tolist()
|
131 |
+
|
132 |
+
print(f"Processing {len(all_indices)} instances...")
|
133 |
+
|
134 |
+
# Process in batches using ThreadPoolExecutor
|
135 |
+
batch_size = 100
|
136 |
results_list = []
|
137 |
+
progress_queue = Queue()
|
138 |
+
valid_results_count = 0
|
139 |
+
error_count = 0
|
140 |
+
|
141 |
+
def process_batch(batch_indices):
|
142 |
+
nonlocal valid_results_count, error_count
|
143 |
+
batch_results = []
|
144 |
+
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
145 |
+
futures = [
|
146 |
+
executor.submit(process_single_instance,
|
147 |
+
(idx, eval_csv, qa_dataset, evaluator, eval_metrics, max_candidate_id))
|
148 |
+
for idx in batch_indices
|
149 |
+
]
|
150 |
+
for future in futures:
|
151 |
+
try:
|
152 |
+
result = future.result()
|
153 |
+
if result is not None:
|
154 |
+
batch_results.append(result)
|
155 |
+
valid_results_count += 1
|
156 |
+
else:
|
157 |
+
error_count += 1
|
158 |
+
except Exception as e:
|
159 |
+
print(f"Error in batch processing: {str(e)}")
|
160 |
+
error_count += 1
|
161 |
+
progress_queue.put(1)
|
162 |
+
return batch_results
|
163 |
+
|
164 |
+
# Process batches with progress tracking
|
165 |
+
total_batches = (len(all_indices) + batch_size - 1) // batch_size
|
166 |
+
remaining_indices = len(all_indices)
|
167 |
+
|
168 |
+
def update_progress():
|
169 |
+
with tqdm(total=len(all_indices), desc="Processing instances") as pbar:
|
170 |
+
completed = 0
|
171 |
+
while completed < len(all_indices):
|
172 |
+
progress_queue.get()
|
173 |
+
completed += 1
|
174 |
+
pbar.update(1)
|
175 |
+
|
176 |
+
# Start progress monitoring thread
|
177 |
+
progress_thread = threading.Thread(target=update_progress)
|
178 |
+
progress_thread.start()
|
179 |
+
|
180 |
+
# Process batches
|
181 |
+
for i in range(0, len(all_indices), batch_size):
|
182 |
+
batch_indices = all_indices[i:min(i + batch_size, len(all_indices))]
|
183 |
+
batch_results = process_batch(batch_indices)
|
184 |
+
results_list.extend(batch_results)
|
185 |
+
remaining_indices -= len(batch_indices)
|
186 |
+
print(f"\rBatch {i//batch_size + 1}/{total_batches} completed. "
|
187 |
+
f"Valid: {valid_results_count}, Errors: {error_count}, Remaining: {remaining_indices}")
|
188 |
+
|
189 |
+
progress_thread.join()
|
190 |
|
191 |
# Compute final metrics
|
192 |
+
if not results_list:
|
193 |
+
raise ValueError("No valid results were produced")
|
194 |
+
|
195 |
+
print(f"\nProcessing complete. Valid results: {valid_results_count}, Errors: {error_count}")
|
196 |
+
|
197 |
results_df = pd.DataFrame(results_list)
|
198 |
final_results = {
|
199 |
+
metric: results_df[metric].mean()
|
200 |
+
for metric in eval_metrics
|
201 |
}
|
202 |
+
|
203 |
+
elapsed_time = time.time() - start_time
|
204 |
+
print(f"\nMetrics computation completed in {elapsed_time:.2f} seconds")
|
205 |
return final_results
|
206 |
|
|
|
|
|
|
|
|
|
207 |
except Exception as error:
|
208 |
+
elapsed_time = time.time() - start_time
|
209 |
+
error_msg = f"Error in compute_metrics ({elapsed_time:.2f}s): {str(error)}"
|
210 |
+
print(error_msg)
|
211 |
+
return error_msg
|
212 |
|
213 |
|
214 |
# Data dictionaries for leaderboard
|
submissions/ance_test_abc/predictions_20241115_001153.csv
DELETED
The diff for this file is too large to render.
See raw diff
|
|