Spaces:
Running
Running
Commit
·
5a793ee
1
Parent(s):
aeb51a1
small fix (not tested yet)
Browse files
vms/ui/project/tabs/manage_tab.py
CHANGED
@@ -6,7 +6,7 @@ import gradio as gr
|
|
6 |
import logging
|
7 |
import shutil
|
8 |
from pathlib import Path
|
9 |
-
from typing import Dict, Any, List, Optional
|
10 |
from gradio_modal import Modal
|
11 |
|
12 |
from vms.utils import BaseTab, validate_model_repo
|
@@ -51,6 +51,17 @@ class ManageTab(BaseTab):
|
|
51 |
"""Update the download button text"""
|
52 |
return gr.update(value=self.get_download_button_text())
|
53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
def download_and_update_button(self):
|
55 |
"""Handle download and return updated button with current text"""
|
56 |
# Get the safetensors path for download
|
|
|
6 |
import logging
|
7 |
import shutil
|
8 |
from pathlib import Path
|
9 |
+
from typing import Dict, Any, List, Optional, Tuple
|
10 |
from gradio_modal import Modal
|
11 |
|
12 |
from vms.utils import BaseTab, validate_model_repo
|
|
|
51 |
"""Update the download button text"""
|
52 |
return gr.update(value=self.get_download_button_text())
|
53 |
|
54 |
+
def update_checkpoint_button_text(self) -> gr.update:
|
55 |
+
"""Update the checkpoint button text"""
|
56 |
+
return gr.update(value=self.get_checkpoint_button_text())
|
57 |
+
|
58 |
+
def update_both_download_buttons(self) -> Tuple[gr.update, gr.update]:
|
59 |
+
"""Update both download button texts"""
|
60 |
+
return (
|
61 |
+
gr.update(value=self.get_download_button_text()),
|
62 |
+
gr.update(value=self.get_checkpoint_button_text())
|
63 |
+
)
|
64 |
+
|
65 |
def download_and_update_button(self):
|
66 |
"""Handle download and return updated button with current text"""
|
67 |
# Get the safetensors path for download
|
vms/ui/project/tabs/train_tab.py
CHANGED
@@ -341,9 +341,12 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
341 |
## ⚗️ Train your model on your dataset
|
342 |
- **🚀 Start new training**: Begins training from scratch (clears previous checkpoints)
|
343 |
- **🛸 Start from latest checkpoint**: Continues training from the most recent checkpoint
|
344 |
-
- **🔄 Start over using latest LoRA weights**: Start fresh training but use existing LoRA weights as initialization
|
345 |
""")
|
346 |
-
|
|
|
|
|
|
|
|
|
347 |
with gr.Row():
|
348 |
# Check for existing checkpoints to determine button text
|
349 |
checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
|
@@ -485,11 +488,18 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
485 |
self.app.training.append_log("Cleared previous checkpoints for new training session")
|
486 |
|
487 |
# Start training normally
|
488 |
-
|
489 |
model_type, model_version, training_type,
|
490 |
lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
|
491 |
save_iterations, repo_id, progress
|
492 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
493 |
|
494 |
def handle_resume_training(
|
495 |
self, model_type, model_version, training_type,
|
@@ -501,17 +511,27 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
501 |
checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
|
502 |
|
503 |
if not checkpoints:
|
504 |
-
|
|
|
|
|
|
|
505 |
|
506 |
self.app.training.append_log(f"Resuming training from latest checkpoint")
|
507 |
|
508 |
# Start training with the checkpoint
|
509 |
-
|
510 |
model_type, model_version, training_type,
|
511 |
lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
|
512 |
save_iterations, repo_id, progress,
|
513 |
resume_from_checkpoint="latest"
|
514 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
515 |
|
516 |
def handle_start_from_lora_training(
|
517 |
self, model_type, model_version, training_type,
|
@@ -522,22 +542,26 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
522 |
# Find the latest LoRA weights
|
523 |
lora_weights_path = self.app.output_path / "lora_weights"
|
524 |
|
|
|
|
|
|
|
|
|
525 |
if not lora_weights_path.exists():
|
526 |
-
return "No LoRA weights found", "Please train a model first or start a new training session"
|
527 |
|
528 |
# Find the latest LoRA checkpoint directory
|
529 |
lora_dirs = sorted([d for d in lora_weights_path.iterdir() if d.is_dir()],
|
530 |
key=lambda x: int(x.name), reverse=True)
|
531 |
|
532 |
if not lora_dirs:
|
533 |
-
return "No LoRA weight directories found", "Please train a model first or start a new training session"
|
534 |
|
535 |
latest_lora_dir = lora_dirs[0]
|
536 |
|
537 |
# Verify the LoRA weights file exists
|
538 |
lora_weights_file = latest_lora_dir / "pytorch_lora_weights.safetensors"
|
539 |
if not lora_weights_file.exists():
|
540 |
-
return f"LoRA weights file not found in {latest_lora_dir}", "Please check your LoRA weights directory"
|
541 |
|
542 |
# Clear checkpoints to start fresh (but keep LoRA weights)
|
543 |
for checkpoint in self.app.output_path.glob("finetrainers_step_*"):
|
@@ -552,11 +576,17 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
552 |
self.app.training.append_log(f"Starting training from LoRA weights: {latest_lora_dir}")
|
553 |
|
554 |
# Start training with the LoRA weights
|
555 |
-
|
556 |
model_type, model_version, training_type,
|
557 |
lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
|
558 |
save_iterations, repo_id, progress,
|
559 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
560 |
|
561 |
def connect_events(self) -> None:
|
562 |
"""Connect event handlers to UI components"""
|
@@ -739,7 +769,9 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
739 |
],
|
740 |
outputs=[
|
741 |
self.components["status_box"],
|
742 |
-
self.components["log_box"]
|
|
|
|
|
743 |
]
|
744 |
)
|
745 |
|
@@ -759,7 +791,9 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
759 |
],
|
760 |
outputs=[
|
761 |
self.components["status_box"],
|
762 |
-
self.components["log_box"]
|
|
|
|
|
763 |
]
|
764 |
)
|
765 |
|
@@ -779,7 +813,9 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
779 |
],
|
780 |
outputs=[
|
781 |
self.components["status_box"],
|
782 |
-
self.components["log_box"]
|
|
|
|
|
783 |
]
|
784 |
)
|
785 |
|
@@ -795,7 +831,9 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
795 |
self.components["current_task_box"],
|
796 |
self.components["start_btn"],
|
797 |
self.components["stop_btn"],
|
798 |
-
third_btn
|
|
|
|
|
799 |
]
|
800 |
)
|
801 |
|
@@ -807,7 +845,9 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
807 |
self.components["current_task_box"],
|
808 |
self.components["start_btn"],
|
809 |
self.components["stop_btn"],
|
810 |
-
third_btn
|
|
|
|
|
811 |
]
|
812 |
)
|
813 |
|
@@ -1200,7 +1240,12 @@ Full finetune mode trains all parameters of the model, requiring more VRAM but p
|
|
1200 |
variant="stop"
|
1201 |
)
|
1202 |
|
1203 |
-
|
|
|
|
|
|
|
|
|
|
|
1204 |
|
1205 |
def update_training_ui(self, training_state: Dict[str, Any]):
|
1206 |
"""Update UI components based on training state"""
|
|
|
341 |
## ⚗️ Train your model on your dataset
|
342 |
- **🚀 Start new training**: Begins training from scratch (clears previous checkpoints)
|
343 |
- **🛸 Start from latest checkpoint**: Continues training from the most recent checkpoint
|
|
|
344 |
""")
|
345 |
+
|
346 |
+
#Finetrainers doesn't support recovery of a training session using a LoRA,
|
347 |
+
#so this feature doesn't work, I've disabled the line/documentation:
|
348 |
+
#- **🔄 Start over using latest LoRA weights**: Start fresh training but use existing LoRA weights as initialization
|
349 |
+
|
350 |
with gr.Row():
|
351 |
# Check for existing checkpoints to determine button text
|
352 |
checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
|
|
|
488 |
self.app.training.append_log("Cleared previous checkpoints for new training session")
|
489 |
|
490 |
# Start training normally
|
491 |
+
status, logs = self.handle_training_start(
|
492 |
model_type, model_version, training_type,
|
493 |
lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
|
494 |
save_iterations, repo_id, progress
|
495 |
)
|
496 |
+
|
497 |
+
# Update download button texts
|
498 |
+
manage_tab = self.app.tabs["manage_tab"]
|
499 |
+
download_btn_text = gr.update(value=manage_tab.get_download_button_text())
|
500 |
+
checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
|
501 |
+
|
502 |
+
return status, logs, download_btn_text, checkpoint_btn_text
|
503 |
|
504 |
def handle_resume_training(
|
505 |
self, model_type, model_version, training_type,
|
|
|
511 |
checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
|
512 |
|
513 |
if not checkpoints:
|
514 |
+
manage_tab = self.app.tabs["manage_tab"]
|
515 |
+
download_btn_text = gr.update(value=manage_tab.get_download_button_text())
|
516 |
+
checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
|
517 |
+
return "No checkpoints found to resume from", "Please start a new training session instead", download_btn_text, checkpoint_btn_text
|
518 |
|
519 |
self.app.training.append_log(f"Resuming training from latest checkpoint")
|
520 |
|
521 |
# Start training with the checkpoint
|
522 |
+
status, logs = self.handle_training_start(
|
523 |
model_type, model_version, training_type,
|
524 |
lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
|
525 |
save_iterations, repo_id, progress,
|
526 |
resume_from_checkpoint="latest"
|
527 |
)
|
528 |
+
|
529 |
+
# Update download button texts
|
530 |
+
manage_tab = self.app.tabs["manage_tab"]
|
531 |
+
download_btn_text = gr.update(value=manage_tab.get_download_button_text())
|
532 |
+
checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
|
533 |
+
|
534 |
+
return status, logs, download_btn_text, checkpoint_btn_text
|
535 |
|
536 |
def handle_start_from_lora_training(
|
537 |
self, model_type, model_version, training_type,
|
|
|
542 |
# Find the latest LoRA weights
|
543 |
lora_weights_path = self.app.output_path / "lora_weights"
|
544 |
|
545 |
+
manage_tab = self.app.tabs["manage_tab"]
|
546 |
+
download_btn_text = gr.update(value=manage_tab.get_download_button_text())
|
547 |
+
checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
|
548 |
+
|
549 |
if not lora_weights_path.exists():
|
550 |
+
return "No LoRA weights found", "Please train a model first or start a new training session", download_btn_text, checkpoint_btn_text
|
551 |
|
552 |
# Find the latest LoRA checkpoint directory
|
553 |
lora_dirs = sorted([d for d in lora_weights_path.iterdir() if d.is_dir()],
|
554 |
key=lambda x: int(x.name), reverse=True)
|
555 |
|
556 |
if not lora_dirs:
|
557 |
+
return "No LoRA weight directories found", "Please train a model first or start a new training session", download_btn_text, checkpoint_btn_text
|
558 |
|
559 |
latest_lora_dir = lora_dirs[0]
|
560 |
|
561 |
# Verify the LoRA weights file exists
|
562 |
lora_weights_file = latest_lora_dir / "pytorch_lora_weights.safetensors"
|
563 |
if not lora_weights_file.exists():
|
564 |
+
return f"LoRA weights file not found in {latest_lora_dir}", "Please check your LoRA weights directory", download_btn_text, checkpoint_btn_text
|
565 |
|
566 |
# Clear checkpoints to start fresh (but keep LoRA weights)
|
567 |
for checkpoint in self.app.output_path.glob("finetrainers_step_*"):
|
|
|
576 |
self.app.training.append_log(f"Starting training from LoRA weights: {latest_lora_dir}")
|
577 |
|
578 |
# Start training with the LoRA weights
|
579 |
+
status, logs = self.handle_training_start(
|
580 |
model_type, model_version, training_type,
|
581 |
lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
|
582 |
save_iterations, repo_id, progress,
|
583 |
)
|
584 |
+
|
585 |
+
# Update download button texts
|
586 |
+
download_btn_text = gr.update(value=manage_tab.get_download_button_text())
|
587 |
+
checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
|
588 |
+
|
589 |
+
return status, logs, download_btn_text, checkpoint_btn_text
|
590 |
|
591 |
def connect_events(self) -> None:
|
592 |
"""Connect event handlers to UI components"""
|
|
|
769 |
],
|
770 |
outputs=[
|
771 |
self.components["status_box"],
|
772 |
+
self.components["log_box"],
|
773 |
+
self.app.tabs["manage_tab"].components["download_model_btn"],
|
774 |
+
self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
|
775 |
]
|
776 |
)
|
777 |
|
|
|
791 |
],
|
792 |
outputs=[
|
793 |
self.components["status_box"],
|
794 |
+
self.components["log_box"],
|
795 |
+
self.app.tabs["manage_tab"].components["download_model_btn"],
|
796 |
+
self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
|
797 |
]
|
798 |
)
|
799 |
|
|
|
813 |
],
|
814 |
outputs=[
|
815 |
self.components["status_box"],
|
816 |
+
self.components["log_box"],
|
817 |
+
self.app.tabs["manage_tab"].components["download_model_btn"],
|
818 |
+
self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
|
819 |
]
|
820 |
)
|
821 |
|
|
|
831 |
self.components["current_task_box"],
|
832 |
self.components["start_btn"],
|
833 |
self.components["stop_btn"],
|
834 |
+
third_btn,
|
835 |
+
self.app.tabs["manage_tab"].components["download_model_btn"],
|
836 |
+
self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
|
837 |
]
|
838 |
)
|
839 |
|
|
|
845 |
self.components["current_task_box"],
|
846 |
self.components["start_btn"],
|
847 |
self.components["stop_btn"],
|
848 |
+
third_btn,
|
849 |
+
self.app.tabs["manage_tab"].components["download_model_btn"],
|
850 |
+
self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
|
851 |
]
|
852 |
)
|
853 |
|
|
|
1240 |
variant="stop"
|
1241 |
)
|
1242 |
|
1243 |
+
# Update download button texts
|
1244 |
+
manage_tab = self.app.tabs["manage_tab"]
|
1245 |
+
download_btn_text = gr.update(value=manage_tab.get_download_button_text())
|
1246 |
+
checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
|
1247 |
+
|
1248 |
+
return start_btn, resume_btn, stop_btn, delete_checkpoints_btn, download_btn_text, checkpoint_btn_text
|
1249 |
|
1250 |
def update_training_ui(self, training_state: Dict[str, Any]):
|
1251 |
"""Update UI components based on training state"""
|