Spaces:
Running
Running
Commit
·
41a8716
1
Parent(s):
2264c6e
Revert to commit a9df757
Browse files
vms/ui/app_ui.py
CHANGED
@@ -403,7 +403,6 @@ class AppUI:
|
|
403 |
]
|
404 |
)
|
405 |
|
406 |
-
|
407 |
# Button update timer for button components (every 1 second)
|
408 |
button_timer = gr.Timer(value=1)
|
409 |
button_outputs = [
|
|
|
403 |
]
|
404 |
)
|
405 |
|
|
|
406 |
# Button update timer for button components (every 1 second)
|
407 |
button_timer = gr.Timer(value=1)
|
408 |
button_outputs = [
|
vms/ui/models/tabs/training_tab.py
CHANGED
@@ -88,8 +88,9 @@ class TrainingTab(BaseTab):
|
|
88 |
gr.Markdown(model.model_display_name or "Unknown")
|
89 |
|
90 |
with gr.Column(scale=2, min_width=20):
|
91 |
-
progress_text = f"Step {model.current_step}/{model.total_steps}
|
92 |
gr.Markdown(progress_text)
|
|
|
93 |
|
94 |
with gr.Column(scale=2, min_width=20):
|
95 |
with gr.Row():
|
|
|
88 |
gr.Markdown(model.model_display_name or "Unknown")
|
89 |
|
90 |
with gr.Column(scale=2, min_width=20):
|
91 |
+
progress_text = f"Step {model.current_step}/{model.total_steps}"
|
92 |
gr.Markdown(progress_text)
|
93 |
+
gr.Progress(value=model.training_progress/100)
|
94 |
|
95 |
with gr.Column(scale=2, min_width=20):
|
96 |
with gr.Row():
|
vms/ui/project/services/training.py
CHANGED
@@ -1823,9 +1823,12 @@ class TrainingService:
|
|
1823 |
try:
|
1824 |
checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
|
1825 |
if not checkpoints:
|
1826 |
-
return "
|
1827 |
|
1828 |
-
|
|
|
|
|
|
|
1829 |
except Exception as e:
|
1830 |
logger.warning(f"Error getting checkpoint info for button text: {e}")
|
1831 |
-
return "
|
|
|
1823 |
try:
|
1824 |
checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
|
1825 |
if not checkpoints:
|
1826 |
+
return "📥 Download checkpoints (not available)"
|
1827 |
|
1828 |
+
# Get the latest checkpoint by step number
|
1829 |
+
latest_checkpoint = max(checkpoints, key=lambda x: int(x.name.split("_")[-1]))
|
1830 |
+
step_num = int(latest_checkpoint.name.split("_")[-1])
|
1831 |
+
return f"📥 Download checkpoints (step {step_num})"
|
1832 |
except Exception as e:
|
1833 |
logger.warning(f"Error getting checkpoint info for button text: {e}")
|
1834 |
+
return "📥 Download checkpoints (not available)"
|
vms/ui/project/tabs/manage_tab.py
CHANGED
@@ -25,6 +25,50 @@ class ManageTab(BaseTab):
|
|
25 |
self.id = "manage_tab"
|
26 |
self.title = "5️⃣ Storage"
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
def create(self, parent=None) -> gr.TabItem:
|
30 |
"""Create the Manage tab UI components"""
|
@@ -46,19 +90,19 @@ class ManageTab(BaseTab):
|
|
46 |
gr.Markdown("📦 Training dataset download disabled for large datasets")
|
47 |
|
48 |
self.components["download_model_btn"] = gr.DownloadButton(
|
49 |
-
|
50 |
variant="secondary",
|
51 |
size="lg"
|
52 |
)
|
53 |
|
54 |
self.components["download_checkpoint_btn"] = gr.DownloadButton(
|
55 |
-
|
56 |
variant="secondary",
|
57 |
size="lg"
|
58 |
)
|
59 |
|
60 |
self.components["download_output_btn"] = gr.DownloadButton(
|
61 |
-
"📁 Download output
|
62 |
variant="secondary",
|
63 |
size="lg",
|
64 |
visible=False
|
|
|
25 |
self.id = "manage_tab"
|
26 |
self.title = "5️⃣ Storage"
|
27 |
|
28 |
+
def get_download_button_text(self) -> str:
|
29 |
+
"""Get the dynamic text for the download button based on current model state"""
|
30 |
+
try:
|
31 |
+
model_info = self.app.training.get_model_output_info()
|
32 |
+
if model_info["path"] and model_info["steps"]:
|
33 |
+
return f"🧠 Download weights ({model_info['steps']} steps)"
|
34 |
+
elif model_info["path"]:
|
35 |
+
return "🧠 Download weights (.safetensors)"
|
36 |
+
else:
|
37 |
+
return "🧠 Download weights (not available)"
|
38 |
+
except Exception as e:
|
39 |
+
logger.warning(f"Error getting model info for button text: {e}")
|
40 |
+
return "🧠 Download weights (.safetensors)"
|
41 |
+
|
42 |
+
def get_checkpoint_button_text(self) -> str:
|
43 |
+
"""Get the dynamic text for the download checkpoint button"""
|
44 |
+
try:
|
45 |
+
return self.app.training.get_checkpoint_button_text()
|
46 |
+
except Exception as e:
|
47 |
+
logger.warning(f"Error getting checkpoint button text: {e}")
|
48 |
+
return "📥 Download checkpoints (not available)"
|
49 |
+
|
50 |
+
def update_download_button_text(self) -> gr.update:
|
51 |
+
"""Update the download button text"""
|
52 |
+
return gr.update(value=self.get_download_button_text())
|
53 |
+
|
54 |
+
def update_checkpoint_button_text(self) -> gr.update:
|
55 |
+
"""Update the checkpoint button text"""
|
56 |
+
return gr.update(value=self.get_checkpoint_button_text())
|
57 |
+
|
58 |
+
def update_both_download_buttons(self) -> Tuple[gr.update, gr.update]:
|
59 |
+
"""Update both download button texts"""
|
60 |
+
return (
|
61 |
+
gr.update(value=self.get_download_button_text()),
|
62 |
+
gr.update(value=self.get_checkpoint_button_text())
|
63 |
+
)
|
64 |
+
|
65 |
+
def download_and_update_button(self):
|
66 |
+
"""Handle download and return updated button with current text"""
|
67 |
+
# Get the safetensors path for download
|
68 |
+
path = self.app.training.get_model_output_safetensors()
|
69 |
+
# For DownloadButton, we need to return the file path directly for download
|
70 |
+
# The button text will be updated on next render
|
71 |
+
return path
|
72 |
|
73 |
def create(self, parent=None) -> gr.TabItem:
|
74 |
"""Create the Manage tab UI components"""
|
|
|
90 |
gr.Markdown("📦 Training dataset download disabled for large datasets")
|
91 |
|
92 |
self.components["download_model_btn"] = gr.DownloadButton(
|
93 |
+
self.get_download_button_text(),
|
94 |
variant="secondary",
|
95 |
size="lg"
|
96 |
)
|
97 |
|
98 |
self.components["download_checkpoint_btn"] = gr.DownloadButton(
|
99 |
+
self.get_checkpoint_button_text(),
|
100 |
variant="secondary",
|
101 |
size="lg"
|
102 |
)
|
103 |
|
104 |
self.components["download_output_btn"] = gr.DownloadButton(
|
105 |
+
"📁 Download output directory (.zip)",
|
106 |
variant="secondary",
|
107 |
size="lg",
|
108 |
visible=False
|
vms/ui/project/tabs/train_tab.py
CHANGED
@@ -494,7 +494,12 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
494 |
save_iterations, repo_id, progress
|
495 |
)
|
496 |
|
497 |
-
|
|
|
|
|
|
|
|
|
|
|
498 |
|
499 |
def handle_resume_training(
|
500 |
self, model_type, model_version, training_type,
|
@@ -506,7 +511,10 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
506 |
checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
|
507 |
|
508 |
if not checkpoints:
|
509 |
-
|
|
|
|
|
|
|
510 |
|
511 |
self.app.training.append_log(f"Resuming training from latest checkpoint")
|
512 |
|
@@ -518,7 +526,12 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
518 |
resume_from_checkpoint="latest"
|
519 |
)
|
520 |
|
521 |
-
|
|
|
|
|
|
|
|
|
|
|
522 |
|
523 |
def handle_start_from_lora_training(
|
524 |
self, model_type, model_version, training_type,
|
@@ -529,22 +542,26 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
529 |
# Find the latest LoRA weights
|
530 |
lora_weights_path = self.app.output_path / "lora_weights"
|
531 |
|
|
|
|
|
|
|
|
|
532 |
if not lora_weights_path.exists():
|
533 |
-
return "No LoRA weights found", "Please train a model first or start a new training session"
|
534 |
|
535 |
# Find the latest LoRA checkpoint directory
|
536 |
lora_dirs = sorted([d for d in lora_weights_path.iterdir() if d.is_dir()],
|
537 |
key=lambda x: int(x.name), reverse=True)
|
538 |
|
539 |
if not lora_dirs:
|
540 |
-
return "No LoRA weight directories found", "Please train a model first or start a new training session"
|
541 |
|
542 |
latest_lora_dir = lora_dirs[0]
|
543 |
|
544 |
# Verify the LoRA weights file exists
|
545 |
lora_weights_file = latest_lora_dir / "pytorch_lora_weights.safetensors"
|
546 |
if not lora_weights_file.exists():
|
547 |
-
return f"LoRA weights file not found in {latest_lora_dir}", "Please check your LoRA weights directory"
|
548 |
|
549 |
# Clear checkpoints to start fresh (but keep LoRA weights)
|
550 |
for checkpoint in self.app.output_path.glob("finetrainers_step_*"):
|
@@ -565,7 +582,11 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
565 |
save_iterations, repo_id, progress,
|
566 |
)
|
567 |
|
568 |
-
|
|
|
|
|
|
|
|
|
569 |
|
570 |
def connect_events(self) -> None:
|
571 |
"""Connect event handlers to UI components"""
|
@@ -748,7 +769,9 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
748 |
],
|
749 |
outputs=[
|
750 |
self.components["status_box"],
|
751 |
-
self.components["log_box"]
|
|
|
|
|
752 |
]
|
753 |
)
|
754 |
|
@@ -768,7 +791,9 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
768 |
],
|
769 |
outputs=[
|
770 |
self.components["status_box"],
|
771 |
-
self.components["log_box"]
|
|
|
|
|
772 |
]
|
773 |
)
|
774 |
|
@@ -788,7 +813,9 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
788 |
],
|
789 |
outputs=[
|
790 |
self.components["status_box"],
|
791 |
-
self.components["log_box"]
|
|
|
|
|
792 |
]
|
793 |
)
|
794 |
|
@@ -804,7 +831,9 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
804 |
self.components["current_task_box"],
|
805 |
self.components["start_btn"],
|
806 |
self.components["stop_btn"],
|
807 |
-
third_btn
|
|
|
|
|
808 |
]
|
809 |
)
|
810 |
|
@@ -816,7 +845,9 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
816 |
self.components["current_task_box"],
|
817 |
self.components["start_btn"],
|
818 |
self.components["stop_btn"],
|
819 |
-
third_btn
|
|
|
|
|
820 |
]
|
821 |
)
|
822 |
|
@@ -1209,7 +1240,12 @@ Full finetune mode trains all parameters of the model, requiring more VRAM but p
|
|
1209 |
variant="stop"
|
1210 |
)
|
1211 |
|
1212 |
-
|
|
|
|
|
|
|
|
|
|
|
1213 |
|
1214 |
def update_training_ui(self, training_state: Dict[str, Any]):
|
1215 |
"""Update UI components based on training state"""
|
|
|
494 |
save_iterations, repo_id, progress
|
495 |
)
|
496 |
|
497 |
+
# Update download button texts
|
498 |
+
manage_tab = self.app.tabs["manage_tab"]
|
499 |
+
download_btn_text = gr.update(value=manage_tab.get_download_button_text())
|
500 |
+
checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
|
501 |
+
|
502 |
+
return status, logs, download_btn_text, checkpoint_btn_text
|
503 |
|
504 |
def handle_resume_training(
|
505 |
self, model_type, model_version, training_type,
|
|
|
511 |
checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
|
512 |
|
513 |
if not checkpoints:
|
514 |
+
manage_tab = self.app.tabs["manage_tab"]
|
515 |
+
download_btn_text = gr.update(value=manage_tab.get_download_button_text())
|
516 |
+
checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
|
517 |
+
return "No checkpoints found to resume from", "Please start a new training session instead", download_btn_text, checkpoint_btn_text
|
518 |
|
519 |
self.app.training.append_log(f"Resuming training from latest checkpoint")
|
520 |
|
|
|
526 |
resume_from_checkpoint="latest"
|
527 |
)
|
528 |
|
529 |
+
# Update download button texts
|
530 |
+
manage_tab = self.app.tabs["manage_tab"]
|
531 |
+
download_btn_text = gr.update(value=manage_tab.get_download_button_text())
|
532 |
+
checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
|
533 |
+
|
534 |
+
return status, logs, download_btn_text, checkpoint_btn_text
|
535 |
|
536 |
def handle_start_from_lora_training(
|
537 |
self, model_type, model_version, training_type,
|
|
|
542 |
# Find the latest LoRA weights
|
543 |
lora_weights_path = self.app.output_path / "lora_weights"
|
544 |
|
545 |
+
manage_tab = self.app.tabs["manage_tab"]
|
546 |
+
download_btn_text = gr.update(value=manage_tab.get_download_button_text())
|
547 |
+
checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
|
548 |
+
|
549 |
if not lora_weights_path.exists():
|
550 |
+
return "No LoRA weights found", "Please train a model first or start a new training session", download_btn_text, checkpoint_btn_text
|
551 |
|
552 |
# Find the latest LoRA checkpoint directory
|
553 |
lora_dirs = sorted([d for d in lora_weights_path.iterdir() if d.is_dir()],
|
554 |
key=lambda x: int(x.name), reverse=True)
|
555 |
|
556 |
if not lora_dirs:
|
557 |
+
return "No LoRA weight directories found", "Please train a model first or start a new training session", download_btn_text, checkpoint_btn_text
|
558 |
|
559 |
latest_lora_dir = lora_dirs[0]
|
560 |
|
561 |
# Verify the LoRA weights file exists
|
562 |
lora_weights_file = latest_lora_dir / "pytorch_lora_weights.safetensors"
|
563 |
if not lora_weights_file.exists():
|
564 |
+
return f"LoRA weights file not found in {latest_lora_dir}", "Please check your LoRA weights directory", download_btn_text, checkpoint_btn_text
|
565 |
|
566 |
# Clear checkpoints to start fresh (but keep LoRA weights)
|
567 |
for checkpoint in self.app.output_path.glob("finetrainers_step_*"):
|
|
|
582 |
save_iterations, repo_id, progress,
|
583 |
)
|
584 |
|
585 |
+
# Update download button texts
|
586 |
+
download_btn_text = gr.update(value=manage_tab.get_download_button_text())
|
587 |
+
checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
|
588 |
+
|
589 |
+
return status, logs, download_btn_text, checkpoint_btn_text
|
590 |
|
591 |
def connect_events(self) -> None:
|
592 |
"""Connect event handlers to UI components"""
|
|
|
769 |
],
|
770 |
outputs=[
|
771 |
self.components["status_box"],
|
772 |
+
self.components["log_box"],
|
773 |
+
self.app.tabs["manage_tab"].components["download_model_btn"],
|
774 |
+
self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
|
775 |
]
|
776 |
)
|
777 |
|
|
|
791 |
],
|
792 |
outputs=[
|
793 |
self.components["status_box"],
|
794 |
+
self.components["log_box"],
|
795 |
+
self.app.tabs["manage_tab"].components["download_model_btn"],
|
796 |
+
self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
|
797 |
]
|
798 |
)
|
799 |
|
|
|
813 |
],
|
814 |
outputs=[
|
815 |
self.components["status_box"],
|
816 |
+
self.components["log_box"],
|
817 |
+
self.app.tabs["manage_tab"].components["download_model_btn"],
|
818 |
+
self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
|
819 |
]
|
820 |
)
|
821 |
|
|
|
831 |
self.components["current_task_box"],
|
832 |
self.components["start_btn"],
|
833 |
self.components["stop_btn"],
|
834 |
+
third_btn,
|
835 |
+
self.app.tabs["manage_tab"].components["download_model_btn"],
|
836 |
+
self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
|
837 |
]
|
838 |
)
|
839 |
|
|
|
845 |
self.components["current_task_box"],
|
846 |
self.components["start_btn"],
|
847 |
self.components["stop_btn"],
|
848 |
+
third_btn,
|
849 |
+
self.app.tabs["manage_tab"].components["download_model_btn"],
|
850 |
+
self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
|
851 |
]
|
852 |
)
|
853 |
|
|
|
1240 |
variant="stop"
|
1241 |
)
|
1242 |
|
1243 |
+
# Update download button texts
|
1244 |
+
manage_tab = self.app.tabs["manage_tab"]
|
1245 |
+
download_btn_text = gr.update(value=manage_tab.get_download_button_text())
|
1246 |
+
checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
|
1247 |
+
|
1248 |
+
return start_btn, resume_btn, stop_btn, delete_checkpoints_btn, download_btn_text, checkpoint_btn_text
|
1249 |
|
1250 |
def update_training_ui(self, training_state: Dict[str, Any]):
|
1251 |
"""Update UI components based on training state"""
|