jbilcke-hf HF Staff commited on
Commit
41a8716
·
1 Parent(s): 2264c6e

Revert to commit a9df757

Browse files
vms/ui/app_ui.py CHANGED
@@ -403,7 +403,6 @@ class AppUI:
403
  ]
404
  )
405
 
406
-
407
  # Button update timer for button components (every 1 second)
408
  button_timer = gr.Timer(value=1)
409
  button_outputs = [
 
403
  ]
404
  )
405
 
 
406
  # Button update timer for button components (every 1 second)
407
  button_timer = gr.Timer(value=1)
408
  button_outputs = [
vms/ui/models/tabs/training_tab.py CHANGED
@@ -88,8 +88,9 @@ class TrainingTab(BaseTab):
88
  gr.Markdown(model.model_display_name or "Unknown")
89
 
90
  with gr.Column(scale=2, min_width=20):
91
- progress_text = f"Step {model.current_step}/{model.total_steps} ({model.training_progress:.1f}%)"
92
  gr.Markdown(progress_text)
 
93
 
94
  with gr.Column(scale=2, min_width=20):
95
  with gr.Row():
 
88
  gr.Markdown(model.model_display_name or "Unknown")
89
 
90
  with gr.Column(scale=2, min_width=20):
91
+ progress_text = f"Step {model.current_step}/{model.total_steps}"
92
  gr.Markdown(progress_text)
93
+ gr.Progress(value=model.training_progress/100)
94
 
95
  with gr.Column(scale=2, min_width=20):
96
  with gr.Row():
vms/ui/project/services/training.py CHANGED
@@ -1823,9 +1823,12 @@ class TrainingService:
1823
  try:
1824
  checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
1825
  if not checkpoints:
1826
- return "No checkpoints available"
1827
 
1828
- return f"💽 Download checkpoints"
 
 
 
1829
  except Exception as e:
1830
  logger.warning(f"Error getting checkpoint info for button text: {e}")
1831
- return "No checkpoints available"
 
1823
  try:
1824
  checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
1825
  if not checkpoints:
1826
+ return "📥 Download checkpoints (not available)"
1827
 
1828
+ # Get the latest checkpoint by step number
1829
+ latest_checkpoint = max(checkpoints, key=lambda x: int(x.name.split("_")[-1]))
1830
+ step_num = int(latest_checkpoint.name.split("_")[-1])
1831
+ return f"📥 Download checkpoints (step {step_num})"
1832
  except Exception as e:
1833
  logger.warning(f"Error getting checkpoint info for button text: {e}")
1834
+ return "📥 Download checkpoints (not available)"
vms/ui/project/tabs/manage_tab.py CHANGED
@@ -25,6 +25,50 @@ class ManageTab(BaseTab):
25
  self.id = "manage_tab"
26
  self.title = "5️⃣ Storage"
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  def create(self, parent=None) -> gr.TabItem:
30
  """Create the Manage tab UI components"""
@@ -46,19 +90,19 @@ class ManageTab(BaseTab):
46
  gr.Markdown("📦 Training dataset download disabled for large datasets")
47
 
48
  self.components["download_model_btn"] = gr.DownloadButton(
49
- "🧠 Download LoRA weights",
50
  variant="secondary",
51
  size="lg"
52
  )
53
 
54
  self.components["download_checkpoint_btn"] = gr.DownloadButton(
55
- "💽 Download Checkpoints",
56
  variant="secondary",
57
  size="lg"
58
  )
59
 
60
  self.components["download_output_btn"] = gr.DownloadButton(
61
- "📁 Download output/ (.zip)",
62
  variant="secondary",
63
  size="lg",
64
  visible=False
 
25
  self.id = "manage_tab"
26
  self.title = "5️⃣ Storage"
27
 
28
+ def get_download_button_text(self) -> str:
29
+ """Get the dynamic text for the download button based on current model state"""
30
+ try:
31
+ model_info = self.app.training.get_model_output_info()
32
+ if model_info["path"] and model_info["steps"]:
33
+ return f"🧠 Download weights ({model_info['steps']} steps)"
34
+ elif model_info["path"]:
35
+ return "🧠 Download weights (.safetensors)"
36
+ else:
37
+ return "🧠 Download weights (not available)"
38
+ except Exception as e:
39
+ logger.warning(f"Error getting model info for button text: {e}")
40
+ return "🧠 Download weights (.safetensors)"
41
+
42
+ def get_checkpoint_button_text(self) -> str:
43
+ """Get the dynamic text for the download checkpoint button"""
44
+ try:
45
+ return self.app.training.get_checkpoint_button_text()
46
+ except Exception as e:
47
+ logger.warning(f"Error getting checkpoint button text: {e}")
48
+ return "📥 Download checkpoints (not available)"
49
+
50
+ def update_download_button_text(self) -> gr.update:
51
+ """Update the download button text"""
52
+ return gr.update(value=self.get_download_button_text())
53
+
54
+ def update_checkpoint_button_text(self) -> gr.update:
55
+ """Update the checkpoint button text"""
56
+ return gr.update(value=self.get_checkpoint_button_text())
57
+
58
+ def update_both_download_buttons(self) -> Tuple[gr.update, gr.update]:
59
+ """Update both download button texts"""
60
+ return (
61
+ gr.update(value=self.get_download_button_text()),
62
+ gr.update(value=self.get_checkpoint_button_text())
63
+ )
64
+
65
+ def download_and_update_button(self):
66
+ """Handle download and return updated button with current text"""
67
+ # Get the safetensors path for download
68
+ path = self.app.training.get_model_output_safetensors()
69
+ # For DownloadButton, we need to return the file path directly for download
70
+ # The button text will be updated on next render
71
+ return path
72
 
73
  def create(self, parent=None) -> gr.TabItem:
74
  """Create the Manage tab UI components"""
 
90
  gr.Markdown("📦 Training dataset download disabled for large datasets")
91
 
92
  self.components["download_model_btn"] = gr.DownloadButton(
93
+ self.get_download_button_text(),
94
  variant="secondary",
95
  size="lg"
96
  )
97
 
98
  self.components["download_checkpoint_btn"] = gr.DownloadButton(
99
+ self.get_checkpoint_button_text(),
100
  variant="secondary",
101
  size="lg"
102
  )
103
 
104
  self.components["download_output_btn"] = gr.DownloadButton(
105
+ "📁 Download output directory (.zip)",
106
  variant="secondary",
107
  size="lg",
108
  visible=False
vms/ui/project/tabs/train_tab.py CHANGED
@@ -494,7 +494,12 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
494
  save_iterations, repo_id, progress
495
  )
496
 
497
- return status, logs
 
 
 
 
 
498
 
499
  def handle_resume_training(
500
  self, model_type, model_version, training_type,
@@ -506,7 +511,10 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
506
  checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
507
 
508
  if not checkpoints:
509
- return "No checkpoints found to resume from", "Please start a new training session instead"
 
 
 
510
 
511
  self.app.training.append_log(f"Resuming training from latest checkpoint")
512
 
@@ -518,7 +526,12 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
518
  resume_from_checkpoint="latest"
519
  )
520
 
521
- return status, logs
 
 
 
 
 
522
 
523
  def handle_start_from_lora_training(
524
  self, model_type, model_version, training_type,
@@ -529,22 +542,26 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
529
  # Find the latest LoRA weights
530
  lora_weights_path = self.app.output_path / "lora_weights"
531
 
 
 
 
 
532
  if not lora_weights_path.exists():
533
- return "No LoRA weights found", "Please train a model first or start a new training session"
534
 
535
  # Find the latest LoRA checkpoint directory
536
  lora_dirs = sorted([d for d in lora_weights_path.iterdir() if d.is_dir()],
537
  key=lambda x: int(x.name), reverse=True)
538
 
539
  if not lora_dirs:
540
- return "No LoRA weight directories found", "Please train a model first or start a new training session"
541
 
542
  latest_lora_dir = lora_dirs[0]
543
 
544
  # Verify the LoRA weights file exists
545
  lora_weights_file = latest_lora_dir / "pytorch_lora_weights.safetensors"
546
  if not lora_weights_file.exists():
547
- return f"LoRA weights file not found in {latest_lora_dir}", "Please check your LoRA weights directory"
548
 
549
  # Clear checkpoints to start fresh (but keep LoRA weights)
550
  for checkpoint in self.app.output_path.glob("finetrainers_step_*"):
@@ -565,7 +582,11 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
565
  save_iterations, repo_id, progress,
566
  )
567
 
568
- return status, logs
 
 
 
 
569
 
570
  def connect_events(self) -> None:
571
  """Connect event handlers to UI components"""
@@ -748,7 +769,9 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
748
  ],
749
  outputs=[
750
  self.components["status_box"],
751
- self.components["log_box"]
 
 
752
  ]
753
  )
