Shiyu Zhao commited on
Commit
ca736d2
·
1 Parent(s): c4923ca

Update space

Browse files
Files changed (1) hide show
  1. app.py +57 -29
app.py CHANGED
@@ -14,6 +14,7 @@ from email.mime.text import MIMEText
14
  from huggingface_hub import HfApi
15
  import shutil
16
  import tempfile
 
17
 
18
  from stark_qa import load_qa
19
  from stark_qa.evaluator import Evaluator
@@ -30,31 +31,32 @@ except Exception as e:
30
 
31
 
32
  def process_single_instance(args):
 
33
  idx, eval_csv, qa_dataset, evaluator, eval_metrics = args
34
- query, query_id, answer_ids, meta_info = qa_dataset[idx]
35
-
36
  try:
37
- pred_rank = eval_csv[eval_csv['query_id'] == query_id]['pred_rank'].item()
38
- except IndexError:
39
- raise IndexError(f'Error when processing query_id={query_id}, please make sure the predicted results exist for this query.')
40
- except Exception as e:
41
- raise RuntimeError(f'Unexpected error occurred while fetching prediction rank for query_id={query_id}: {e}')
42
-
43
- if isinstance(pred_rank, str):
44
  try:
 
 
 
 
 
 
45
  pred_rank = eval(pred_rank)
46
- except SyntaxError as e:
47
- raise ValueError(f'Failed to parse pred_rank as a list for query_id={query_id}: {e}')
48
-
49
- if not isinstance(pred_rank, list):
50
- raise TypeError(f'Error when processing query_id={query_id}, expected pred_rank to be a list but got {type(pred_rank)}.')
51
-
52
- pred_dict = {pred_rank[i]: -i for i in range(min(100, len(pred_rank)))}
53
- answer_ids = torch.LongTensor(answer_ids)
54
- result = evaluator.evaluate(pred_dict, answer_ids, metrics=eval_metrics)
55
-
56
- result["idx"], result["query_id"] = idx, query_id
57
- return result
58
 
59
 
60
  def compute_metrics(csv_path: str, dataset: str, split: str, num_workers: int = 4):
@@ -63,6 +65,7 @@ def compute_metrics(csv_path: str, dataset: str, split: str, num_workers: int =
63
  'mag': [i for i in range(1172724, 1872968)],
64
  'prime': [i for i in range(129375)]
65
  }
 
66
  try:
67
  eval_csv = pd.read_csv(csv_path)
68
  if 'query_id' not in eval_csv.columns:
