Spaces:
Running
Running
Shiyu Zhao
commited on
Commit
·
e050fd8
1
Parent(s):
bf00306
Update space
Browse files- app.py +163 -164
- utils/hub_storage.py +24 -29
app.py
CHANGED
@@ -189,23 +189,32 @@ def validate_github_url(url):
|
|
189 |
)
|
190 |
return bool(github_pattern.match(url))
|
191 |
|
192 |
-
def validate_csv(
|
193 |
-
"""Validate CSV file format and content"""
|
194 |
try:
|
195 |
-
df = pd.read_csv(
|
196 |
required_cols = ['query_id', 'pred_rank']
|
197 |
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
return True, "Valid CSV file"
|
|
|
|
|
|
|
209 |
except Exception as e:
|
210 |
return False, f"Error processing CSV: {str(e)}"
|
211 |
|
@@ -465,7 +474,7 @@ def save_submission(submission_data, csv_file):
|
|
465 |
def update_leaderboard_data(submission_data):
|
466 |
"""
|
467 |
Update leaderboard data with new submission results
|
468 |
-
Only
|
469 |
"""
|
470 |
global df_synthesized_full, df_synthesized_10, df_human_generated
|
471 |
|
@@ -477,26 +486,32 @@ def update_leaderboard_data(submission_data):
|
|
477 |
}
|
478 |
|
479 |
df_to_update = split_to_df[submission_data['Split']]
|
|
|
480 |
|
481 |
-
# Prepare new row data
|
482 |
new_row = {
|
483 |
-
'Method': submission_data['Method Name']
|
484 |
-
f'STARK-{submission_data["Dataset"].upper()}_Hit@1': submission_data['results']['hit@1'],
|
485 |
-
f'STARK-{submission_data["Dataset"].upper()}_Hit@5': submission_data['results']['hit@5'],
|
486 |
-
f'STARK-{submission_data["Dataset"].upper()}_R@20': submission_data['results']['recall@20'],
|
487 |
-
f'STARK-{submission_data["Dataset"].upper()}_MRR': submission_data['results']['mrr']
|
488 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
489 |
|
490 |
# Check if method already exists
|
491 |
method_mask = df_to_update['Method'] == submission_data['Method Name']
|
492 |
if method_mask.any():
|
493 |
-
# Update
|
494 |
for col in new_row:
|
495 |
df_to_update.loc[method_mask, col] = new_row[col]
|
496 |
else:
|
497 |
-
#
|
498 |
df_to_update.loc[len(df_to_update)] = new_row
|
499 |
|
|
|
|
|
500 |
# Function to get emails from meta_data
|
501 |
def get_emails_from_metadata(meta_data):
|
502 |
"""
|
@@ -601,47 +616,29 @@ def process_submission(
|
|
601 |
method_name, team_name, dataset, split, contact_email,
|
602 |
code_repo, csv_file, model_description, hardware, paper_link, model_type
|
603 |
):
|
604 |
-
"""Process and validate submission"""
|
605 |
-
temp_files = []
|
606 |
try:
|
607 |
-
#
|
608 |
if not all([method_name, team_name, dataset, split, contact_email, code_repo, csv_file, model_type]):
|
609 |
return "Error: Please fill in all required fields"
|
610 |
|
611 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
612 |
is_valid, message = validate_model_type(method_name, model_type)
|
613 |
if not is_valid:
|
614 |
return f"Error: {message}"
|
615 |
|
616 |
-
# Create
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
-
"Contact Email(s)": contact_email,
|
623 |
-
"Code Repository": code_repo,
|
624 |
-
"Model Description": model_description,
|
625 |
-
"Hardware": hardware,
|
626 |
-
"(Optional) Paper link": paper_link,
|
627 |
-
"Model Type": model_type
|
628 |
-
}
|
629 |
-
|
630 |
-
# Generate folder name and timestamp
|
631 |
-
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
632 |
-
folder_name = f"{sanitize_name(method_name)}_{sanitize_name(team_name)}"
|
633 |
-
|
634 |
-
# Process CSV file
|
635 |
-
temp_csv_path = None
|
636 |
-
if isinstance(csv_file, str):
|
637 |
-
temp_csv_path = csv_file
|
638 |
-
else:
|
639 |
-
temp_fd, temp_csv_path = tempfile.mkstemp(suffix='.csv')
|
640 |
-
temp_files.append(temp_csv_path)
|
641 |
-
os.close(temp_fd)
|
642 |
-
|
643 |
-
if hasattr(csv_file, 'name'):
|
644 |
-
shutil.copy2(csv_file.name, temp_csv_path)
|
645 |
else:
|
646 |
with open(temp_csv_path, 'wb') as temp_file:
|
647 |
if hasattr(csv_file, 'seek'):
|
@@ -651,125 +648,113 @@ def process_submission(
|
|
651 |
else:
|
652 |
temp_file.write(csv_file)
|
653 |
|
654 |
-
|
655 |
-
|
656 |
-
|
657 |
-
|
658 |
-
|
659 |
-
|
660 |
-
dataset
|
661 |
-
|
662 |
-
|
663 |
-
|
664 |
-
|
665 |
-
|
666 |
-
# send_error_notification(meta_data, results)
|
667 |
-
return f"Evaluation error: {results}"
|
668 |
-
|
669 |
-
# Process results
|
670 |
-
processed_results = {
|
671 |
-
"hit@1": round(results['hit@1'] * 100, 2),
|
672 |
-
"hit@5": round(results['hit@5'] * 100, 2),
|
673 |
-
"recall@20": round(results['recall@20'] * 100, 2),
|
674 |
-
"mrr": round(results['mrr'] * 100, 2)
|
675 |
-
}
|
676 |
-
|
677 |
-
# Save files to HuggingFace Hub
|
678 |
-
try:
|
679 |
-
# 1. Save CSV file
|
680 |
-
csv_filename = f"predictions_{timestamp}.csv"
|
681 |
-
csv_path_in_repo = f"submissions/{folder_name}/{csv_filename}"
|
682 |
-
hub_storage.save_to_hub(
|
683 |
-
file_content=temp_csv_path,
|
684 |
-
path_in_repo=csv_path_in_repo,
|
685 |
-
commit_message=f"Add submission: {method_name} by {team_name}"
|
686 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
687 |
|
688 |
-
#
|
|
|
|
|
|
|
689 |
submission_data = {
|
690 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
691 |
"results": processed_results,
|
692 |
-
"status": "
|
693 |
-
"submission_date": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
694 |
-
"csv_path": csv_path_in_repo
|
695 |
}
|
696 |
-
|
697 |
-
metadata_fd, temp_metadata_path = tempfile.mkstemp(suffix='.json')
|
698 |
-
temp_files.append(temp_metadata_path)
|
699 |
-
os.close(metadata_fd)
|
700 |
-
|
701 |
-
with open(temp_metadata_path, 'w') as f:
|
702 |
-
json.dump(submission_data, f, indent=4)
|
703 |
-
|
704 |
-
metadata_path = f"submissions/{folder_name}/metadata_{timestamp}.json"
|
705 |
-
hub_storage.save_to_hub(
|
706 |
-
file_content=temp_metadata_path,
|
707 |
-
path_in_repo=metadata_path,
|
708 |
-
commit_message=f"Add metadata: {method_name} by {team_name}"
|
709 |
-
)
|
710 |
|
711 |
-
#
|
712 |
-
|
713 |
-
|
714 |
-
|
715 |
-
|
716 |
-
|
717 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
718 |
|
719 |
-
|
720 |
-
|
721 |
-
|
|
|
|
|
722 |
|
723 |
-
|
724 |
-
|
725 |
|
726 |
-
|
727 |
-
|
728 |
-
|
729 |
-
path_in_repo=latest_path,
|
730 |
-
commit_message=f"Update latest submission info for {method_name}"
|
731 |
-
)
|
732 |
|
733 |
-
except Exception as e:
|
734 |
-
raise RuntimeError(f"Failed to save files to HuggingFace Hub: {str(e)}")
|
735 |
-
|
736 |
-
# Send confirmation email and update leaderboard data
|
737 |
-
# send_submission_confirmation(meta_data, processed_results)
|
738 |
-
update_leaderboard_data(submission_data)
|
739 |
-
|
740 |
-
# Return success message
|
741 |
-
return f"""
|
742 |
-
Submission successful!
|
743 |
-
|
744 |
-
Evaluation Results:
|
745 |
-
Hit@1: {processed_results['hit@1']:.2f}%
|
746 |
-
Hit@5: {processed_results['hit@5']:.2f}%
|
747 |
-
Recall@20: {processed_results['recall@20']:.2f}%
|
748 |
-
MRR: {processed_results['mrr']:.2f}%
|
749 |
-
|
750 |
-
Your submission has been saved and a confirmation email has been sent to {contact_email}.
|
751 |
-
Once approved, your results will appear in the leaderboard under: {method_name}
|
752 |
-
|
753 |
-
You can find your submission at:
|
754 |
-
https://huggingface.co/spaces/{REPO_ID}/tree/main/submissions/{folder_name}
|
755 |
-
|
756 |
-
Please refresh the page to see your submission in the leaderboard.
|
757 |
-
"""
|
758 |
-
|
759 |
except Exception as e:
|
760 |
-
|
761 |
-
|
762 |
-
return error_message
|
763 |
-
finally:
|
764 |
-
# Clean up temporary files
|
765 |
-
for temp_file in temp_files:
|
766 |
-
try:
|
767 |
-
if os.path.exists(temp_file):
|
768 |
-
os.unlink(temp_file)
|
769 |
-
except Exception as e:
|
770 |
-
print(f"Warning: Failed to delete temporary file {temp_file}: {str(e)}")
|
771 |
-
|
772 |
-
|
773 |
def filter_by_model_type(df, selected_types):
|
774 |
"""
|
775 |
Filter DataFrame by selected model types, including submitted models.
|
@@ -786,10 +771,24 @@ def filter_by_model_type(df, selected_types):
|
|
786 |
return df[df['Method'].isin(selected_models)]
|
787 |
|
788 |
def format_dataframe(df, dataset):
|
|
|
|
|
|
|
|
|
789 |
columns = ['Method'] + [col for col in df.columns if dataset in col]
|
790 |
filtered_df = df[columns].copy()
|
|
|
|
|
|
|
|
|
|
|
|
|
791 |
filtered_df.columns = [col.split('_')[-1] if '_' in col else col for col in filtered_df.columns]
|
792 |
-
|
|
|
|
|
|
|
|
|
793 |
return filtered_df
|
794 |
|
795 |
def update_tables(selected_types):
|
|
|
189 |
)
|
190 |
return bool(github_pattern.match(url))
|
191 |
|
192 |
+
def validate_csv(file_path):
|
193 |
+
"""Validate CSV file format and content with better error handling"""
|
194 |
try:
|
195 |
+
df = pd.read_csv(file_path)
|
196 |
required_cols = ['query_id', 'pred_rank']
|
197 |
|
198 |
+
# Check for required columns
|
199 |
+
missing_cols = [col for col in required_cols if col not in df.columns]
|
200 |
+
if missing_cols:
|
201 |
+
return False, f"Missing required columns: {', '.join(missing_cols)}"
|
202 |
+
|
203 |
+
# Validate first few rows to ensure proper format
|
204 |
+
for idx, row in df.head().iterrows():
|
205 |
+
try:
|
206 |
+
rank_list = eval(row['pred_rank']) if isinstance(row['pred_rank'], str) else row['pred_rank']
|
207 |
+
if not isinstance(rank_list, list):
|
208 |
+
return False, f"pred_rank must be a list (row {idx})"
|
209 |
+
if len(rank_list) < 20:
|
210 |
+
return False, f"pred_rank must contain at least 20 candidates (row {idx})"
|
211 |
+
except Exception as e:
|
212 |
+
return False, f"Invalid pred_rank format in row {idx}: {str(e)}"
|
213 |
+
|
214 |
return True, "Valid CSV file"
|
215 |
+
|
216 |
+
except pd.errors.EmptyDataError:
|
217 |
+
return False, "CSV file is empty"
|
218 |
except Exception as e:
|
219 |
return False, f"Error processing CSV: {str(e)}"
|
220 |
|
|
|
474 |
def update_leaderboard_data(submission_data):
|
475 |
"""
|
476 |
Update leaderboard data with new submission results
|
477 |
+
Only updates the specific dataset submitted, preventing empty rows
|
478 |
"""
|
479 |
global df_synthesized_full, df_synthesized_10, df_human_generated
|
480 |
|
|
|
486 |
}
|
487 |
|
488 |
df_to_update = split_to_df[submission_data['Split']]
|
489 |
+
dataset = submission_data['Dataset'].upper()
|
490 |
|
491 |
+
# Prepare new row data with only the relevant dataset columns
|
492 |
new_row = {
|
493 |
+
'Method': submission_data['Method Name']
|
|
|
|
|
|
|
|
|
494 |
}
|
495 |
+
# Only add metrics for the submitted dataset
|
496 |
+
new_row.update({
|
497 |
+
f'STARK-{dataset}_Hit@1': submission_data['results']['hit@1'],
|
498 |
+
f'STARK-{dataset}_Hit@5': submission_data['results']['hit@5'],
|
499 |
+
f'STARK-{dataset}_R@20': submission_data['results']['recall@20'],
|
500 |
+
f'STARK-{dataset}_MRR': submission_data['results']['mrr']
|
501 |
+
})
|
502 |
|
503 |
# Check if method already exists
|
504 |
method_mask = df_to_update['Method'] == submission_data['Method Name']
|
505 |
if method_mask.any():
|
506 |
+
# Update only the columns for the submitted dataset
|
507 |
for col in new_row:
|
508 |
df_to_update.loc[method_mask, col] = new_row[col]
|
509 |
else:
|
510 |
+
# For new methods, create a row with only the submitted dataset's values
|
511 |
df_to_update.loc[len(df_to_update)] = new_row
|
512 |
|
513 |
+
|
514 |
+
|
515 |
# Function to get emails from meta_data
|
516 |
def get_emails_from_metadata(meta_data):
|
517 |
"""
|
|
|
616 |
method_name, team_name, dataset, split, contact_email,
|
617 |
code_repo, csv_file, model_description, hardware, paper_link, model_type
|
618 |
):
|
619 |
+
"""Process and validate submission with better error handling and progress updates"""
|
|
|
620 |
try:
|
621 |
+
# 1. Initial validation with early returns
|
622 |
if not all([method_name, team_name, dataset, split, contact_email, code_repo, csv_file, model_type]):
|
623 |
return "Error: Please fill in all required fields"
|
624 |
|
625 |
+
if len(method_name) > 25:
|
626 |
+
return "Error: Method name must be 25 characters or less"
|
627 |
+
|
628 |
+
if len(team_name) > 25:
|
629 |
+
return "Error: Team name must be 25 characters or less"
|
630 |
+
|
631 |
+
# 2. Validate model type
|
632 |
is_valid, message = validate_model_type(method_name, model_type)
|
633 |
if not is_valid:
|
634 |
return f"Error: {message}"
|
635 |
|
636 |
+
# 3. Create temporary directory for processing
|
637 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
638 |
+
# Copy CSV file to temp directory
|
639 |
+
temp_csv_path = os.path.join(temp_dir, "submission.csv")
|
640 |
+
if isinstance(csv_file, str):
|
641 |
+
shutil.copy2(csv_file, temp_csv_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
642 |
else:
|
643 |
with open(temp_csv_path, 'wb') as temp_file:
|
644 |
if hasattr(csv_file, 'seek'):
|
|
|
648 |
else:
|
649 |
temp_file.write(csv_file)
|
650 |
|
651 |
+
# 4. Validate CSV format
|
652 |
+
is_valid_csv, csv_message = validate_csv(temp_csv_path)
|
653 |
+
if not is_valid_csv:
|
654 |
+
return f"Error validating CSV: {csv_message}"
|
655 |
+
|
656 |
+
# 5. Compute metrics with progress indication
|
657 |
+
print(f"Computing metrics for {dataset.lower()} dataset...")
|
658 |
+
results = compute_metrics(
|
659 |
+
csv_path=temp_csv_path,
|
660 |
+
dataset=dataset.lower(),
|
661 |
+
split=split,
|
662 |
+
num_workers=4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
663 |
)
|
664 |
+
|
665 |
+
if isinstance(results, str):
|
666 |
+
return f"Evaluation error: {results}"
|
667 |
+
|
668 |
+
# 6. Process results
|
669 |
+
processed_results = {
|
670 |
+
metric: round(value * 100, 2)
|
671 |
+
for metric, value in results.items()
|
672 |
+
}
|
673 |
|
674 |
+
# 7. Prepare submission data
|
675 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
676 |
+
folder_name = f"{sanitize_name(method_name)}_{sanitize_name(team_name)}"
|
677 |
+
|
678 |
submission_data = {
|
679 |
+
"Method Name": method_name,
|
680 |
+
"Team Name": team_name,
|
681 |
+
"Dataset": dataset,
|
682 |
+
"Split": split,
|
683 |
+
"Contact Email(s)": contact_email,
|
684 |
+
"Code Repository": code_repo,
|
685 |
+
"Model Description": model_description,
|
686 |
+
"Hardware": hardware,
|
687 |
+
"(Optional) Paper link": paper_link,
|
688 |
+
"Model Type": model_type,
|
689 |
"results": processed_results,
|
690 |
+
"status": "pending_review",
|
691 |
+
"submission_date": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
|
692 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
693 |
|
694 |
+
# 8. Save to HuggingFace Hub with error handling
|
695 |
+
try:
|
696 |
+
# Save CSV
|
697 |
+
csv_path_in_repo = f"submissions/{folder_name}/predictions_{timestamp}.csv"
|
698 |
+
hub_storage.save_to_hub(
|
699 |
+
file_content=temp_csv_path,
|
700 |
+
path_in_repo=csv_path_in_repo,
|
701 |
+
commit_message=f"Add submission CSV: {method_name} by {team_name}"
|
702 |
+
)
|
703 |
+
submission_data["csv_path"] = csv_path_in_repo
|
704 |
+
|
705 |
+
# Save metadata
|
706 |
+
metadata_path = f"submissions/{folder_name}/metadata_{timestamp}.json"
|
707 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.json') as tmp:
|
708 |
+
json.dump(submission_data, tmp, indent=4)
|
709 |
+
tmp.flush()
|
710 |
+
hub_storage.save_to_hub(
|
711 |
+
file_content=tmp.name,
|
712 |
+
path_in_repo=metadata_path,
|
713 |
+
commit_message=f"Add metadata: {method_name} by {team_name}"
|
714 |
+
)
|
715 |
+
|
716 |
+
# Update latest.json
|
717 |
+
latest_path = f"submissions/{folder_name}/latest.json"
|
718 |
+
latest_info = {
|
719 |
+
"latest_submission": timestamp,
|
720 |
+
"status": "pending_review",
|
721 |
+
"method_name": method_name
|
722 |
+
}
|
723 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.json') as tmp:
|
724 |
+
json.dump(latest_info, tmp, indent=4)
|
725 |
+
tmp.flush()
|
726 |
+
hub_storage.save_to_hub(
|
727 |
+
file_content=tmp.name,
|
728 |
+
path_in_repo=latest_path,
|
729 |
+
commit_message=f"Update latest submission info for {method_name}"
|
730 |
+
)
|
731 |
+
|
732 |
+
except Exception as e:
|
733 |
+
return f"Failed to save to HuggingFace Hub: {str(e)}"
|
734 |
+
|
735 |
+
# 9. Update leaderboard
|
736 |
+
update_leaderboard_data(submission_data)
|
737 |
+
|
738 |
+
# 10. Return success message
|
739 |
+
return f"""
|
740 |
+
Submission successful!
|
741 |
|
742 |
+
Evaluation Results:
|
743 |
+
Hit@1: {processed_results['hit@1']:.2f}%
|
744 |
+
Hit@5: {processed_results['hit@5']:.2f}%
|
745 |
+
Recall@20: {processed_results['recall@20']:.2f}%
|
746 |
+
MRR: {processed_results['mrr']:.2f}%
|
747 |
|
748 |
+
Your submission has been saved and will be reviewed.
|
749 |
+
Once approved, your results will appear in the leaderboard as: {method_name}
|
750 |
|
751 |
+
You can find your submission at:
|
752 |
+
https://huggingface.co/spaces/{REPO_ID}/tree/main/submissions/{folder_name}
|
753 |
+
"""
|
|
|
|
|
|
|
754 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
755 |
except Exception as e:
|
756 |
+
return f"Error processing submission: {str(e)}"
|
757 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
758 |
def filter_by_model_type(df, selected_types):
|
759 |
"""
|
760 |
Filter DataFrame by selected model types, including submitted models.
|
|
|
771 |
return df[df['Method'].isin(selected_models)]
|
772 |
|
773 |
def format_dataframe(df, dataset):
|
774 |
+
"""
|
775 |
+
Format DataFrame for display, removing rows with no data for the selected dataset
|
776 |
+
"""
|
777 |
+
# Select relevant columns
|
778 |
columns = ['Method'] + [col for col in df.columns if dataset in col]
|
779 |
filtered_df = df[columns].copy()
|
780 |
+
|
781 |
+
# Remove rows where all metric columns are empty/NaN for this dataset
|
782 |
+
metric_columns = [col for col in filtered_df.columns if col != 'Method']
|
783 |
+
filtered_df = filtered_df.dropna(subset=metric_columns, how='all')
|
784 |
+
|
785 |
+
# Rename columns to remove dataset prefix
|
786 |
filtered_df.columns = [col.split('_')[-1] if '_' in col else col for col in filtered_df.columns]
|
787 |
+
|
788 |
+
# Sort by MRR
|
789 |
+
if 'MRR' in filtered_df.columns:
|
790 |
+
filtered_df = filtered_df.sort_values('MRR', ascending=False)
|
791 |
+
|
792 |
return filtered_df
|
793 |
|
794 |
def update_tables(selected_types):
|
utils/hub_storage.py
CHANGED
@@ -6,36 +6,31 @@ class HubStorage:
|
|
6 |
def __init__(self, repo_id):
|
7 |
self.repo_id = repo_id
|
8 |
self.api = HfApi()
|
9 |
-
|
10 |
-
|
11 |
-
"""
|
12 |
-
Get content of a file from the repository
|
13 |
-
"""
|
14 |
try:
|
15 |
-
|
16 |
-
repo_id=self.repo_id,
|
17 |
-
repo_type="space",
|
18 |
-
filename=file_path,
|
19 |
-
text=True
|
20 |
-
)
|
21 |
-
return content
|
22 |
except Exception as e:
|
23 |
-
|
24 |
-
return None
|
25 |
|
26 |
def save_to_hub(self, file_content, path_in_repo, commit_message):
|
27 |
-
"""
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
6 |
def __init__(self, repo_id):
|
7 |
self.repo_id = repo_id
|
8 |
self.api = HfApi()
|
9 |
+
|
10 |
+
# Validate repository access
|
|
|
|
|
|
|
11 |
try:
|
12 |
+
self.api.repo_info(repo_id=repo_id, repo_type="space")
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
except Exception as e:
|
14 |
+
raise RuntimeError(f"Failed to access repository {repo_id}: {e}")
|
|
|
15 |
|
16 |
def save_to_hub(self, file_content, path_in_repo, commit_message):
|
17 |
+
"""Save file to HuggingFace Hub with retries and better error handling"""
|
18 |
+
max_retries = 3
|
19 |
+
retry_delay = 1 # seconds
|
20 |
+
|
21 |
+
for attempt in range(max_retries):
|
22 |
+
try:
|
23 |
+
self.api.upload_file(
|
24 |
+
repo_id=self.repo_id,
|
25 |
+
repo_type="space",
|
26 |
+
path_or_fileobj=file_content,
|
27 |
+
path_in_repo=path_in_repo,
|
28 |
+
commit_message=commit_message
|
29 |
+
)
|
30 |
+
return True
|
31 |
+
except Exception as e:
|
32 |
+
if attempt == max_retries - 1:
|
33 |
+
raise RuntimeError(f"Failed to save file after {max_retries} attempts: {e}")
|
34 |
+
print(f"Attempt {attempt + 1} failed, retrying in {retry_delay} seconds...")
|
35 |
+
time.sleep(retry_delay)
|
36 |
+
retry_delay *= 2 # Exponential backoff
|