1inkusFace commited on
Commit
e972c48
·
verified ·
1 Parent(s): 747a256

Update skyreelsinfer/offload.py

Browse files
Files changed (1) hide show
  1. skyreelsinfer/offload.py +519 -515
skyreelsinfer/offload.py CHANGED
@@ -1,515 +1,519 @@
1
- import functools
2
- import gc
3
- import os
4
- import time
5
- from dataclasses import dataclass
6
-
7
- import torch
8
- from diffusers.pipelines import DiffusionPipeline
9
- from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor
10
-
11
-
12
- @dataclass
13
- class OffloadConfig:
14
- # high_cpu_memory: Whether to use pinned memory for offload optimization. This can effectively prevent increased model offload latency caused by memory swapping.
15
- high_cpu_memory: bool = True
16
- # parameters_level: Whether to enable parameter-level offload. This further reduces VRAM requirements but may result in increased latency.
17
- parameters_level: bool = False
18
- # compiler_transformer: Whether to enable compilation optimization for the transformer.
19
- compiler_transformer: bool = False
20
- compiler_cache: str = "/tmp/compile_cache"
21
-
22
-
23
- class HfHook:
24
- def __init__(self):
25
- device_id = os.environ.get("LOCAL_RANK", 0)
26
- self.execution_device = f"cuda:{device_id}"
27
-
28
- def detach_hook(self, module):
29
- pass
30
-
31
-
32
- class Offload:
33
- def __init__(self) -> None:
34
- self.active_models = []
35
- self.active_models_ids = []
36
- self.active_subcaches = {}
37
- self.models = {}
38
- self.verboseLevel = 0
39
- self.models_to_quantize = []
40
- self.pinned_modules_data = {}
41
- self.blocks_of_modules = {}
42
- self.blocks_of_modules_sizes = {}
43
- self.compile = False
44
- self.device_mem_capacity = torch.cuda.get_device_properties(0).total_memory
45
- self.last_reserved_mem_check = 0
46
- self.loaded_blocks = {}
47
- self.prev_blocks_names = {}
48
- self.next_blocks_names = {}
49
- device_id = os.environ.get("LOCAL_RANK", 0)
50
- self.device_id = f"cuda:{device_id}"
51
- self.default_stream = torch.cuda.default_stream(self.device_id) # torch.cuda.current_stream()
52
- self.transfer_stream = torch.cuda.Stream()
53
- self.async_transfers = False
54
- self.last_run_model = None
55
-
56
- @classmethod
57
- def offload(cls, pipeline: DiffusionPipeline, config: OffloadConfig = OffloadConfig()):
58
- """
59
- Enable offloading for multiple models in the pipeline, supporting video generation inference on user-level GPUs.
60
- pipe: the pipeline object
61
- config: offload strategy configuration
62
- """
63
- self = cls()
64
- self.pinned_modules_data = {}
65
- if config.parameters_level:
66
- model_budgets = {
67
- "transformer": 600 * 1024 * 1024,
68
- "text_encoder": 3 * 1024 * 1024 * 1024,
69
- "text_encoder_2": 3 * 1024 * 1024 * 1024,
70
- }
71
- self.async_transfers = True
72
- else:
73
- model_budgets = {}
74
-
75
- device_id = os.getenv("LOCAL_RANK", 0)
76
- torch.set_default_device(f"cuda:{device_id}")
77
- pipeline.hf_device_map = torch.device(f"cuda:{device_id}")
78
- pipe_or_dict_of_modules = pipeline.components
79
- if config.compiler_transformer:
80
- pipeline.transformer.to("cuda")
81
- models = {
82
- k: v
83
- for k, v in pipe_or_dict_of_modules.items()
84
- if isinstance(v, torch.nn.Module) and not (config.compiler_transformer and k == "transformer")
85
- }
86
- print_info = {k: type(v) for k, v in models.items()}
87
- print(f"offload models: {print_info}")
88
- if config.compiler_transformer:
89
- pipeline.text_encoder.to("cpu")
90
- pipeline.text_encoder_2.to("cpu")
91
- torch.cuda.empty_cache()
92
- pipeline.transformer.to("cuda")
93
- pipeline.vae.to("cuda")
94
-
95
- def move_text_encoder_to_gpu(pipe):
96
- torch.cuda.empty_cache()
97
- pipe.text_encoder.to("cuda")
98
- pipe.text_encoder_2.to("cuda")
99
-
100
- def move_text_encoder_to_cpu(pipe):
101
- pipe.text_encoder.to("cpu")
102
- pipe.text_encoder_2.to("cpu")
103
- torch.cuda.empty_cache()
104
-
105
- setattr(pipeline, "text_encoder_to_cpu", functools.partial(move_text_encoder_to_cpu, pipeline))
106
- setattr(pipeline, "text_encoder_to_gpu", functools.partial(move_text_encoder_to_gpu, pipeline))
107
-
108
- for k, module in pipe_or_dict_of_modules.items():
109
- if isinstance(module, torch.nn.Module):
110
- for submodule_name, submodule in module.named_modules():
111
- if not hasattr(submodule, "_hf_hook"):
112
- setattr(submodule, "_hf_hook", HfHook())
113
- return self
114
-
115
- sizeofbfloat16 = torch.bfloat16.itemsize
116
- modelPinned = config.high_cpu_memory
117
- # Pin in RAM models
118
- # Calculate the VRAM requirements of the computational modules to determine whether parameters-level offload is necessary.
119
- for model_name, curr_model in models.items():
120
- curr_model.to("cpu").eval()
121
- pinned_parameters_data = {}
122
- current_model_size = 0
123
- print(f"{model_name} move to pinned memory:{modelPinned}")
124
- for p in curr_model.parameters():
125
- if isinstance(p, AffineQuantizedTensor):
126
- if not modelPinned and p.tensor_impl.scale.dtype == torch.float32:
127
- p.tensor_impl.scale = p.tensor_impl.scale.to(torch.bfloat16)
128
- current_model_size += torch.numel(p.tensor_impl.scale) * sizeofbfloat16
129
- current_model_size += torch.numel(p.tensor_impl.float8_data) * sizeofbfloat16 / 2
130
- if modelPinned:
131
- p.tensor_impl.float8_data = p.tensor_impl.float8_data.pin_memory()
132
- p.tensor_impl.scale = p.tensor_impl.scale.pin_memory()
133
- pinned_parameters_data[p] = [p.tensor_impl.float8_data, p.tensor_impl.scale]
134
- else:
135
- p.data = p.data.to(torch.bfloat16) if p.data.dtype == torch.float32 else p.data.to(p.data.dtype)
136
- current_model_size += torch.numel(p.data) * p.data.element_size()
137
- if modelPinned:
138
- p.data = p.data.pin_memory()
139
- pinned_parameters_data[p] = p.data
140
-
141
- for buffer in curr_model.buffers():
142
- buffer.data = (
143
- buffer.data.to(torch.bfloat16)
144
- if buffer.data.dtype == torch.float32
145
- else buffer.data.to(buffer.data.dtype)
146
- )
147
- current_model_size += torch.numel(buffer.data) * buffer.data.element_size()
148
- if modelPinned:
149
- buffer.data = buffer.data.pin_memory()
150
-
151
- if model_name not in self.models:
152
- self.models[model_name] = curr_model
153
-
154
- curr_model_budget = model_budgets.get(model_name, 0)
155
- if curr_model_budget > 0 and curr_model_budget > current_model_size:
156
- model_budgets[model_name] = 0
157
-
158
- if modelPinned:
159
- pinned_buffers_data = {b: b.data for b in curr_model.buffers()}
160
- pinned_parameters_data.update(pinned_buffers_data)
161
- self.pinned_modules_data[model_name] = pinned_parameters_data
162
- gc.collect()
163
- torch.cuda.empty_cache()
164
-
165
- # if config.compiler_transformer:
166
- # module = pipeline.transformer
167
- # print("wrap transformer forward")
168
- # # gpu model wrap
169
- # for submodule_name, submodule in module.named_modules():
170
- # if not hasattr(submodule, "_hf_hook"):
171
- # setattr(submodule, "_hf_hook", HfHook())
172
- #
173
- # forward_method = getattr(module, "forward")
174
- #
175
- # def wrap_unload_all(*args, **kwargs):
176
- # self.unload_all("transformer")
177
- # return forward_method(*args, **kwargs)
178
- #
179
- # setattr(module, "forward", functools.update_wrapper(wrap_unload_all, forward_method))
180
-
181
- # wrap forward methods
182
- for model_name, curr_model in models.items():
183
- current_budget = model_budgets.get(model_name, 0)
184
- current_size = 0
185
- self.loaded_blocks[model_name] = None
186
- cur_blocks_prefix, prev_blocks_name, cur_blocks_name, cur_blocks_seq = None, None, None, -1
187
-
188
- for submodule_name, submodule in curr_model.named_modules():
189
- # create a fake accelerate parameter so that the _execution_device property returns always "cuda"
190
- if not hasattr(submodule, "_hf_hook"):
191
- setattr(submodule, "_hf_hook", HfHook())
192
-
193
- if not submodule_name:
194
- continue
195
-
196
- # usr parameters-level offload
197
- if current_budget > 0:
198
- if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
199
- if cur_blocks_prefix == None:
200
- cur_blocks_prefix = submodule_name + "."
201
- else:
202
- if not submodule_name.startswith(cur_blocks_prefix):
203
- cur_blocks_prefix = submodule_name + "."
204
- cur_blocks_name, cur_blocks_seq = None, -1
205
- else:
206
- if cur_blocks_prefix is not None:
207
- if submodule_name.startswith(cur_blocks_prefix):
208
- num = int(submodule_name[len(cur_blocks_prefix) :].split(".")[0])
209
- if num != cur_blocks_seq and (cur_blocks_name == None or current_size > current_budget):
210
- prev_blocks_name = cur_blocks_name
211
- cur_blocks_name = cur_blocks_prefix + str(num)
212
- cur_blocks_seq = num
213
- else:
214
- cur_blocks_prefix = None
215
- prev_blocks_name = None
216
- cur_blocks_name = None
217
- cur_blocks_seq = -1
218
-
219
- if hasattr(submodule, "forward"):
220
- submodule_forward = getattr(submodule, "forward")
221
- if not callable(submodule_forward):
222
- print("***")
223
- continue
224
- if len(submodule_name.split(".")) == 1:
225
- self.hook_me(submodule, curr_model, model_name, submodule_name, submodule_forward)
226
- else:
227
- self.hook_me_light(
228
- submodule, model_name, cur_blocks_name, submodule_forward, context=submodule_name
229
- )
230
- current_size = self.add_module_to_blocks(model_name, cur_blocks_name, submodule, prev_blocks_name)
231
-
232
- gc.collect()
233
- torch.cuda.empty_cache()
234
- return self
235
-
236
- def add_module_to_blocks(self, model_name, blocks_name, submodule, prev_block_name):
237
-
238
- entry_name = model_name if blocks_name is None else model_name + "/" + blocks_name
239
- if entry_name in self.blocks_of_modules:
240
- blocks_params = self.blocks_of_modules[entry_name]
241
- blocks_params_size = self.blocks_of_modules_sizes[entry_name]
242
- else:
243
- blocks_params = []
244
- self.blocks_of_modules[entry_name] = blocks_params
245
- blocks_params_size = 0
246
- if blocks_name != None:
247
- prev_entry_name = None if prev_block_name == None else model_name + "/" + prev_block_name
248
- self.prev_blocks_names[entry_name] = prev_entry_name
249
- if not prev_block_name == None:
250
- self.next_blocks_names[prev_entry_name] = entry_name
251
-
252
- for p in submodule.parameters(recurse=False):
253
- blocks_params.append(p)
254
- if isinstance(p, AffineQuantizedTensor):
255
- blocks_params_size += p.tensor_impl.float8_data.nbytes
256
- blocks_params_size += p.tensor_impl.scale.nbytes
257
- else:
258
- blocks_params_size += p.data.nbytes
259
-
260
- for p in submodule.buffers(recurse=False):
261
- blocks_params.append(p)
262
- blocks_params_size += p.data.nbytes
263
-
264
- self.blocks_of_modules_sizes[entry_name] = blocks_params_size
265
-
266
- return blocks_params_size
267
-
268
- def can_model_be_cotenant(self, model_name):
269
- cotenants_map = {
270
- "text_encoder": ["vae", "text_encoder_2"],
271
- "text_encoder_2": ["vae", "text_encoder"],
272
- }
273
- potential_cotenants = cotenants_map.get(model_name, None)
274
- if potential_cotenants is None:
275
- return False
276
- for existing_cotenant in self.active_models_ids:
277
- if existing_cotenant not in potential_cotenants:
278
- return False
279
- return True
280
-
281
- @torch.compiler.disable()
282
- def gpu_load_blocks(self, model_name, blocks_name, async_load=False):
283
- if blocks_name != None:
284
- self.loaded_blocks[model_name] = blocks_name
285
-
286
- def cpu_to_gpu(stream_to_use, blocks_params, record_for_stream=None):
287
- with torch.cuda.stream(stream_to_use):
288
- for p in blocks_params:
289
- if isinstance(p, AffineQuantizedTensor):
290
- p.tensor_impl.float8_data = p.tensor_impl.float8_data.cuda(
291
- non_blocking=True, device=self.device_id
292
- )
293
- p.tensor_impl.scale = p.tensor_impl.scale.cuda(non_blocking=True, device=self.device_id)
294
- else:
295
- p.data = p.data.cuda(non_blocking=True, device=self.device_id)
296
-
297
- if record_for_stream != None:
298
- if isinstance(p, AffineQuantizedTensor):
299
- p.tensor_impl.float8_data.record_stream(record_for_stream)
300
- p.tensor_impl.scale.record_stream(record_for_stream)
301
- else:
302
- p.data.record_stream(record_for_stream)
303
-
304
- entry_name = model_name if blocks_name is None else model_name + "/" + blocks_name
305
- if self.verboseLevel >= 2:
306
- model = self.models[model_name]
307
- model_name = model._get_name()
308
- print(f"Loading model {entry_name} ({model_name}) in GPU")
309
-
310
- if self.async_transfers and blocks_name != None:
311
- first = self.prev_blocks_names[entry_name] == None
312
- next_blocks_entry = self.next_blocks_names[entry_name] if entry_name in self.next_blocks_names else None
313
- if first:
314
- cpu_to_gpu(torch.cuda.current_stream(), self.blocks_of_modules[entry_name])
315
- torch.cuda.synchronize()
316
-
317
- if next_blocks_entry != None:
318
- cpu_to_gpu(self.transfer_stream, self.blocks_of_modules[next_blocks_entry])
319
-
320
- else:
321
- cpu_to_gpu(self.default_stream, self.blocks_of_modules[entry_name])
322
- torch.cuda.synchronize()
323
-
324
- @torch.compiler.disable()
325
- def gpu_unload_blocks(self, model_name, blocks_name):
326
- if blocks_name != None:
327
- self.loaded_blocks[model_name] = None
328
-
329
- blocks_name = model_name if blocks_name is None else model_name + "/" + blocks_name
330
-
331
- if self.verboseLevel >= 2:
332
- model = self.models[model_name]
333
- model_name = model._get_name()
334
- print(f"Unloading model {blocks_name} ({model_name}) from GPU")
335
-
336
- blocks_params = self.blocks_of_modules[blocks_name]
337
-
338
- if model_name in self.pinned_modules_data:
339
- pinned_parameters_data = self.pinned_modules_data[model_name]
340
- for p in blocks_params:
341
- if isinstance(p, AffineQuantizedTensor):
342
- data = pinned_parameters_data[p]
343
- p.tensor_impl.float8_data = data[0]
344
- p.tensor_impl.scale = data[1]
345
- else:
346
- p.data = pinned_parameters_data[p]
347
- else:
348
- for p in blocks_params:
349
- if isinstance(p, AffineQuantizedTensor):
350
- p.tensor_impl.float8_data = p.tensor_impl.float8_data.cpu()
351
- p.tensor_impl.scale = p.tensor_impl.scale.cpu()
352
- else:
353
- p.data = p.data.cpu()
354
-
355
- @torch.compiler.disable()
356
- def gpu_load(self, model_name):
357
- model = self.models[model_name]
358
- self.active_models.append(model)
359
- self.active_models_ids.append(model_name)
360
-
361
- self.gpu_load_blocks(model_name, None)
362
-
363
- # torch.cuda.current_stream().synchronize()
364
-
365
- @torch.compiler.disable()
366
- def unload_all(self, model_name: str):
367
- if len(self.active_models_ids) == 0 and self.last_run_model == model_name:
368
- self.last_run_model = model_name
369
- return
370
- for model_name in self.active_models_ids:
371
- self.gpu_unload_blocks(model_name, None)
372
- loaded_block = self.loaded_blocks[model_name]
373
- if loaded_block != None:
374
- self.gpu_unload_blocks(model_name, loaded_block)
375
- self.loaded_blocks[model_name] = None
376
-
377
- self.active_models = []
378
- self.active_models_ids = []
379
- self.active_subcaches = []
380
- torch.cuda.empty_cache()
381
- gc.collect()
382
- self.last_reserved_mem_check = time.time()
383
- self.last_run_model = model_name
384
-
385
- def move_args_to_gpu(self, *args, **kwargs):
386
- new_args = []
387
- new_kwargs = {}
388
- for arg in args:
389
- if torch.is_tensor(arg):
390
- if arg.dtype == torch.float32:
391
- arg = arg.to(torch.bfloat16).cuda(non_blocking=True, device=self.device_id)
392
- else:
393
- arg = arg.cuda(non_blocking=True, device=self.device_id)
394
- new_args.append(arg)
395
-
396
- for k in kwargs:
397
- arg = kwargs[k]
398
- if torch.is_tensor(arg):
399
- if arg.dtype == torch.float32:
400
- arg = arg.to(torch.bfloat16).cuda(non_blocking=True, device=self.device_id)
401
- else:
402
- arg = arg.cuda(non_blocking=True, device=self.device_id)
403
- new_kwargs[k] = arg
404
-
405
- return new_args, new_kwargs
406
-
407
- def ready_to_check_mem(self):
408
- if self.compile:
409
- return
410
- cur_clock = time.time()
411
- # can't check at each call if we can empty the cuda cache as quering the reserved memory value is a time consuming operation
412
- if (cur_clock - self.last_reserved_mem_check) < 0.200:
413
- return False
414
- self.last_reserved_mem_check = cur_clock
415
- return True
416
-
417
- def empty_cache_if_needed(self):
418
- mem_reserved = torch.cuda.memory_reserved()
419
- mem_threshold = 0.9 * self.device_mem_capacity
420
- if mem_reserved >= mem_threshold:
421
- mem_allocated = torch.cuda.memory_allocated()
422
- if mem_allocated <= 0.70 * mem_reserved:
423
- torch.cuda.empty_cache()
424
- tm = time.time()
425
- if self.verboseLevel >= 2:
426
- print(f"Empty Cuda cache at {tm}")
427
-
428
- def any_param_or_buffer(self, target_module: torch.nn.Module):
429
-
430
- for _ in target_module.parameters(recurse=False):
431
- return True
432
-
433
- for _ in target_module.buffers(recurse=False):
434
- return True
435
-
436
- return False
437
-
438
- def hook_me_light(self, target_module, model_name, blocks_name, previous_method, context):
439
-
440
- anyParam = self.any_param_or_buffer(target_module)
441
-
442
- def check_empty_cuda_cache(module, *args, **kwargs):
443
- if self.ready_to_check_mem():
444
- self.empty_cache_if_needed()
445
- return previous_method(*args, **kwargs)
446
-
447
- def load_module_blocks(module, *args, **kwargs):
448
- if blocks_name == None:
449
- if self.ready_to_check_mem():
450
- self.empty_cache_if_needed()
451
- else:
452
- loaded_block = self.loaded_blocks[model_name]
453
- if loaded_block == None or loaded_block != blocks_name:
454
- if loaded_block != None:
455
- self.gpu_unload_blocks(model_name, loaded_block)
456
- if self.ready_to_check_mem():
457
- self.empty_cache_if_needed()
458
- self.loaded_blocks[model_name] = blocks_name
459
- self.gpu_load_blocks(model_name, blocks_name)
460
- return previous_method(*args, **kwargs)
461
-
462
- if hasattr(target_module, "_mm_id"):
463
- orig_model_name = getattr(target_module, "_mm_id")
464
- if self.verboseLevel >= 2:
465
- print(
466
- f"Model '{model_name}' shares module '{target_module._get_name()}' with module '{orig_model_name}' "
467
- )
468
- assert not anyParam
469
- return
470
- setattr(target_module, "_mm_id", model_name)
471
-
472
- if blocks_name != None and anyParam:
473
- setattr(
474
- target_module,
475
- "forward",
476
- functools.update_wrapper(functools.partial(load_module_blocks, target_module), previous_method),
477
- )
478
- # print(f"new cache:{blocks_name}")
479
- else:
480
- setattr(
481
- target_module,
482
- "forward",
483
- functools.update_wrapper(functools.partial(check_empty_cuda_cache, target_module), previous_method),
484
- )
485
-
486
- def hook_me(self, target_module, model, model_name, module_id, previous_method):
487
- def check_change_module(module, *args, **kwargs):
488
- performEmptyCacheTest = False
489
- if not model_name in self.active_models_ids:
490
- new_model_name = getattr(module, "_mm_id")
491
- if not self.can_model_be_cotenant(new_model_name):
492
- self.unload_all(model_name)
493
- performEmptyCacheTest = False
494
- self.gpu_load(new_model_name)
495
- args, kwargs = self.move_args_to_gpu(*args, **kwargs)
496
- if performEmptyCacheTest:
497
- self.empty_cache_if_needed()
498
- return previous_method(*args, **kwargs)
499
-
500
- if hasattr(target_module, "_mm_id"):
501
- return
502
- setattr(target_module, "_mm_id", model_name)
503
-
504
- setattr(
505
- target_module,
506
- "forward",
507
- functools.update_wrapper(functools.partial(check_change_module, target_module), previous_method),
508
- )
509
-
510
- if not self.verboseLevel >= 1:
511
- return
512
-
513
- if module_id == None or module_id == "":
514
- model_name = model._get_name()
515
- print(f"Hooked in model '{model_name}' ({model_name})")
 
 
 
 
 
