jbilcke-hf HF Staff commited on
Commit
914fbd3
·
1 Parent(s): 5a11e7d

Fix model type display issue when internal model type name is used

Browse files
vms/ui/models/services/models_service.py CHANGED
@@ -55,14 +55,24 @@ class Model:
55
  status = ui_state.get('project_status', 'draft')
56
 
57
  # Get model type from UI state
58
- model_type_display = ui_state.get('model_type', '')
59
 
60
- # Map display name to internal name
 
61
  for display_name, internal_name in MODEL_TYPES.items():
62
- if display_name == model_type_display:
63
  model_type = internal_name
64
  model_display_name = display_name
 
65
  break
 
 
 
 
 
 
 
 
66
  except Exception as e:
67
  logger.error(f"Error loading UI state for model {model_id}: {str(e)}")
68
 
 
55
  status = ui_state.get('project_status', 'draft')
56
 
57
  # Get model type from UI state
58
+ model_type_value = ui_state.get('model_type', '')
59
 
60
+ # First check if model_type_value is a display name
61
+ display_name_found = False
62
  for display_name, internal_name in MODEL_TYPES.items():
63
+ if display_name == model_type_value:
64
  model_type = internal_name
65
  model_display_name = display_name
66
+ display_name_found = True
67
  break
68
+
69
+ # If not a display name, check if it's an internal name
70
+ if not display_name_found:
71
+ for display_name, internal_name in MODEL_TYPES.items():
72
+ if internal_name == model_type_value:
73
+ model_type = internal_name
74
+ model_display_name = display_name
75
+ break
76
  except Exception as e:
77
  logger.error(f"Error loading UI state for model {model_id}: {str(e)}")
78
 
vms/utils/finetrainers_utils.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  from pathlib import Path
3
  import logging
4
  import shutil
 
5
  from typing import Any, Optional, Dict, List, Union, Tuple
6
 
7
  from ..config import (
@@ -11,7 +12,8 @@ from ..config import (
11
  DEFAULT_VALIDATION_HEIGHT,
12
  DEFAULT_VALIDATION_WIDTH,
13
  DEFAULT_VALIDATION_NB_FRAMES,
14
- DEFAULT_VALIDATION_FRAMERATE
 
15
  )
16
  from .utils import get_video_fps, extract_scene_info, make_archive, is_image_file, is_video_file
17
 
@@ -39,6 +41,26 @@ def prepare_finetrainers_dataset(training_path=None, training_videos_path=None)
39
  Returns:
40
  Tuple of (videos_file_path, prompts_file_path)
41
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  # Verifies the videos subdirectory
44
  training_videos_path.mkdir(exist_ok=True)
@@ -67,6 +89,20 @@ def prepare_finetrainers_dataset(training_path=None, training_videos_path=None)
67
  relative_path = f"videos/{file.name}"
68
  media_files.append(relative_path)
69
  captions.append(caption)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  # Write files if we have content
72
  if media_files and captions:
@@ -102,6 +138,24 @@ def copy_files_to_training_dir(prompt_prefix: str, training_videos_path=None) ->
102
 
103
  gr.Info("Copying assets to the training dataset..")
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  # Find files needing captions
106
  video_files = list(STAGING_PATH.glob("*.mp4"))
107
  image_files = [f for f in STAGING_PATH.glob("*") if is_image_file(f)]
@@ -155,7 +209,9 @@ def copy_files_to_training_dir(prompt_prefix: str, training_videos_path=None) ->
155
  print(f"failed to copy one of the pairs: {e}")
156
  pass
157
 
158
- prepare_finetrainers_dataset()
 
 
159
 
160
  gr.Info(f"Successfully generated the training dataset ({nb_copied_pairs} pairs)")
161
 
 
2
  from pathlib import Path
3
  import logging
4
  import shutil
5
+ import json
6
  from typing import Any, Optional, Dict, List, Union, Tuple
7
 
8
  from ..config import (
 
12
  DEFAULT_VALIDATION_HEIGHT,
13
  DEFAULT_VALIDATION_WIDTH,
14
  DEFAULT_VALIDATION_NB_FRAMES,
15
+ DEFAULT_VALIDATION_FRAMERATE,
16
+ load_global_config, get_project_paths
17
  )
18
  from .utils import get_video_fps, extract_scene_info, make_archive, is_image_file, is_video_file
19
 
 
41
  Returns:
42
  Tuple of (videos_file_path, prompts_file_path)
43
  """
44
+ # Get project ID from global config if paths not provided
45
+ if training_path is None or training_videos_path is None:
46
+ config = load_global_config()
47
+ project_id = config.get("latest_model_project_id")
48
+
49
+ if not project_id:
50
+ logger.error("No active project found in global config")
51
+ return None, None
52
+
53
+ # Get paths for this project
54
+ project_training_path, project_videos_path, _, _ = get_project_paths(project_id)
55
+
56
+ # Use provided paths or defaults
57
+ training_path = training_path or project_training_path
58
+ training_videos_path = training_videos_path or project_videos_path
59
+
60
+ # Validate paths
61
+ if training_path is None or training_videos_path is None:
62
+ logger.error("Could not determine training paths")
63
+ return None, None
64
 
65
  # Verifies the videos subdirectory
66
  training_videos_path.mkdir(exist_ok=True)
 
89
  relative_path = f"videos/{file.name}"
90
  media_files.append(relative_path)
91
  captions.append(caption)
92
+
93
+ # Also include image files if present (for image conditioning)
94
+ for idx, file in enumerate(sorted(training_videos_path.glob("*"))):
95
+ if is_image_file(file):
96
+ caption_file = file.with_suffix('.txt')
97
+ if caption_file.exists():
98
+ # Normalize caption to single line
99
+ caption = caption_file.read_text().strip()
100
+ caption = ' '.join(caption.split())
101
+
102
+ # Use relative path from training root
103
+ relative_path = f"videos/{file.name}"
104
+ media_files.append(relative_path)
105
+ captions.append(caption)
106
 
107
  # Write files if we have content
108
  if media_files and captions:
 
138
 
139
  gr.Info("Copying assets to the training dataset..")
140
 
141
+ # Get project ID from global config
142
+ config = load_global_config()
143
+ project_id = config.get("latest_model_project_id")
144
+
145
+ if not project_id:
146
+ logger.error("No active project found in global config")
147
+ raise ValueError("No active project found. Please create or select a project first.")
148
+
149
+ # Get paths for this project if not provided
150
+ if training_videos_path is None:
151
+ _, training_videos_path, _, _ = get_project_paths(project_id)
152
+
153
+ if training_videos_path is None:
154
+ logger.error("Could not determine training videos path")
155
+ raise ValueError("Training videos path is not set or could not be determined")
156
+
157
+ logger.info(f"Using training videos path: {training_videos_path}")
158
+
159
  # Find files needing captions
160
  video_files = list(STAGING_PATH.glob("*.mp4"))
161
  image_files = [f for f in STAGING_PATH.glob("*") if is_image_file(f)]
 
209
  print(f"failed to copy one of the pairs: {e}")
210
  pass
211
 
212
+ # Get training_path for prepare_finetrainers_dataset
213
+ training_path, _, _, _ = get_project_paths(project_id)
214
+ prepare_finetrainers_dataset(training_path, training_videos_path)
215
 
216
  gr.Info(f"Successfully generated the training dataset ({nb_copied_pairs} pairs)")
217