754
 
@@ -768,7 +791,9 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
768
  ],
769
  outputs=[
770
  self.components["status_box"],
771
- self.components["log_box"]
 
 
772
  ]
773
  )
774
 
@@ -788,7 +813,9 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
788
  ],
789
  outputs=[
790
  self.components["status_box"],
791
- self.components["log_box"]
 
 
792
  ]
793
  )
794
 
@@ -804,7 +831,9 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
804
  self.components["current_task_box"],
805
  self.components["start_btn"],
806
  self.components["stop_btn"],
807
- third_btn
 
 
808
  ]
809
  )
810
 
@@ -816,7 +845,9 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
816
  self.components["current_task_box"],
817
  self.components["start_btn"],
818
  self.components["stop_btn"],
819
- third_btn
 
 
820
  ]
821
  )
822
 
@@ -1209,7 +1240,12 @@ Full finetune mode trains all parameters of the model, requiring more VRAM but p
1209
  variant="stop"
1210
  )
1211
 
1212
- return start_btn, resume_btn, stop_btn, delete_checkpoints_btn
 
 
 
 
 
1213
 
1214
  def update_training_ui(self, training_state: Dict[str, Any]):
1215
  """Update UI components based on training state"""
 
494
  save_iterations, repo_id, progress
495
  )