1
+ import functools
2
+ import gc
3
+ import os
4
+ import time
5
+ from dataclasses import dataclass
6
+
7
+ import torch
8
+ from diffusers.pipelines import DiffusionPipeline
9
+ from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor
10
+
11
+
12
+ @dataclass
13
+ class OffloadConfig:
14
+ # high_cpu_memory: Whether to use pinned memory for offload optimization. This can effectively prevent increased model offload latency caused by memory swapping.
15
+ high_cpu_memory: bool = True
16
+ # parameters_level: Whether to enable parameter-level offload. This further reduces VRAM requirements but may result in increased latency.
17
+ parameters_level: bool = False
18
+ # compiler_transformer: Whether to enable compilation optimization for the transformer.
19
+ compiler_transformer: bool = False
20
+ compiler_cache: str = "/tmp/compile_cache"
21
+
22
+
23
+ class HfHook:
24
+ def __init__(self):
25
+ device_id = os.environ.get("LOCAL_RANK", 0)
26
+ self.execution_device = f"cuda:{device_id}"
27
+
28
+ def detach_hook(self, module):
29
+ pass
30
+
31
+
32
+ class Offload:
33
+ def __init__(self) -> None:
34
+ self.active_models = []
35
+ self.active_models_ids = []
36
+ self.active_subcaches = {}
37
+ self.models = {}
38
+ self.verboseLevel = 0
39
+ self.models_to_quantize = []
40
+ self.pinned_modules_data = {}
41
+ self.blocks_of_modules = {}
42
+ self.blocks_of_modules_sizes = {}
43
+ self.compile = False
44
+ self.device_mem_capacity = torch.cuda.get_device_properties(0).total_memory
45
+ self.last_reserved_mem_check = 0
46
+ self.loaded_blocks = {}
47
+ self.prev_blocks_names = {}
48
+ self.next_blocks_names = {}
49
+ device_id = os.environ.get("LOCAL_RANK", 0)
50
+ self.device_id = f"cuda:{device_id}"
51
+ self.default_stream = torch.cuda.default_stream(self.device_id) # torch.cuda.current_stream()
52
+ self.transfer_stream = torch.cuda.Stream()
53
+ self.async_transfers = False
54
+ self.last_run_model = None
55
+
56
+ def check_empty_cuda_cache(self): # Now a method of Offload
57
+ if torch.cuda.is_available():
58
+ torch.cuda.empty_cache()
59
+
60
+ @classmethod
61
+ def offload(cls, pipeline: DiffusionPipeline, config: OffloadConfig = OffloadConfig()):
62
+ """
63
+ Enable offloading for multiple models in the pipeline, supporting video generation inference on user-level GPUs.
64
+ pipe: the pipeline object
65
+ config: offload strategy configuration
66
+ """
67
+ self = cls()
68
+ self.pinned_modules_data = {}
69
+ if config.parameters_level:
70
+ model_budgets = {
71
+ "transformer": 600 * 1024 * 1024,
72
+ "text_encoder": 3 * 1024 * 1024 * 1024,
73
+ "text_encoder_2": 3 * 1024 * 1024 * 1024,
74
+ }
75
+ self.async_transfers = True
76
+ else:
77
+ model_budgets = {}
78
+
79
+ device_id = os.getenv("LOCAL_RANK", 0)
80
+ torch.set_default_device(f"cuda:{device_id}")
81
+ pipeline.hf_device_map = torch.device(f"cuda:{device_id}")
82
+ pipe_or_dict_of_modules = pipeline.components
83
+ if config.compiler_transformer:
84
+ pipeline.transformer.to("cuda")
85
+ models = {
86
+ k: v
87
+ for k, v in pipe_or_dict_of_modules.items()
88
+ if isinstance(v, torch.nn.Module) and not (config.compiler_transformer and k == "transformer")
89
+ }
90
+ print_info = {k: type(v) for k, v in models.items()}
91
+ print(f"offload models: {print_info}")
92
+ if config.compiler_transformer:
93
+ pipeline.text_encoder.to("cpu")
94
+ pipeline.text_encoder_2.to("cpu")
95
+ torch.cuda.empty_cache()
96
+ pipeline.transformer.to("cuda")
97
+ pipeline.vae.to("cuda")
98
+
99
+ def move_text_encoder_to_gpu(pipe):
100
+ torch.cuda.empty_cache()
101
+ pipe.text_encoder.to("cuda")
102
+ pipe.text_encoder_2.to("cuda")
103
+
104
+ def move_text_encoder_to_cpu(pipe):
105
+ pipe.text_encoder.to("cpu")
106
+ pipe.text_encoder_2.to("cpu")
107
+ torch.cuda.empty_cache()
108
+
109
+ setattr(pipeline, "text_encoder_to_cpu", functools.partial(move_text_encoder_to_cpu, pipeline))
110
+ setattr(pipeline, "text_encoder_to_gpu", functools.partial(move_text_encoder_to_gpu, pipeline))
111
+
112
+ for k, module in pipe_or_dict_of_modules.items():
113
+ if isinstance(module, torch.nn.Module):
114
+ for submodule_name, submodule in module.named_modules():
115
+ if not hasattr(submodule, "_hf_hook"):
116
+ setattr(submodule, "_hf_hook", HfHook())
117
+ return self
118
+
119
+ sizeofbfloat16 = torch.bfloat16.itemsize
120
+ modelPinned = config.high_cpu_memory
121
+ # Pin in RAM models
122
+ # Calculate the VRAM requirements of the computational modules to determine whether parameters-level offload is necessary.
123
+ for model_name, curr_model in models.items():
124
+ curr_model.to("cpu").eval()
125
+ pinned_parameters_data = {}
126
+ current_model_size = 0
127
+ print(f"{model_name} move to pinned memory:{modelPinned}")
128
+ for p in curr_model.parameters():
129
+ if isinstance(p, AffineQuantizedTensor):
130
+ if not modelPinned and p.tensor_impl.scale.dtype == torch.float32:
131
+ p.tensor_impl.scale = p.tensor_impl.scale.to(torch.bfloat16)
132
+ current_model_size += torch.numel(p.tensor_impl.scale) * sizeofbfloat16
133
+ current_model_size += torch.numel(p.tensor_impl.float8_data) * sizeofbfloat16 / 2
134
+ if modelPinned:
135
+ p.tensor_impl.float8_data = p.tensor_impl.float8_data.pin_memory()
136
+ p.tensor_impl.scale = p.tensor_impl.scale.pin_memory()
137
+ pinned_parameters_data[p] = [p.tensor_impl.float8_data, p.tensor_impl.scale]
138
+ else:
139
+ p.data = p.data.to(torch.bfloat16) if p.data.dtype == torch.float32 else p.data.to(p.data.dtype)
140
+ current_model_size += torch.numel(p.data) * p.data.element_size()
141
+ if modelPinned:
142
+ p.data = p.data.pin_memory()
143
+ pinned_parameters_data[p] = p.data
144
+
145
+ for buffer in curr_model.buffers():
146
+ buffer.data = (
147
+ buffer.data.to(torch.bfloat16)
148
+ if buffer.data.dtype == torch.float32
149
+ else buffer.data.to(buffer.data.dtype)
150
+ )
151
+ current_model_size += torch.numel(buffer.data) * buffer.data.element_size()
152
+ if modelPinned:
153
+ buffer.data = buffer.data.pin_memory()
154
+
155
+ if model_name not in self.models:
156
+ self.models[model_name] = curr_model
157
+
158
+ curr_model_budget = model_budgets.get(model_name, 0)
159
+ if curr_model_budget > 0 and curr_model_budget > current_model_size:
160
+ model_budgets[model_name] = 0
161
+
162
+ if modelPinned:
163
+ pinned_buffers_data = {b: b.data for b in curr_model.buffers()}
164
+ pinned_parameters_data.update(pinned_buffers_data)
165
+ self.pinned_modules_data[model_name] = pinned_parameters_data
166
+ gc.collect()
167
+ torch.cuda.empty_cache()
168
+
169
+ # if config.compiler_transformer:
170
+ # module = pipeline.transformer
171
+ # print("wrap transformer forward")
172
+ # # gpu model wrap
173
+ # for submodule_name, submodule in module.named_modules():
174
+ # if not hasattr(submodule, "_hf_hook"):
175
+ # setattr(submodule, "_hf_hook", HfHook())
176
+ #
177
+ # forward_method = getattr(module, "forward")
178
+ #
179
+ # def wrap_unload_all(*args, **kwargs):
180
+ # self.unload_all("transformer")
181
+ # return forward_method(*args, **kwargs)
182
+ #
183
+ # setattr(module, "forward", functools.update_wrapper(wrap_unload_all, forward_method))
184
+
185
+ # wrap forward methods
186
+ for model_name, curr_model in models.items():
187
+ current_budget = model_budgets.get(model_name, 0)
188
+ current_size = 0
189
+ self.loaded_blocks[model_name] = None
190
+ cur_blocks_prefix, prev_blocks_name, cur_blocks_name, cur_blocks_seq = None, None, None, -1
191
+
192
+ for submodule_name, submodule in curr_model.named_modules():
193
+ # create a fake accelerate parameter so that the _execution_device property returns always "cuda"
194
+ if not hasattr(submodule, "_hf_hook"):
195
+ setattr(submodule, "_hf_hook", HfHook())
196
+
197
+ if not submodule_name:
198
+ continue
199
+
200
+ # usr parameters-level offload
201
+ if current_budget > 0:
202
+ if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
203
+ if cur_blocks_prefix == None:
204
+ cur_blocks_prefix = submodule_name + "."
205
+ else:
206
+ if not submodule_name.startswith(cur_blocks_prefix):
207
+ cur_blocks_prefix = submodule_name + "."
208
+ cur_blocks_name, cur_blocks_seq = None, -1
209
+ else:
210
+ if cur_blocks_prefix is not None:
211
+ if submodule_name.startswith(cur_blocks_prefix):
212
+ num = int(submodule_name[len(cur_blocks_prefix) :].split(".")[0])
213
+ if num != cur_blocks_seq and (cur_blocks_name == None or current_size > current_budget):
214
+ prev_blocks_name = cur_blocks_name
215
+ cur_blocks_name = cur_blocks_prefix + str(num)
216
+ cur_blocks_seq = num
217
+ else:
218
+ cur_blocks_prefix = None
219
+ prev_blocks_name = None
220
+ cur_blocks_name = None
221
+ cur_blocks_seq = -1
222
+
223
+ if hasattr(submodule, "forward"):
224
+ submodule_forward = getattr(submodule, "forward")
225
+ if not callable(submodule_forward):
226
+ print("***")
227
+ continue
228
+ if len(submodule_name.split(".")) == 1:
229
+ self.hook_me(submodule, curr_model, model_name, submodule_name, submodule_forward)
230
+ else:
231
+ self.hook_me_light(
232
+ submodule, model_name, cur_blocks_name, submodule_forward, context=submodule_name
233
+ )
234
+ current_size = self.add_module_to_blocks(model_name, cur_blocks_name, submodule, prev_blocks_name)
235
+
236
+ gc.collect()
237
+ torch.cuda.empty_cache()
238
+ return self
239
+
240
+ def add_module_to_blocks(self, model_name, blocks_name, submodule, prev_block_name):
241
+
242
+ entry_name = model_name if blocks_name is None else model_name + "/" + blocks_name
243
+ if entry_name in self.blocks_of_modules:
244
+ blocks_params = self.blocks_of_modules[entry_name]
245
+ blocks_params_size = self.blocks_of_modules_sizes[entry_name]
246
+ else:
247
+ blocks_params = []
248
+ self.blocks_of_modules[entry_name] = blocks_params
249
+ blocks_params_size = 0
250
+ if blocks_name != None:
251
+ prev_entry_name = None if prev_block_name == None else model_name + "/" + prev_block_name
252
+ self.prev_blocks_names[entry_name] = prev_entry_name
253
+ if not prev_block_name == None:
254
+ self.next_blocks_names[prev_entry_name] = entry_name
255
+
256
+ for p in submodule.parameters(recurse=False):
257
+ blocks_params.append(p)
258
+ if isinstance(p, AffineQuantizedTensor):
259
+ blocks_params_size += p.tensor_impl.float8_data.nbytes
260
+ blocks_params_size += p.tensor_impl.scale.nbytes
261
+ else:
262
+ blocks_params_size += p.data.nbytes
263
+
264
+ for p in submodule.buffers(recurse=False):
265
+ blocks_params.append(p)
266
+ blocks_params_size += p.data.nbytes
267
+
268
+ self.blocks_of_modules_sizes[entry_name] = blocks_params_size
269
+
270
+ return blocks_params_size
271
+
272
+ def can_model_be_cotenant(self, model_name):
273
+ cotenants_map = {
274
+ "text_encoder": ["vae", "text_encoder_2"],
275
+ "text_encoder_2": ["vae", "text_encoder"],
276
+ }
277
+ potential_cotenants = cotenants_map.get(model_name, None)
278
+ if potential_cotenants is None:
279
+ return False
280
+ for existing_cotenant in self.active_models_ids:
281
+ if existing_cotenant not in potential_cotenants:
282
+ return False
283
+ return True
284
+
285
+ @torch.compiler.disable()
286
+ def gpu_load_blocks(self, model_name, blocks_name, async_load=False):
287
+ if blocks_name != None:
288
+ self.loaded_blocks[model_name] = blocks_name
289
+
290
+ def cpu_to_gpu(stream_to_use, blocks_params, record_for_stream=None):
291
+ with torch.cuda.stream(stream_to_use):
292
+ for p in blocks_params:
293
+ if isinstance(p, AffineQuantizedTensor):
294
+ p.tensor_impl.float8_data = p.tensor_impl.float8_data.cuda(
295
+ non_blocking=True, device=self.device_id
296
+ )
297
+ p.tensor_impl.scale = p.tensor_impl.scale.cuda(non_blocking=True, device=self.device_id)
298
+ else:
299
+ p.data = p.data.cuda(non_blocking=True, device=self.device_id)
300
+
301
+ if record_for_stream != None:
302
+ if isinstance(p, AffineQuantizedTensor):
303
+ p.tensor_impl.float8_data.record_stream(record_for_stream)
304
+ p.tensor_impl.scale.record_stream(record_for_stream)
305
+ else:
306
+ p.data.record_stream(record_for_stream)
307
+
308
+ entry_name = model_name if blocks_name is None else model_name + "/" + blocks_name
309
+ if self.verboseLevel >= 2:
310
+ model = self.models[model_name]
311
+ model_name = model._get_name()
312
+ print(f"Loading model {entry_name} ({model_name}) in GPU")
313
+
314
+ if self.async_transfers and blocks_name != None:
315
+ first = self.prev_blocks_names[entry_name] == None
316
+ next_blocks_entry = self.next_blocks_names[entry_name] if entry_name in self.next_blocks_names else None
317
+ if first:
318
+ cpu_to_gpu(torch.cuda.current_stream(), self.blocks_of_modules[entry_name])
319
+ torch.cuda.synchronize()
320
+
321
+ if next_blocks_entry != None:
322
+ cpu_to_gpu(self.transfer_stream, self.blocks_of_modules[next_blocks_entry])
323
+
324
+ else:
325
+ cpu_to_gpu(self.default_stream, self.blocks_of_modules[entry_name])
326
+ torch.cuda.synchronize()
327
+
328
+ @torch.compiler.disable()
329
+ def gpu_unload_blocks(self, model_name, blocks_name):
330
+ if blocks_name != None:
331
+ self.loaded_blocks[model_name] = None
332
+
333
+ blocks_name = model_name if blocks_name is None else model_name + "/" + blocks_name
334
+
335
+ if self.verboseLevel >= 2:
336
+ model = self.models[model_name]
337
+ model_name = model._get_name()
338
+ print(f"Unloading model {blocks_name} ({model_name}) from GPU")
339
+
340
+ blocks_params = self.blocks_of_modules[blocks_name]
341
+
342
+ if model_name in self.pinned_modules_data:
343
+ pinned_parameters_data = self.pinned_modules_data[model_name]
344
+ for p in blocks_params:
345
+ if isinstance(p, AffineQuantizedTensor):
346
+ data = pinned_parameters_data[p]
347
+ p.tensor_impl.float8_data = data[0]
348
+ p.tensor_impl.scale = data[1]
349
+ else:
350
+ p.data = pinned_parameters_data[p]
351
+ else:
352
+ for p in blocks_params:
353
+ if isinstance(p, AffineQuantizedTensor):
354
+ p.tensor_impl.float8_data = p.tensor_impl.float8_data.cpu()
355
+ p.tensor_impl.scale = p.tensor_impl.scale.cpu()
356
+ else:
357
+ p.data = p.data.cpu()
358
+
359
+ @torch.compiler.disable()
360
+ def gpu_load(self, model_name):
361
+ model = self.models[model_name]
362
+ self.active_models.append(model)
363
+ self.active_models_ids.append(model_name)
364
+
365
+ self.gpu_load_blocks(model_name, None)
366
+
367
+ # torch.cuda.current_stream().synchronize()
368
+
369
+ @torch.compiler.disable()
370
+ def unload_all(self, model_name: str):
371
+ if len(self.active_models_ids) == 0 and self.last_run_model == model_name:
372
+ self.last_run_model = model_name
373
+ return
374
+ for model_name in self.active_models_ids:
375
+ self.gpu_unload_blocks(model_name, None)
376
+ loaded_block = self.loaded_blocks[model_name]
377
+ if loaded_block != None:
378
+ self.gpu_unload_blocks(model_name, loaded_block)
379
+ self.loaded_blocks[model_name] = None
380
+
381
+ self.active_models = []
382
+ self.active_models_ids = []
383
+ self.active_subcaches = []
384
+ torch.cuda.empty_cache()
385
+ gc.collect()
386
+ self.last_reserved_mem_check = time.time()
387
+ self.last_run_model = model_name
388
+
389
+ def move_args_to_gpu(self, *args, **kwargs):
390
+ new_args = []
391
+ new_kwargs = {}
392
+ for arg in args:
393
+ if torch.is_tensor(arg):
394
+ if arg.dtype == torch.float32:
395
+ arg = arg.to(torch.bfloat16).cuda(non_blocking=True, device=self.device_id)
396
+ else:
397
+ arg = arg.cuda(non_blocking=True, device=self.device_id)
398
+ new_args.append(arg)
399
+
400
+ for k in kwargs:
401
+ arg = kwargs[k]
402
+ if torch.is_tensor(arg):
403
+ if arg.dtype == torch.float32:
404
+ arg = arg.to(torch.bfloat16).cuda(non_blocking=True, device=self.device_id)
405
+ else:
406
+ arg = arg.cuda(non_blocking=True, device=self.device_id)
407
+ new_kwargs[k] = arg
408
+
409
+ return new_args, new_kwargs
410
+
411
+ def ready_to_check_mem(self):
412
+ if self.compile:
413
+ return
414
+ cur_clock = time.time()
415
+ # can't check at each call if we can empty the cuda cache as quering the reserved memory value is a time consuming operation
416
+ if (cur_clock - self.last_reserved_mem_check) < 0.200:
417
+ return False
418
+ self.last_reserved_mem_check = cur_clock
419
+ return True
420
+
421
+ def empty_cache_if_needed(self):
422
+ mem_reserved = torch.cuda.memory_reserved()
423
+ mem_threshold = 0.9 * self.device_mem_capacity
424
+ if mem_reserved >= mem_threshold:
425
+ mem_allocated = torch.cuda.memory_allocated()
426
+ if mem_allocated <= 0.70 * mem_reserved:
427
+ torch.cuda.empty_cache()
428
+ tm = time.time()
429
+ if self.verboseLevel >= 2:
430
+ print(f"Empty Cuda cache at {tm}")
431
+
432
+ def any_param_or_buffer(self, target_module: torch.nn.Module):
433
+
434
+ for _ in target_module.parameters(recurse=False):
435
+ return True
436
+
437
+ for _ in target_module.buffers(recurse=False):
438
+ return True
439
+
440
+ return False
441
+
442
+ def hook_me_light(self, target_module, model_name, blocks_name, previous_method, context):
443
+
444
+ anyParam = self.any_param_or_buffer(target_module)
445
+
446
+ def check_empty_cuda_cache(module, *args, **kwargs):
447
+ if self.ready_to_check_mem():
448
+ self.empty_cache_if_needed()
449
+ return previous_method(*args, **kwargs)
450
+
451
+ def load_module_blocks(module, *args, **kwargs):
452
+ if blocks_name == None:
453
+ if self.ready_to_check_mem():
454
+ self.empty_cache_if_needed()
455
+ else:
456
+ loaded_block = self.loaded_blocks[model_name]
457
+ if loaded_block == None or loaded_block != blocks_name:
458
+ if loaded_block != None:
459
+ self.gpu_unload_blocks(model_name, loaded_block)
460
+ if self.ready_to_check_mem():
461
+ self.empty_cache_if_needed()
462
+ self.loaded_blocks[model_name] = blocks_name
463
+ self.gpu_load_blocks(model_name, blocks_name)
464
+ return previous_method(*args, **kwargs)
465
+
466
+ if hasattr(target_module, "_mm_id"):
467
+ orig_model_name = getattr(target_module, "_mm_id")
468
+ if self.verboseLevel >= 2:
469
+ print(
470
+ f"Model '{model_name}' shares module '{target_module._get_name()}' with module '{orig_model_name}' "
471
+ )
472
+ assert not anyParam
473
+ return
474
+ setattr(target_module, "_mm_id", model_name)
475
+
476
+ if blocks_name != None and anyParam:
477
+ setattr(
478
+ target_module,
479
+ "forward",
480
+ functools.update_wrapper(functools.partial(load_module_blocks, target_module), previous_method),
481
+ )
482
+ # print(f"new cache:{blocks_name}")
483
+ else:
484
+ setattr(
485
+ target_module,
486
+ "forward",
487
+ functools.update_wrapper(functools.partial(check_empty_cuda_cache, target_module), previous_method),
488
+ )
489
+
490
+ def hook_me(self, target_module, model, model_name, module_id, previous_method):
491
+ def check_change_module(module, *args, **kwargs):
492
+ performEmptyCacheTest = False
493
+ if not model_name in self.active_models_ids:
494
+ new_model_name = getattr(module, "_mm_id")
495
+ if not self.can_model_be_cotenant(new_model_name):
496
+ self.unload_all(model_name)
497
+ performEmptyCacheTest = False
498
+ self.gpu_load(new_model_name)
499
+ args, kwargs = self.move_args_to_gpu(*args, **kwargs)
500
+ if performEmptyCacheTest:
501
+ self.empty_cache_if_needed()
502
+ return previous_method(*args, **kwargs)
503
+
504
+ if hasattr(target_module, "_mm_id"):
505
+ return
506
+ setattr(target_module, "_mm_id", model_name)
507
+
508
+ setattr(
509
+ target_module,
510
+ "forward",
511
+ functools.update_wrapper(functools.partial(check_change_module, target_module), previous_method),
512
+ )
513
+
514
+ if not self.verboseLevel >= 1:
515
+ return
516
+
517
+ if module_id == None or module_id == "":
518
+ model_name = model._get_name()
519
+ print(f"Hooked in model '{model_name}' ({model_name})")