@@ -77,11 +80,14 @@ def compute_metrics(csv_path: str, dataset: str, split: str, num_workers: int =
77
  if split not in ['test', 'test-0.1', 'human_generated_eval']:
78
  raise ValueError(f"Invalid split '{split}', expected one of ['test', 'test-0.1', 'human_generated_eval'].")
79
 
 
80
  evaluator = Evaluator(candidate_ids_dict[dataset])
81
  eval_metrics = ['hit@1', 'hit@5', 'recall@20', 'mrr']
 
82
  qa_dataset = load_qa(dataset, human_generated_eval=split == 'human_generated_eval')
83
  split_idx = qa_dataset.get_idx_split()
84
  all_indices = split_idx[split].tolist()
 
85
 
86
  results_list = []
87
  # query_ids = []
@@ -102,28 +108,50 @@ def compute_metrics(csv_path: str, dataset: str, split: str, num_workers: int =
102
  # metric: np.mean(eval_csv[eval_csv['query_id'].isin(query_ids)][metric]) for metric in eval_metrics
103
  # }
104
  # return final_result
105
- batch_size = 100
106
- for i in range(0, len(all_indices), batch_size):
107
- max_ind = min(i+batch_size, len(all_indices))
108
- batch_indices = all_indices[i:max_ind]
 
 
 
 
 
 
 
109
  args = [(idx, eval_csv, qa_dataset, evaluator, eval_metrics)
110
  for idx in batch_indices]
111
 
112
  with ProcessPoolExecutor(max_workers=num_workers) as executor:
113
  futures = [executor.submit(process_single_instance, arg)
114
  for arg in args]
115
- for future in as_completed(futures):
116
- results_list.append(future.result())
117
-
118
- # Compute final metrics
 
 
 
 
 
 
 
119
  results_df = pd.DataFrame(results_list)
120
  final_results = {
121
  metric: results_df[metric].mean()
122
  for metric in eval_metrics
123
  }
124
 
 
 
125
  return final_results
126
 
 
 
 
 
 
 
127
  except pd.errors.EmptyDataError:
128
  return "Error: The CSV file is empty or could not be read. Please check the file and try again."
129
  except FileNotFoundError:
 
14
  from huggingface_hub import HfApi
15
  import shutil
16
  import tempfile
17
+ import time
18
 
19
  from stark_qa import load_qa
20
  from stark_qa.evaluator import Evaluator
 
31
 
32
 
33
  def process_single_instance(args):
34
+ """Process a single instance with progress tracking"""
35
  idx, eval_csv, qa_dataset, evaluator, eval_metrics = args
 
 
36
  try:
37
+ query, query_id, answer_ids, meta_info = qa_dataset[idx]
38
+
39
+ # Print progress for debugging
40
+ print(f"Processing query_id: {query_id}")
41
+
 
 
42
  try:
43
+ pred_rank = eval_csv[eval_csv['query_id'] == query_id]['pred_rank'].item()
44
+ except Exception as e:
45
+ print(f"Error getting pred_rank for query_id {query_id}: {str(e)}")
46
+ raise
47
+
48
+ if isinstance(pred_rank, str):
49
  pred_rank = eval(pred_rank)
50
+
51
+ pred_dict = {pred_rank[i]: -i for i in range(min(100, len(pred_rank)))}
52
+ answer_ids = torch.LongTensor(answer_ids)
53
+ result = evaluator.evaluate(pred_dict, answer_ids, metrics=eval_metrics)
54
+
55
+ result["idx"], result["query_id"] = idx, query_id
56
+ return result
57
+ except Exception as e:
58
+ print(f"Error in process_single_instance for idx {idx}: {str(e)}")
59
+ raise
 
 
60
 
61
 
62
  def compute_metrics(csv_path: str, dataset: str, split: str, num_workers: int = 4):
 
65
  'mag': [i for i in range(1172724, 1872968)],
66
  'prime': [i for i in range(129375)]
67
  }
68
+ start_time = time.time()
69
  try:
70
  eval_csv = pd.read_csv(csv_path)
71
  if 'query_id' not in eval_csv.columns:
 
80
  if split not in ['test', 'test-0.1', 'human_generated_eval']:
81
  raise ValueError(f"Invalid split '{split}', expected one of ['test', 'test-0.1', 'human_generated_eval'].")
82
 
83
+ print("Initializing evaluator...")
84
  evaluator = Evaluator(candidate_ids_dict[dataset])
85
  eval_metrics = ['hit@1', 'hit@5', 'recall@20', 'mrr']
86
+ print("Loading QA dataset...")
87
  qa_dataset = load_qa(dataset, human_generated_eval=split == 'human_generated_eval')
88
  split_idx = qa_dataset.get_idx_split()
89
  all_indices = split_idx[split].tolist()
90
+ print(f"Dataset loaded, processing {len(all_indices)} instances")
91
 
92
  results_list = []
93
  # query_ids = []
 
108
  # metric: np.mean(eval_csv[eval_csv['query_id'].isin(query_ids)][metric]) for metric in eval_metrics
109
  # }
110
  # return final_result
111
+ batch_size = 50 # Smaller batch size for more frequent updates
112
+ total_batches = (len(all_indices) + batch_size - 1) // batch_size
113
+
114
+ for batch_num in range(total_batches):
115
+ batch_start = batch_num * batch_size
116
+ batch_end = min((batch_num + 1) * batch_size, len(all_indices))
117
+ batch_indices = all_indices[batch_start:batch_end]
118
+
119
+ print(f"\nProcessing batch {batch_num + 1}/{total_batches}")
120
+ print(f"Batch size: {len(batch_indices)}")
121
+
122
  args = [(idx, eval_csv, qa_dataset, evaluator, eval_metrics)
123
  for idx in batch_indices]
124
 
125
  with ProcessPoolExecutor(max_workers=num_workers) as executor:
126
  futures = [executor.submit(process_single_instance, arg)
127
  for arg in args]
128
+ for future in tqdm(as_completed(futures),
129
+ total=len(futures),
130
+ desc=f"Batch {batch_num + 1}"):
131
+ try:
132
+ result = future.result()
133
+ results_list.append(result)
134
+ except Exception as e:
135
+ print(f"Error processing result: {str(e)}")
136
+ raise
137
+
138
+ print("\nComputing final metrics...")
139
  results_df = pd.DataFrame(results_list)
140
  final_results = {
141
  metric: results_df[metric].mean()
142
  for metric in eval_metrics
143
  }
144
 
145
+ elapsed_time = time.time() - start_time
146
+ print(f"\nMetrics computation completed in {elapsed_time:.2f} seconds")
147
  return final_results
148
 
149
+ except Exception as error:
150
+ elapsed_time = time.time() - start_time
151
+ error_msg = f"Error in compute_metrics ({elapsed_time:.2f}s): {str(error)}"
152
+ print(error_msg)
153
+ return error_msg
154
+
155
  except pd.errors.EmptyDataError:
156
  return "Error: The CSV file is empty or could not be read. Please check the file and try again."
157
  except FileNotFoundError: