Spaces:
Running
Running
Shiyu Zhao
commited on
Commit
·
255a7a4
1
Parent(s):
9698e43
Update space
Browse files
app.py
CHANGED
@@ -7,7 +7,7 @@ from datetime import datetime
|
|
7 |
import json
|
8 |
import torch
|
9 |
from tqdm import tqdm
|
10 |
-
from concurrent.futures import ThreadPoolExecutor
|
11 |
import smtplib
|
12 |
from email.mime.multipart import MIMEMultipart
|
13 |
from email.mime.text import MIMEText
|
@@ -41,50 +41,55 @@ except Exception as e:
|
|
41 |
|
42 |
|
43 |
def process_single_instance(args):
|
44 |
-
"""Process a single instance with improved
|
45 |
idx, eval_csv, qa_dataset, evaluator, eval_metrics, max_candidate_id = args
|
46 |
try:
|
|
|
47 |
query, query_id, answer_ids, meta_info = qa_dataset[idx]
|
48 |
|
49 |
-
# Get predictions
|
50 |
matching_preds = eval_csv[eval_csv['query_id'] == query_id]['pred_rank']
|
51 |
if len(matching_preds) == 0:
|
52 |
print(f"Warning: No prediction found for query_id {query_id}")
|
53 |
return None
|
54 |
elif len(matching_preds) > 1:
|
55 |
print(f"Warning: Multiple predictions found for query_id {query_id}, using first one")
|
56 |
-
|
57 |
-
|
58 |
-
pred_rank = matching_preds.iloc[0]
|
59 |
|
60 |
# Parse prediction
|
61 |
-
|
62 |
-
|
63 |
pred_rank = eval(pred_rank)
|
64 |
-
|
65 |
-
|
66 |
-
else:
|
67 |
-
print(f"Warning: Unexpected pred_rank type for query_id {query_id}: {type(pred_rank)}")
|
68 |
return None
|
69 |
-
|
70 |
-
|
71 |
-
return None
|
72 |
-
|
73 |
-
# Validate and filter predictions
|
74 |
if not isinstance(pred_rank, list):
|
75 |
print(f"Warning: pred_rank is not a list for query_id {query_id}")
|
76 |
return None
|
77 |
|
78 |
-
#
|
79 |
-
|
80 |
-
#
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
-
#
|
84 |
-
pred_dict = {
|
|
|
|
|
85 |
answer_ids = torch.LongTensor(answer_ids)
|
86 |
-
result = evaluator.evaluate(pred_dict, answer_ids, metrics=eval_metrics)
|
87 |
|
|
|
|
|
88 |
result["idx"], result["query_id"] = idx, query_id
|
89 |
return result
|
90 |
|
@@ -93,116 +98,81 @@ def process_single_instance(args):
|
|
93 |
return None
|
94 |
|
95 |
def compute_metrics(csv_path: str, dataset: str, split: str, num_workers: int = 4):
|
96 |
-
"""Compute metrics with improved
|
97 |
start_time = time.time()
|
|
|
|
|
98 |
candidate_ids_dict = {
|
99 |
'amazon': [i for i in range(957192)],
|
100 |
'mag': [i for i in range(1172724, 1872968)],
|
101 |
'prime': [i for i in range(129375)]
|
102 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
try:
|
104 |
-
|
105 |
-
|
106 |
-
# Load CSV and validate format
|
107 |
-
print("Loading and validating CSV file...")
|
108 |
-
eval_csv = pd.read_csv(csv_path)
|
109 |
-
if 'query_id' not in eval_csv.columns or 'pred_rank' not in eval_csv.columns:
|
110 |
-
raise ValueError("CSV must contain 'query_id' and 'pred_rank' columns")
|
111 |
-
|
112 |
-
# Check for duplicate query_ids
|
113 |
-
duplicate_queries = eval_csv['query_id'].duplicated()
|
114 |
-
if duplicate_queries.any():
|
115 |
-
dup_count = duplicate_queries.sum()
|
116 |
-
print(f"Warning: Found {dup_count} duplicate query_ids in CSV")
|
117 |
-
|
118 |
-
# Keep only necessary columns
|
119 |
-
eval_csv = eval_csv[['query_id', 'pred_rank']]
|
120 |
-
print(f"CSV loaded, shape: {eval_csv.shape}")
|
121 |
-
|
122 |
-
# Get dataset-specific candidate size
|
123 |
-
candidate_size_dict = {
|
124 |
-
'amazon': 957192,
|
125 |
-
'mag': 700244, # 1872968 - 1172724
|
126 |
-
'prime': 129375
|
127 |
-
}
|
128 |
-
|
129 |
-
if dataset not in candidate_size_dict:
|
130 |
raise ValueError(f"Invalid dataset '{dataset}'")
|
|
|
|
|
131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
max_candidate_id = candidate_size_dict[dataset]
|
133 |
-
print(f"Dataset {dataset} has {max_candidate_id} candidates")
|
134 |
|
|
|
135 |
evaluator = Evaluator(candidate_ids_dict[dataset])
|
136 |
eval_metrics = ['hit@1', 'hit@5', 'recall@20', 'mrr']
|
137 |
-
|
138 |
qa_dataset = load_qa(dataset, human_generated_eval=split == 'human_generated_eval')
|
139 |
split_idx = qa_dataset.get_idx_split()
|
140 |
all_indices = split_idx[split].tolist()
|
141 |
|
142 |
-
print(f"Processing {len(all_indices)} instances
|
143 |
|
144 |
-
# Process
|
145 |
-
batch_size = 100
|
146 |
results_list = []
|
147 |
-
|
148 |
-
valid_results_count = 0
|
149 |
error_count = 0
|
150 |
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
|
|
161 |
try:
|
162 |
result = future.result()
|
163 |
if result is not None:
|
164 |
-
|
165 |
-
|
166 |
else:
|
167 |
error_count += 1
|
168 |
except Exception as e:
|
169 |
-
print(f"Error in
|
170 |
error_count += 1
|
171 |
-
progress_queue.put(1)
|
172 |
-
return batch_results
|
173 |
-
|
174 |
-
# Process batches with progress tracking
|
175 |
-
total_batches = (len(all_indices) + batch_size - 1) // batch_size
|
176 |
-
remaining_indices = len(all_indices)
|
177 |
-
|
178 |
-
def update_progress():
|
179 |
-
with tqdm(total=len(all_indices), desc="Processing instances") as pbar:
|
180 |
-
completed = 0
|
181 |
-
while completed < len(all_indices):
|
182 |
-
progress_queue.get()
|
183 |
-
completed += 1
|
184 |
pbar.update(1)
|
185 |
-
|
186 |
-
# Start progress monitoring thread
|
187 |
-
progress_thread = threading.Thread(target=update_progress)
|
188 |
-
progress_thread.start()
|
189 |
-
|
190 |
-
# Process batches
|
191 |
-
for i in range(0, len(all_indices), batch_size):
|
192 |
-
batch_indices = all_indices[i:min(i + batch_size, len(all_indices))]
|
193 |
-
batch_results = process_batch(batch_indices)
|
194 |
-
results_list.extend(batch_results)
|
195 |
-
remaining_indices -= len(batch_indices)
|
196 |
-
print(f"\rBatch {i//batch_size + 1}/{total_batches} completed. "
|
197 |
-
f"Valid: {valid_results_count}, Errors: {error_count}, Remaining: {remaining_indices}")
|
198 |
-
|
199 |
-
progress_thread.join()
|
200 |
-
|
201 |
# Compute final metrics
|
202 |
if not results_list:
|
203 |
raise ValueError("No valid results were produced")
|
204 |
|
205 |
-
print(f"\nProcessing complete. Valid
|
206 |
|
207 |
results_df = pd.DataFrame(results_list)
|
208 |
final_results = {
|
@@ -211,9 +181,9 @@ def compute_metrics(csv_path: str, dataset: str, split: str, num_workers: int =
|
|
211 |
}
|
212 |
|
213 |
elapsed_time = time.time() - start_time
|
214 |
-
print(f"
|
215 |
return final_results
|
216 |
-
|
217 |
except Exception as error:
|
218 |
elapsed_time = time.time() - start_time
|
219 |
error_msg = f"Error in compute_metrics ({elapsed_time:.2f}s): {str(error)}"
|
|
|
7 |
import json
|
8 |
import torch
|
9 |
from tqdm import tqdm
|
10 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
11 |
import smtplib
|
12 |
from email.mime.multipart import MIMEMultipart
|
13 |
from email.mime.text import MIMEText
|
|
|
41 |
|
42 |
|
43 |
def process_single_instance(args):
|
44 |
+
"""Process a single instance with improved validation and error handling"""
|
45 |
idx, eval_csv, qa_dataset, evaluator, eval_metrics, max_candidate_id = args
|
46 |
try:
|
47 |
+
# Get query data
|
48 |
query, query_id, answer_ids, meta_info = qa_dataset[idx]
|
49 |
|
50 |
+
# Get predictions
|
51 |
matching_preds = eval_csv[eval_csv['query_id'] == query_id]['pred_rank']
|
52 |
if len(matching_preds) == 0:
|
53 |
print(f"Warning: No prediction found for query_id {query_id}")
|
54 |
return None
|
55 |
elif len(matching_preds) > 1:
|
56 |
print(f"Warning: Multiple predictions found for query_id {query_id}, using first one")
|
57 |
+
|
58 |
+
pred_rank = matching_preds.iloc[0]
|
|
|
59 |
|
60 |
# Parse prediction
|
61 |
+
if isinstance(pred_rank, str):
|
62 |
+
try:
|
63 |
pred_rank = eval(pred_rank)
|
64 |
+
except Exception as e:
|
65 |
+
print(f"Error parsing pred_rank for query_id {query_id}: {str(e)}")
|
|
|
|
|
66 |
return None
|
67 |
+
|
68 |
+
# Validate prediction format
|
|
|
|
|
|
|
69 |
if not isinstance(pred_rank, list):
|
70 |
print(f"Warning: pred_rank is not a list for query_id {query_id}")
|
71 |
return None
|
72 |
|
73 |
+
# Validate and filter prediction values
|
74 |
+
valid_pred_rank = []
|
75 |
+
for rank in pred_rank[:100]: # Only use top 100 predictions
|
76 |
+
if isinstance(rank, (int, np.integer)) and 0 <= rank < max_candidate_id:
|
77 |
+
valid_pred_rank.append(rank)
|
78 |
+
else:
|
79 |
+
print(f"Warning: Invalid prediction {rank} for query_id {query_id}")
|
80 |
+
|
81 |
+
if not valid_pred_rank:
|
82 |
+
print(f"Warning: No valid predictions for query_id {query_id}")
|
83 |
+
return None
|
84 |
|
85 |
+
# Create prediction dictionary with valid predictions only
|
86 |
+
pred_dict = {rank: -i for i, rank in enumerate(valid_pred_rank)}
|
87 |
+
|
88 |
+
# Convert answer_ids to tensor
|
89 |
answer_ids = torch.LongTensor(answer_ids)
|
|
|
90 |
|
91 |
+
# Evaluate
|
92 |
+
result = evaluator.evaluate(pred_dict, answer_ids, metrics=eval_metrics)
|
93 |
result["idx"], result["query_id"] = idx, query_id
|
94 |
return result
|
95 |
|
|
|
98 |
return None
|
99 |
|
100 |
def compute_metrics(csv_path: str, dataset: str, split: str, num_workers: int = 4):
|
101 |
+
"""Compute metrics with improved thread safety and error handling"""
|
102 |
start_time = time.time()
|
103 |
+
|
104 |
+
# Dataset configuration
|
105 |
candidate_ids_dict = {
|
106 |
'amazon': [i for i in range(957192)],
|
107 |
'mag': [i for i in range(1172724, 1872968)],
|
108 |
'prime': [i for i in range(129375)]
|
109 |
}
|
110 |
+
|
111 |
+
candidate_size_dict = {
|
112 |
+
'amazon': 957192,
|
113 |
+
'mag': 700244, # 1872968 - 1172724
|
114 |
+
'prime': 129375
|
115 |
+
}
|
116 |
+
|
117 |
try:
|
118 |
+
# Input validation
|
119 |
+
if dataset not in candidate_ids_dict:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
raise ValueError(f"Invalid dataset '{dataset}'")
|
121 |
+
if split not in ['test', 'test-0.1', 'human_generated_eval']:
|
122 |
+
raise ValueError(f"Invalid split '{split}'")
|
123 |
|
124 |
+
# Load and validate CSV
|
125 |
+
print(f"\nLoading data for {dataset} {split}")
|
126 |
+
eval_csv = pd.read_csv(csv_path)
|
127 |
+
required_columns = ['query_id', 'pred_rank']
|
128 |
+
if not all(col in eval_csv.columns for col in required_columns):
|
129 |
+
raise ValueError(f"CSV must contain columns: {required_columns}")
|
130 |
+
|
131 |
+
eval_csv = eval_csv[required_columns]
|
132 |
max_candidate_id = candidate_size_dict[dataset]
|
|
|
133 |
|
134 |
+
# Initialize components
|
135 |
evaluator = Evaluator(candidate_ids_dict[dataset])
|
136 |
eval_metrics = ['hit@1', 'hit@5', 'recall@20', 'mrr']
|
|
|
137 |
qa_dataset = load_qa(dataset, human_generated_eval=split == 'human_generated_eval')
|
138 |
split_idx = qa_dataset.get_idx_split()
|
139 |
all_indices = split_idx[split].tolist()
|
140 |
|
141 |
+
print(f"Processing {len(all_indices)} instances with {num_workers} threads")
|
142 |
|
143 |
+
# Process instances
|
|
|
144 |
results_list = []
|
145 |
+
valid_count = 0
|
|
|
146 |
error_count = 0
|
147 |
|
148 |
+
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
149 |
+
futures = [
|
150 |
+
executor.submit(
|
151 |
+
process_single_instance,
|
152 |
+
(idx, eval_csv, qa_dataset, evaluator, eval_metrics, max_candidate_id)
|
153 |
+
)
|
154 |
+
for idx in all_indices
|
155 |
+
]
|
156 |
+
|
157 |
+
with tqdm(total=len(futures), desc="Processing") as pbar:
|
158 |
+
for future in as_completed(futures):
|
159 |
try:
|
160 |
result = future.result()
|
161 |
if result is not None:
|
162 |
+
results_list.append(result)
|
163 |
+
valid_count += 1
|
164 |
else:
|
165 |
error_count += 1
|
166 |
except Exception as e:
|
167 |
+
print(f"Error in future: {str(e)}")
|
168 |
error_count += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
pbar.update(1)
|
170 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
# Compute final metrics
|
172 |
if not results_list:
|
173 |
raise ValueError("No valid results were produced")
|
174 |
|
175 |
+
print(f"\nProcessing complete. Valid: {valid_count}, Errors: {error_count}")
|
176 |
|
177 |
results_df = pd.DataFrame(results_list)
|
178 |
final_results = {
|
|
|
181 |
}
|
182 |
|
183 |
elapsed_time = time.time() - start_time
|
184 |
+
print(f"Completed in {elapsed_time:.2f} seconds")
|
185 |
return final_results
|
186 |
+
|
187 |
except Exception as error:
|
188 |
elapsed_time = time.time() - start_time
|
189 |
error_msg = f"Error in compute_metrics ({elapsed_time:.2f}s): {str(error)}"
|