496
 
497
+ # Update download button texts
498
+ manage_tab = self.app.tabs["manage_tab"]
499
+ download_btn_text = gr.update(value=manage_tab.get_download_button_text())
500
+ checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
501
+
502
+ return status, logs, download_btn_text, checkpoint_btn_text
503
 
504
  def handle_resume_training(
505
  self, model_type, model_version, training_type,
 
511
  checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
512
 
513
  if not checkpoints:
514
+ manage_tab = self.app.tabs["manage_tab"]
515
+ download_btn_text = gr.update(value=manage_tab.get_download_button_text())
516
+ checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
517
+ return "No checkpoints found to resume from", "Please start a new training session instead", download_btn_text, checkpoint_btn_text
518
 
519
  self.app.training.append_log(f"Resuming training from latest checkpoint")
520
 
 
526
  resume_from_checkpoint="latest"
527
  )
528
 
529
+ # Update download button texts
530
+ manage_tab = self.app.tabs["manage_tab"]
531
+ download_btn_text = gr.update(value=manage_tab.get_download_button_text())
532
+ checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
533
+
534
+ return status, logs, download_btn_text, checkpoint_btn_text
535
 
536
  def handle_start_from_lora_training(
537
  self, model_type, model_version, training_type,
 
542
  # Find the latest LoRA weights
543
  lora_weights_path = self.app.output_path / "lora_weights"
544
 
545
+ manage_tab = self.app.tabs["manage_tab"]
546
+ download_btn_text = gr.update(value=manage_tab.get_download_button_text())
547
+ checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
548
+
549
  if not lora_weights_path.exists():
550
+ return "No LoRA weights found", "Please train a model first or start a new training session", download_btn_text, checkpoint_btn_text
551
 
552
  # Find the latest LoRA checkpoint directory
553
  lora_dirs = sorted([d for d in lora_weights_path.iterdir() if d.is_dir()],
554
  key=lambda x: int(x.name), reverse=True)
555
 
556
  if not lora_dirs:
557
+ return "No LoRA weight directories found", "Please train a model first or start a new training session", download_btn_text, checkpoint_btn_text
558
 
559
  latest_lora_dir = lora_dirs[0]
560
 
561
  # Verify the LoRA weights file exists
562
  lora_weights_file = latest_lora_dir / "pytorch_lora_weights.safetensors"
563
  if not lora_weights_file.exists():
564
+ return f"LoRA weights file not found in {latest_lora_dir}", "Please check your LoRA weights directory", download_btn_text, checkpoint_btn_text
565
 
566
  # Clear checkpoints to start fresh (but keep LoRA weights)
567
  for checkpoint in self.app.output_path.glob("finetrainers_step_*"):
 
582
  save_iterations, repo_id, progress,
583
  )
584
 
585
+ # Update download button texts
586
+ download_btn_text = gr.update(value=manage_tab.get_download_button_text())
587
+ checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
588
+
589
+ return status, logs, download_btn_text, checkpoint_btn_text
590
 
591
  def connect_events(self) -> None:
592
  """Connect event handlers to UI components"""
 
769
  ],
770
  outputs=[
771
  self.components["status_box"],
772
+ self.components["log_box"],
773
+ self.app.tabs["manage_tab"].components["download_model_btn"],
774
+ self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
775
  ]
776
  )
777
 
 
791
  ],
792
  outputs=[
793
  self.components["status_box"],
794
+ self.components["log_box"],
795
+ self.app.tabs["manage_tab"].components["download_model_btn"],
796
+ self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
797
  ]
798
  )
799
 
 
813
  ],
814
  outputs=[
815
  self.components["status_box"],
816
+ self.components["log_box"],
817
+ self.app.tabs["manage_tab"].components["download_model_btn"],
818
+ self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
819
  ]
820
  )
821
 
 
831
  self.components["current_task_box"],
832
  self.components["start_btn"],
833
  self.components["stop_btn"],
834
+ third_btn,
835
+ self.app.tabs["manage_tab"].components["download_model_btn"],
836
+ self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
837
  ]
838
  )
839
 
 
845
  self.components["current_task_box"],
846
  self.components["start_btn"],
847
  self.components["stop_btn"],
848
+ third_btn,
849
+ self.app.tabs["manage_tab"].components["download_model_btn"],
850
+ self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
851
  ]
852
  )
853
 
 
1240
  variant="stop"
1241
  )
1242
 
1243
+ # Update download button texts
1244
+ manage_tab = self.app.tabs["manage_tab"]
1245
+ download_btn_text = gr.update(value=manage_tab.get_download_button_text())
1246
+ checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
1247
+
1248
+ return start_btn, resume_btn, stop_btn, delete_checkpoints_btn, download_btn_text, checkpoint_btn_text
1249
 
1250
  def update_training_ui(self, training_state: Dict[str, Any]):
1251
  """Update UI components based on training state"""