""" Model loading utilities for Image Tagger application. """ import os import json import torch import platform import traceback import importlib.util def is_windows(): """Check if the system is Windows""" return platform.system() == "Windows" class DummyDataset: """Minimal dataset class for inference""" def __init__(self, metadata): self.total_tags = metadata['total_tags'] self.idx_to_tag = {int(k): v for k, v in metadata['idx_to_tag'].items()} self.tag_to_category = metadata['tag_to_category'] def get_tag_info(self, idx): tag = self.idx_to_tag.get(idx, f"unknown_{idx}") category = self.tag_to_category.get(tag, "general") return tag, category def load_model_code(model_dir): """ Load the model code module from the model directory. Args: model_dir: Path to the model directory Returns: Imported model code module """ model_code_path = os.path.join(model_dir, "model_code.py") if not os.path.exists(model_code_path): raise FileNotFoundError(f"model_code.py not found at {model_code_path}") # Import the model code dynamically spec = importlib.util.spec_from_file_location("model_code", model_code_path) model_code = importlib.util.module_from_spec(spec) spec.loader.exec_module(model_code) # Check if required classes exist if not hasattr(model_code, 'ImageTagger') or not hasattr(model_code, 'FlashAttention'): raise ImportError("Required classes not found in model_code.py") return model_code def check_flash_attention(): """ Check if Flash Attention is properly installed. Returns: bool: True if Flash Attention is available and working """ try: import flash_attn if hasattr(flash_attn, 'flash_attn_func'): module_path = flash_attn.flash_attn_func.__module__ return 'flash_attn_fallback' not in module_path except: pass return False def estimate_model_memory_usage(model, device): """ Estimate the memory usage of a model in MB. """ mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()]) mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()]) mem_total = mem_params + mem_bufs # in bytes return mem_total / (1024 * 1024) # convert to MB def load_exported_model(model_dir, model_type="full"): """ Load the exported model and metadata with correct precision. Args: model_dir: Directory containing the model files model_type: "full" or "initial_only" Returns: model, thresholds, metadata """ print(f"Loading {model_type} model from {model_dir}") # Make sure we have the absolute path to the model directory model_dir = os.path.abspath(model_dir) print(f"Absolute model path: {model_dir}") # Check for required files metadata_path = os.path.join(model_dir, "metadata.json") thresholds_path = os.path.join(model_dir, "thresholds.json") print(f"Looking for thresholds at: {thresholds_path}") # Check platform and Flash Attention status windows_system = is_windows() flash_attn_installed = check_flash_attention() # Add a specific warning for Windows users trying to use the full model without Flash Attention if windows_system and model_type == "full" and not flash_attn_installed: print("Note: On Windows without Flash Attention, the full model will not work") print(" which may produce less accurate results.") print(" Consider using the 'initial_only' model for better performance on Windows.") # Determine file paths based on model type if model_type == "initial_only": # Try both naming conventions if os.path.exists(os.path.join(model_dir, "model_initial_only.pt")): model_path = os.path.join(model_dir, "model_initial_only.pt") else: model_path = os.path.join(model_dir, "model_initial.pt") # Try both naming conventions for info file if os.path.exists(os.path.join(model_dir, "model_info_initial_only.json")): model_info_path = os.path.join(model_dir, "model_info_initial_only.json") else: model_info_path = os.path.join(model_dir, "model_info_initial.json") else: # Try multiple naming conventions for the full model model_filenames = ["model_refined.pt", "model.pt", "model_full.pt"] model_path = None for filename in model_filenames: path = os.path.join(model_dir, filename) if os.path.exists(path): model_path = path break if model_path is None: raise FileNotFoundError(f"No model file found in {model_dir}. Looked for: {', '.join(model_filenames)}") model_info_path = os.path.join(model_dir, "model_info.json") # Check for required files metadata_path = os.path.join(model_dir, "metadata.json") thresholds_path = os.path.join(model_dir, "thresholds.json") required_files = [metadata_path, thresholds_path, model_path] for file_path in required_files: if not os.path.exists(file_path): raise FileNotFoundError(f"Required file {file_path} not found") # Load metadata with open(metadata_path, "r") as f: metadata = json.load(f) # Load model code model_code = load_model_code(model_dir) # Create dataset dummy_dataset = DummyDataset(metadata) # Determine device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load model info if os.path.exists(model_info_path): with open(model_info_path, 'r') as f: model_info = json.load(f) print("Loaded model info:", model_info) tag_context_size = model_info.get('tag_context_size', 256) num_heads = model_info.get('num_heads', 16) else: print("Model info not found, using defaults") tag_context_size = 256 num_heads = 16 try: # Check if InitialOnlyImageTagger class exists has_initial_only_class = hasattr(model_code, 'InitialOnlyImageTagger') # Create the appropriate model type if model_type == "initial_only": # Create the lightweight model if has_initial_only_class: model = model_code.InitialOnlyImageTagger( total_tags=metadata['total_tags'], dataset=dummy_dataset, pretrained=False ) else: # Fallback to using ImageTagger for initial-only if the specific class isn't available print("InitialOnlyImageTagger class not found. Using ImageTagger as fallback.") model = model_code.ImageTagger( total_tags=metadata['total_tags'], dataset=dummy_dataset, pretrained=False, tag_context_size=tag_context_size, num_heads=num_heads ) else: # Create the full model model = model_code.ImageTagger( total_tags=metadata['total_tags'], dataset=dummy_dataset, pretrained=False, tag_context_size=tag_context_size, num_heads=num_heads ) # Load state dict state_dict = torch.load(model_path, map_location=device) # Try loading with strict=True first, then fall back to strict=False try: model.load_state_dict(state_dict, strict=True) print("✓ Model loaded with strict=True") except Exception as e: print(f"Warning: Strict loading failed: {str(e)}") print("Attempting to load with strict=False...") model.load_state_dict(state_dict, strict=False) print("✓ Model loaded with strict=False") # Ensure model is in half precision to match training conditions model = model.to(device=device, dtype=torch.float16) model.eval() # Check parameter dtype param_dtype = next(model.parameters()).dtype print(f"Model loaded successfully on {device} with precision {param_dtype}") print(f"Model memory usage: {estimate_model_memory_usage(model, device):.2f} MB") except Exception as e: print(f"Error loading model: {str(e)}") traceback.print_exc() raise # Load thresholds with open(thresholds_path, "r") as f: thresholds = json.load(f) return model, thresholds, metadata