Spaces:
Running
Running
Shiyu Zhao
commited on
Commit
·
647ad4c
1
Parent(s):
5dfb93b
Update space
Browse files
app.py
CHANGED
@@ -42,7 +42,7 @@ except Exception as e:
|
|
42 |
|
43 |
def process_single_instance(args):
|
44 |
"""Process a single instance with improved validation and error handling"""
|
45 |
-
idx, eval_csv, qa_dataset, evaluator, eval_metrics
|
46 |
try:
|
47 |
# Get query data
|
48 |
query, query_id, answer_ids, meta_info = qa_dataset[idx]
|
@@ -70,26 +70,22 @@ def process_single_instance(args):
|
|
70 |
print(f"Warning: pred_rank is not a list for query_id {query_id}")
|
71 |
return None
|
72 |
|
73 |
-
# Validate and filter prediction values
|
74 |
-
valid_pred_rank = []
|
75 |
-
for rank in pred_rank[:100]: # Only use top 100 predictions
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
|
81 |
-
if not valid_pred_rank:
|
82 |
-
|
83 |
-
|
84 |
|
85 |
-
|
86 |
-
pred_dict = {rank: -i for i, rank in enumerate(valid_pred_rank)}
|
87 |
-
|
88 |
-
# Convert answer_ids to tensor
|
89 |
answer_ids = torch.LongTensor(answer_ids)
|
90 |
-
|
91 |
-
# Evaluate
|
92 |
result = evaluator.evaluate(pred_dict, answer_ids, metrics=eval_metrics)
|
|
|
93 |
result["idx"], result["query_id"] = idx, query_id
|
94 |
return result
|
95 |
|
@@ -108,12 +104,6 @@ def compute_metrics(csv_path: str, dataset: str, split: str, num_workers: int =
|
|
108 |
'prime': [i for i in range(129375)]
|
109 |
}
|
110 |
|
111 |
-
candidate_size_dict = {
|
112 |
-
'amazon': 957192,
|
113 |
-
'mag': 700244, # 1872968 - 1172724
|
114 |
-
'prime': 129375
|
115 |
-
}
|
116 |
-
|
117 |
try:
|
118 |
# Input validation
|
119 |
if dataset not in candidate_ids_dict:
|
@@ -129,7 +119,6 @@ def compute_metrics(csv_path: str, dataset: str, split: str, num_workers: int =
|
|
129 |
raise ValueError(f"CSV must contain columns: {required_columns}")
|
130 |
|
131 |
eval_csv = eval_csv[required_columns]
|
132 |
-
max_candidate_id = candidate_size_dict[dataset]
|
133 |
|
134 |
# Initialize components
|
135 |
evaluator = Evaluator(candidate_ids_dict[dataset])
|
@@ -149,7 +138,7 @@ def compute_metrics(csv_path: str, dataset: str, split: str, num_workers: int =
|
|
149 |
futures = [
|
150 |
executor.submit(
|
151 |
process_single_instance,
|
152 |
-
(idx, eval_csv, qa_dataset, evaluator, eval_metrics
|
153 |
)
|
154 |
for idx in all_indices
|
155 |
]
|
|
|
42 |
|
43 |
def process_single_instance(args):
|
44 |
"""Process a single instance with improved validation and error handling"""
|
45 |
+
idx, eval_csv, qa_dataset, evaluator, eval_metrics = args
|
46 |
try:
|
47 |
# Get query data
|
48 |
query, query_id, answer_ids, meta_info = qa_dataset[idx]
|
|
|
70 |
print(f"Warning: pred_rank is not a list for query_id {query_id}")
|
71 |
return None
|
72 |
|
73 |
+
# # Validate and filter prediction values
|
74 |
+
# valid_pred_rank = []
|
75 |
+
# for rank in pred_rank[:100]: # Only use top 100 predictions
|
76 |
+
# if isinstance(rank, (int, np.integer)) and 0 <= rank < max_candidate_id:
|
77 |
+
# valid_pred_rank.append(rank)
|
78 |
+
# else:
|
79 |
+
# print(f"Warning: Invalid prediction {rank} for query_id {query_id}")
|
80 |
|
81 |
+
# if not valid_pred_rank:
|
82 |
+
# print(f"Warning: No valid predictions for query_id {query_id}")
|
83 |
+
# return None
|
84 |
|
85 |
+
pred_dict = {pred_rank[i]: -i for i in range(min(100, len(pred_rank)))}
|
|
|
|
|
|
|
86 |
answer_ids = torch.LongTensor(answer_ids)
|
|
|
|
|
87 |
result = evaluator.evaluate(pred_dict, answer_ids, metrics=eval_metrics)
|
88 |
+
|
89 |
result["idx"], result["query_id"] = idx, query_id
|
90 |
return result
|
91 |
|
|
|
104 |
'prime': [i for i in range(129375)]
|
105 |
}
|
106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
try:
|
108 |
# Input validation
|
109 |
if dataset not in candidate_ids_dict:
|
|
|
119 |
raise ValueError(f"CSV must contain columns: {required_columns}")
|
120 |
|
121 |
eval_csv = eval_csv[required_columns]
|
|
|
122 |
|
123 |
# Initialize components
|
124 |
evaluator = Evaluator(candidate_ids_dict[dataset])
|
|
|
138 |
futures = [
|
139 |
executor.submit(
|
140 |
process_single_instance,
|
141 |
+
(idx, eval_csv, qa_dataset, evaluator, eval_metrics)
|
142 |
)
|
143 |
for idx in all_indices
|
144 |
]
|
submissions/abc_abc/latest.json
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"latest_submission": "20241115_004044",
|
3 |
-
"status": "approved",
|
4 |
-
"method_name": "abc",
|
5 |
-
"team_name": "abc"
|
6 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
submissions/abc_abc/metadata_20241115_004044.json
DELETED
@@ -1,21 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"Method Name": "abc",
|
3 |
-
"Team Name": "abc",
|
4 |
-
"Dataset": "prime",
|
5 |
-
"Split": "human_generated_eval",
|
6 |
-
"Contact Email(s)": "a@s.edu",
|
7 |
-
"Code Repository": "https://github.com/",
|
8 |
-
"Model Description": "abc",
|
9 |
-
"Hardware": "abc",
|
10 |
-
"(Optional) Paper link": "",
|
11 |
-
"Model Type": "Others",
|
12 |
-
"results": {
|
13 |
-
"hit@1": 0.0,
|
14 |
-
"hit@5": 0.0,
|
15 |
-
"recall@20": 0.0,
|
16 |
-
"mrr": 0.03
|
17 |
-
},
|
18 |
-
"status": "approved",
|
19 |
-
"submission_date": "2024-11-15 00:40:49",
|
20 |
-
"csv_path": "submissions/abc_abc/predictions_20241115_004044.csv"
|
21 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
submissions/abc_abc/predictions_20241115_004044.csv
DELETED
The diff for this file is too large to render.
See raw diff
|
|