import functools import gc import os import time from dataclasses import dataclass import torch from diffusers.pipelines import DiffusionPipeline from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor @dataclass class OffloadConfig: # high_cpu_memory: Whether to use pinned memory for offload optimization. This can effectively prevent increased model offload latency caused by memory swapping. high_cpu_memory: bool = True # parameters_level: Whether to enable parameter-level offload. This further reduces VRAM requirements but may result in increased latency. parameters_level: bool = False # compiler_transformer: Whether to enable compilation optimization for the transformer. compiler_transformer: bool = False compiler_cache: str = "/tmp/compile_cache" class HfHook: def __init__(self): device_id = os.environ.get("LOCAL_RANK", 0) self.execution_device = f"cuda:{device_id}" def detach_hook(self, module): pass class Offload: def __init__(self) -> None: self.active_models = [] self.active_models_ids = [] self.active_subcaches = {} self.models = {} self.verboseLevel = 0 self.models_to_quantize = [] self.pinned_modules_data = {} self.blocks_of_modules = {} self.blocks_of_modules_sizes = {} self.compile = False self.device_mem_capacity = torch.cuda.get_device_properties(0).total_memory self.last_reserved_mem_check = 0 self.loaded_blocks = {} self.prev_blocks_names = {} self.next_blocks_names = {} device_id = os.environ.get("LOCAL_RANK", 0) self.device_id = f"cuda:{device_id}" self.default_stream = torch.cuda.default_stream(self.device_id) # torch.cuda.current_stream() self.transfer_stream = torch.cuda.Stream() self.async_transfers = False self.last_run_model = None def check_empty_cuda_cache(self): # Now a method of Offload if torch.cuda.is_available(): torch.cuda.empty_cache() @classmethod def offload(cls, pipeline: DiffusionPipeline, config: OffloadConfig = OffloadConfig()): """ Enable offloading for multiple models in the pipeline, supporting video generation inference on user-level GPUs. pipe: the pipeline object config: offload strategy configuration """ self = cls() self.pinned_modules_data = {} if config.parameters_level: model_budgets = { "transformer": 600 * 1024 * 1024, "text_encoder": 3 * 1024 * 1024 * 1024, "text_encoder_2": 3 * 1024 * 1024 * 1024, } self.async_transfers = True else: model_budgets = {} device_id = os.getenv("LOCAL_RANK", 0) torch.set_default_device(f"cuda:{device_id}") pipeline.hf_device_map = torch.device(f"cuda:{device_id}") pipe_or_dict_of_modules = pipeline.components if config.compiler_transformer: pipeline.transformer.to("cuda") models = { k: v for k, v in pipe_or_dict_of_modules.items() if isinstance(v, torch.nn.Module) and not (config.compiler_transformer and k == "transformer") } print_info = {k: type(v) for k, v in models.items()} print(f"offload models: {print_info}") if config.compiler_transformer: pipeline.text_encoder.to("cpu") pipeline.text_encoder_2.to("cpu") torch.cuda.empty_cache() pipeline.transformer.to("cuda") pipeline.vae.to("cuda") def move_text_encoder_to_gpu(pipe): torch.cuda.empty_cache() pipe.text_encoder.to("cuda") pipe.text_encoder_2.to("cuda") def move_text_encoder_to_cpu(pipe): pipe.text_encoder.to("cpu") pipe.text_encoder_2.to("cpu") torch.cuda.empty_cache() setattr(pipeline, "text_encoder_to_cpu", functools.partial(move_text_encoder_to_cpu, pipeline)) setattr(pipeline, "text_encoder_to_gpu", functools.partial(move_text_encoder_to_gpu, pipeline)) for k, module in pipe_or_dict_of_modules.items(): if isinstance(module, torch.nn.Module): for submodule_name, submodule in module.named_modules(): if not hasattr(submodule, "_hf_hook"): setattr(submodule, "_hf_hook", HfHook()) return self sizeofbfloat16 = torch.bfloat16.itemsize modelPinned = config.high_cpu_memory # Pin in RAM models # Calculate the VRAM requirements of the computational modules to determine whether parameters-level offload is necessary. for model_name, curr_model in models.items(): curr_model.to("cpu").eval() pinned_parameters_data = {} current_model_size = 0 print(f"{model_name} move to pinned memory:{modelPinned}") for p in curr_model.parameters(): if isinstance(p, AffineQuantizedTensor): if not modelPinned and p.tensor_impl.scale.dtype == torch.float32: p.tensor_impl.scale = p.tensor_impl.scale.to(torch.bfloat16) current_model_size += torch.numel(p.tensor_impl.scale) * sizeofbfloat16 current_model_size += torch.numel(p.tensor_impl.float8_data) * sizeofbfloat16 / 2 if modelPinned: p.tensor_impl.float8_data = p.tensor_impl.float8_data.pin_memory() p.tensor_impl.scale = p.tensor_impl.scale.pin_memory() pinned_parameters_data[p] = [p.tensor_impl.float8_data, p.tensor_impl.scale] else: p.data = p.data.to(torch.bfloat16) if p.data.dtype == torch.float32 else p.data.to(p.data.dtype) current_model_size += torch.numel(p.data) * p.data.element_size() if modelPinned: p.data = p.data.pin_memory() pinned_parameters_data[p] = p.data for buffer in curr_model.buffers(): buffer.data = ( buffer.data.to(torch.bfloat16) if buffer.data.dtype == torch.float32 else buffer.data.to(buffer.data.dtype) ) current_model_size += torch.numel(buffer.data) * buffer.data.element_size() if modelPinned: buffer.data = buffer.data.pin_memory() if model_name not in self.models: self.models[model_name] = curr_model curr_model_budget = model_budgets.get(model_name, 0) if curr_model_budget > 0 and curr_model_budget > current_model_size: model_budgets[model_name] = 0 if modelPinned: pinned_buffers_data = {b: b.data for b in curr_model.buffers()} pinned_parameters_data.update(pinned_buffers_data) self.pinned_modules_data[model_name] = pinned_parameters_data gc.collect() torch.cuda.empty_cache() # if config.compiler_transformer: # module = pipeline.transformer # print("wrap transformer forward") # # gpu model wrap # for submodule_name, submodule in module.named_modules(): # if not hasattr(submodule, "_hf_hook"): # setattr(submodule, "_hf_hook", HfHook()) # # forward_method = getattr(module, "forward") # # def wrap_unload_all(*args, **kwargs): # self.unload_all("transformer") # return forward_method(*args, **kwargs) # # setattr(module, "forward", functools.update_wrapper(wrap_unload_all, forward_method)) # wrap forward methods for model_name, curr_model in models.items(): current_budget = model_budgets.get(model_name, 0) current_size = 0 self.loaded_blocks[model_name] = None cur_blocks_prefix, prev_blocks_name, cur_blocks_name, cur_blocks_seq = None, None, None, -1 for submodule_name, submodule in curr_model.named_modules(): # create a fake accelerate parameter so that the _execution_device property returns always "cuda" if not hasattr(submodule, "_hf_hook"): setattr(submodule, "_hf_hook", HfHook()) if not submodule_name: continue # usr parameters-level offload if current_budget > 0: if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): if cur_blocks_prefix == None: cur_blocks_prefix = submodule_name + "." else: if not submodule_name.startswith(cur_blocks_prefix): cur_blocks_prefix = submodule_name + "." cur_blocks_name, cur_blocks_seq = None, -1 else: if cur_blocks_prefix is not None: if submodule_name.startswith(cur_blocks_prefix): num = int(submodule_name[len(cur_blocks_prefix) :].split(".")[0]) if num != cur_blocks_seq and (cur_blocks_name == None or current_size > current_budget): prev_blocks_name = cur_blocks_name cur_blocks_name = cur_blocks_prefix + str(num) cur_blocks_seq = num else: cur_blocks_prefix = None prev_blocks_name = None cur_blocks_name = None cur_blocks_seq = -1 if hasattr(submodule, "forward"): submodule_forward = getattr(submodule, "forward") if not callable(submodule_forward): print("***") continue if len(submodule_name.split(".")) == 1: self.hook_me(submodule, curr_model, model_name, submodule_name, submodule_forward) else: self.hook_me_light( submodule, model_name, cur_blocks_name, submodule_forward, context=submodule_name ) current_size = self.add_module_to_blocks(model_name, cur_blocks_name, submodule, prev_blocks_name) gc.collect() torch.cuda.empty_cache() return self def add_module_to_blocks(self, model_name, blocks_name, submodule, prev_block_name): entry_name = model_name if blocks_name is None else model_name + "/" + blocks_name if entry_name in self.blocks_of_modules: blocks_params = self.blocks_of_modules[entry_name] blocks_params_size = self.blocks_of_modules_sizes[entry_name] else: blocks_params = [] self.blocks_of_modules[entry_name] = blocks_params blocks_params_size = 0 if blocks_name != None: prev_entry_name = None if prev_block_name == None else model_name + "/" + prev_block_name self.prev_blocks_names[entry_name] = prev_entry_name if not prev_block_name == None: self.next_blocks_names[prev_entry_name] = entry_name for p in submodule.parameters(recurse=False): blocks_params.append(p) if isinstance(p, AffineQuantizedTensor): blocks_params_size += p.tensor_impl.float8_data.nbytes blocks_params_size += p.tensor_impl.scale.nbytes else: blocks_params_size += p.data.nbytes for p in submodule.buffers(recurse=False): blocks_params.append(p) blocks_params_size += p.data.nbytes self.blocks_of_modules_sizes[entry_name] = blocks_params_size return blocks_params_size def can_model_be_cotenant(self, model_name): cotenants_map = { "text_encoder": ["vae", "text_encoder_2"], "text_encoder_2": ["vae", "text_encoder"], } potential_cotenants = cotenants_map.get(model_name, None) if potential_cotenants is None: return False for existing_cotenant in self.active_models_ids: if existing_cotenant not in potential_cotenants: return False return True @torch.compiler.disable() def gpu_load_blocks(self, model_name, blocks_name, async_load=False): if blocks_name != None: self.loaded_blocks[model_name] = blocks_name def cpu_to_gpu(stream_to_use, blocks_params, record_for_stream=None): with torch.cuda.stream(stream_to_use): for p in blocks_params: if isinstance(p, AffineQuantizedTensor): p.tensor_impl.float8_data = p.tensor_impl.float8_data.cuda( non_blocking=True, device=self.device_id ) p.tensor_impl.scale = p.tensor_impl.scale.cuda(non_blocking=True, device=self.device_id) else: p.data = p.data.cuda(non_blocking=True, device=self.device_id) if record_for_stream != None: if isinstance(p, AffineQuantizedTensor): p.tensor_impl.float8_data.record_stream(record_for_stream) p.tensor_impl.scale.record_stream(record_for_stream) else: p.data.record_stream(record_for_stream) entry_name = model_name if blocks_name is None else model_name + "/" + blocks_name if self.verboseLevel >= 2: model = self.models[model_name] model_name = model._get_name() print(f"Loading model {entry_name} ({model_name}) in GPU") if self.async_transfers and blocks_name != None: first = self.prev_blocks_names[entry_name] == None next_blocks_entry = self.next_blocks_names[entry_name] if entry_name in self.next_blocks_names else None if first: cpu_to_gpu(torch.cuda.current_stream(), self.blocks_of_modules[entry_name]) torch.cuda.synchronize() if next_blocks_entry != None: cpu_to_gpu(self.transfer_stream, self.blocks_of_modules[next_blocks_entry]) else: cpu_to_gpu(self.default_stream, self.blocks_of_modules[entry_name]) torch.cuda.synchronize() @torch.compiler.disable() def gpu_unload_blocks(self, model_name, blocks_name): if blocks_name != None: self.loaded_blocks[model_name] = None blocks_name = model_name if blocks_name is None else model_name + "/" + blocks_name if self.verboseLevel >= 2: model = self.models[model_name] model_name = model._get_name() print(f"Unloading model {blocks_name} ({model_name}) from GPU") blocks_params = self.blocks_of_modules[blocks_name] if model_name in self.pinned_modules_data: pinned_parameters_data = self.pinned_modules_data[model_name] for p in blocks_params: if isinstance(p, AffineQuantizedTensor): data = pinned_parameters_data[p] p.tensor_impl.float8_data = data[0] p.tensor_impl.scale = data[1] else: p.data = pinned_parameters_data[p] else: for p in blocks_params: if isinstance(p, AffineQuantizedTensor): p.tensor_impl.float8_data = p.tensor_impl.float8_data.cpu() p.tensor_impl.scale = p.tensor_impl.scale.cpu() else: p.data = p.data.cpu() @torch.compiler.disable() def gpu_load(self, model_name): model = self.models[model_name] self.active_models.append(model) self.active_models_ids.append(model_name) self.gpu_load_blocks(model_name, None) # torch.cuda.current_stream().synchronize() @torch.compiler.disable() def unload_all(self, model_name: str): if len(self.active_models_ids) == 0 and self.last_run_model == model_name: self.last_run_model = model_name return for model_name in self.active_models_ids: self.gpu_unload_blocks(model_name, None) loaded_block = self.loaded_blocks[model_name] if loaded_block != None: self.gpu_unload_blocks(model_name, loaded_block) self.loaded_blocks[model_name] = None self.active_models = [] self.active_models_ids = [] self.active_subcaches = [] torch.cuda.empty_cache() gc.collect() self.last_reserved_mem_check = time.time() self.last_run_model = model_name def move_args_to_gpu(self, *args, **kwargs): new_args = [] new_kwargs = {} for arg in args: if torch.is_tensor(arg): if arg.dtype == torch.float32: arg = arg.to(torch.bfloat16).cuda(non_blocking=True, device=self.device_id) else: arg = arg.cuda(non_blocking=True, device=self.device_id) new_args.append(arg) for k in kwargs: arg = kwargs[k] if torch.is_tensor(arg): if arg.dtype == torch.float32: arg = arg.to(torch.bfloat16).cuda(non_blocking=True, device=self.device_id) else: arg = arg.cuda(non_blocking=True, device=self.device_id) new_kwargs[k] = arg return new_args, new_kwargs def ready_to_check_mem(self): if self.compile: return cur_clock = time.time() # can't check at each call if we can empty the cuda cache as quering the reserved memory value is a time consuming operation if (cur_clock - self.last_reserved_mem_check) < 0.200: return False self.last_reserved_mem_check = cur_clock return True def empty_cache_if_needed(self): mem_reserved = torch.cuda.memory_reserved() mem_threshold = 0.9 * self.device_mem_capacity if mem_reserved >= mem_threshold: mem_allocated = torch.cuda.memory_allocated() if mem_allocated <= 0.70 * mem_reserved: torch.cuda.empty_cache() tm = time.time() if self.verboseLevel >= 2: print(f"Empty Cuda cache at {tm}") def any_param_or_buffer(self, target_module: torch.nn.Module): for _ in target_module.parameters(recurse=False): return True for _ in target_module.buffers(recurse=False): return True return False def hook_me_light(self, target_module, model_name, blocks_name, previous_method, context): anyParam = self.any_param_or_buffer(target_module) def check_empty_cuda_cache(module, *args, **kwargs): if self.ready_to_check_mem(): self.empty_cache_if_needed() return previous_method(*args, **kwargs) def load_module_blocks(module, *args, **kwargs): if blocks_name == None: if self.ready_to_check_mem(): self.empty_cache_if_needed() else: loaded_block = self.loaded_blocks[model_name] if loaded_block == None or loaded_block != blocks_name: if loaded_block != None: self.gpu_unload_blocks(model_name, loaded_block) if self.ready_to_check_mem(): self.empty_cache_if_needed() self.loaded_blocks[model_name] = blocks_name self.gpu_load_blocks(model_name, blocks_name) return previous_method(*args, **kwargs) if hasattr(target_module, "_mm_id"): orig_model_name = getattr(target_module, "_mm_id") if self.verboseLevel >= 2: print( f"Model '{model_name}' shares module '{target_module._get_name()}' with module '{orig_model_name}' " ) assert not anyParam return setattr(target_module, "_mm_id", model_name) if blocks_name != None and anyParam: setattr( target_module, "forward", functools.update_wrapper(functools.partial(load_module_blocks, target_module), previous_method), ) # print(f"new cache:{blocks_name}") else: setattr( target_module, "forward", functools.update_wrapper(functools.partial(check_empty_cuda_cache, target_module), previous_method), ) def hook_me(self, target_module, model, model_name, module_id, previous_method): def check_change_module(module, *args, **kwargs): performEmptyCacheTest = False if not model_name in self.active_models_ids: new_model_name = getattr(module, "_mm_id") if not self.can_model_be_cotenant(new_model_name): self.unload_all(model_name) performEmptyCacheTest = False self.gpu_load(new_model_name) args, kwargs = self.move_args_to_gpu(*args, **kwargs) if performEmptyCacheTest: self.empty_cache_if_needed() return previous_method(*args, **kwargs) if hasattr(target_module, "_mm_id"): return setattr(target_module, "_mm_id", model_name) setattr( target_module, "forward", functools.update_wrapper(functools.partial(check_change_module, target_module), previous_method), ) if not self.verboseLevel >= 1: return if module_id == None or module_id == "": model_name = model._get_name() print(f"Hooked in model '{model_name}' ({model_name})")