Spaces:
Running
Running
from typing import Any, Dict, List, Optional, Tuple, Union | |
import torch | |
from diffusers import DiffusionPipeline | |
from diffusers.configuration_utils import FrozenDict | |
from PIL.Image import Image | |
from ..logging import get_logger | |
from ..parallel import ParallelBackendEnum | |
from ..processors import ProcessorMixin | |
from ..typing import ArtifactType, SchedulerType, TokenizerType | |
from ..utils import resolve_component_cls | |
logger = get_logger() | |
# TODO(aryan): we most likely don't need this. take a look after refactoring more | |
# fmt: off | |
IGNORE_KEYS_FOR_COLLATION = {"height", "width", "num_frames", "frame_rate", "rope_interpolation_scale", "return_dict", "attention_kwargs", "cross_attention_kwargs", "joint_attention_kwargs", "latents_mean", "latents_std"} | |
# fmt: on | |
class ModelSpecification: | |
r""" | |
The ModelSpecification class is an interface to be used for Diffusion training recipes. It provides | |
loose structure about how to organize the code for training. The trainer implementations will | |
make use of this interface to load models, prepare conditions, prepare latents, forward pass, etc. | |
""" | |
def __init__( | |
self, | |
pretrained_model_name_or_path: Optional[str] = None, | |
tokenizer_id: Optional[str] = None, | |
tokenizer_2_id: Optional[str] = None, | |
tokenizer_3_id: Optional[str] = None, | |
text_encoder_id: Optional[str] = None, | |
text_encoder_2_id: Optional[str] = None, | |
text_encoder_3_id: Optional[str] = None, | |
transformer_id: Optional[str] = None, | |
vae_id: Optional[str] = None, | |
text_encoder_dtype: torch.dtype = torch.bfloat16, | |
text_encoder_2_dtype: torch.dtype = torch.bfloat16, | |
text_encoder_3_dtype: torch.dtype = torch.bfloat16, | |
transformer_dtype: torch.dtype = torch.bfloat16, | |
vae_dtype: str = torch.bfloat16, | |
revision: Optional[str] = None, | |
cache_dir: Optional[str] = None, | |
condition_model_processors: List[ProcessorMixin] = None, | |
latent_model_processors: List[ProcessorMixin] = None, | |
) -> None: | |
self.pretrained_model_name_or_path = pretrained_model_name_or_path | |
self.tokenizer_id = tokenizer_id | |
self.tokenizer_2_id = tokenizer_2_id | |
self.tokenizer_3_id = tokenizer_3_id | |
self.text_encoder_id = text_encoder_id | |
self.text_encoder_2_id = text_encoder_2_id | |
self.text_encoder_3_id = text_encoder_3_id | |
self.transformer_id = transformer_id | |
self.vae_id = vae_id | |
self.text_encoder_dtype = text_encoder_dtype | |
self.text_encoder_2_dtype = text_encoder_2_dtype | |
self.text_encoder_3_dtype = text_encoder_3_dtype | |
self.transformer_dtype = transformer_dtype | |
self.vae_dtype = vae_dtype | |
self.revision = revision | |
self.cache_dir = cache_dir | |
self.condition_model_processors = condition_model_processors or [] | |
self.latent_model_processors = latent_model_processors or [] | |
self.transformer_config: Dict[str, Any] = None | |
self.vae_config: Dict[str, Any] = None | |
self._load_configs() | |
# TODO(aryan): revisit how to do this better without user having to worry about it | |
def _resolution_dim_keys(self) -> Dict[str, Tuple[int, ...]]: | |
raise NotImplementedError( | |
f"ModelSpecification::_resolution_dim_keys is not implemented for {self.__class__.__name__}" | |
) | |
def load_condition_models(self) -> Dict[str, torch.nn.Module]: | |
raise NotImplementedError( | |
f"ModelSpecification::load_condition_models is not implemented for {self.__class__.__name__}" | |
) | |
def load_latent_models(self) -> Dict[str, torch.nn.Module]: | |
raise NotImplementedError( | |
f"ModelSpecification::load_latent_models is not implemented for {self.__class__.__name__}" | |
) | |
def load_diffusion_models(self) -> Dict[str, Union[torch.nn.Module]]: | |
raise NotImplementedError( | |
f"ModelSpecification::load_diffusion_models is not implemented for {self.__class__.__name__}" | |
) | |
def load_pipeline( | |
self, | |
tokenizer: Optional[TokenizerType] = None, | |
tokenizer_2: Optional[TokenizerType] = None, | |
tokenizer_3: Optional[TokenizerType] = None, | |
text_encoder: Optional[torch.nn.Module] = None, | |
text_encoder_2: Optional[torch.nn.Module] = None, | |
text_encoder_3: Optional[torch.nn.Module] = None, | |
transformer: Optional[torch.nn.Module] = None, | |
vae: Optional[torch.nn.Module] = None, | |
scheduler: Optional[SchedulerType] = None, | |
enable_slicing: bool = False, | |
enable_tiling: bool = False, | |
enable_model_cpu_offload: bool = False, | |
training: bool = False, | |
**kwargs, | |
) -> DiffusionPipeline: | |
raise NotImplementedError( | |
f"ModelSpecification::load_pipeline is not implemented for {self.__class__.__name__}" | |
) | |
def prepare_conditions(self, **kwargs) -> Dict[str, Any]: | |
for processor in self.condition_model_processors: | |
result = processor(**kwargs) | |
result_keys = set(result.keys()) | |
repeat_keys = result_keys.intersection(kwargs.keys()) | |
if repeat_keys: | |
logger.warning( | |
f"Processor {processor.__class__.__name__} returned keys that already exist in " | |
f"conditions: {repeat_keys}. Overwriting the existing values, but this may not " | |
f"be intended. Please rename the keys in the processor to avoid conflicts." | |
) | |
kwargs.update(result) | |
return kwargs | |
def prepare_latents(self, **kwargs) -> Dict[str, Any]: | |
for processor in self.latent_model_processors: | |
result = processor(**kwargs) | |
result_keys = set(result.keys()) | |
repeat_keys = result_keys.intersection(kwargs.keys()) | |
if repeat_keys: | |
logger.warning( | |
f"Processor {processor.__class__.__name__} returned keys that already exist in " | |
f"conditions: {repeat_keys}. Overwriting the existing values, but this may not " | |
f"be intended. Please rename the keys in the processor to avoid conflicts." | |
) | |
kwargs.update(result) | |
return kwargs | |
def collate_conditions(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: | |
keys = list(data[0].keys()) | |
collated_data = {} | |
for key in keys: | |
if key in IGNORE_KEYS_FOR_COLLATION: | |
collated_data[key] = data[0][key] | |
continue | |
collated_d = [d[key] for d in data] | |
if isinstance(collated_d[0], torch.Tensor): | |
collated_d = torch.cat(collated_d) | |
collated_data[key] = collated_d | |
return collated_data | |
def collate_latents(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: | |
keys = list(data[0].keys()) | |
collated_data = {} | |
for key in keys: | |
if key in IGNORE_KEYS_FOR_COLLATION: | |
collated_data[key] = data[0][key] | |
continue | |
collated_d = [d[key] for d in data] | |
# TODO(aryan): Support multi-resolution collation | |
if isinstance(collated_d[0], torch.Tensor): | |
collated_d = torch.cat(collated_d) | |
collated_data[key] = collated_d | |
return collated_data | |
def forward( | |
self, transformer: torch.nn.Module, generator: Optional[torch.Generator] = None, **kwargs | |
) -> Dict[str, torch.Tensor]: | |
raise NotImplementedError(f"ModelSpecification::forward is not implemented for {self.__class__.__name__}") | |
def validation( | |
self, | |
pipeline: DiffusionPipeline, | |
prompt: Optional[str] = None, | |
image: Optional[Image] = None, | |
video: Optional[List[Image]] = None, | |
height: Optional[int] = None, | |
width: Optional[int] = None, | |
num_frames: Optional[int] = None, | |
frame_rate: Optional[int] = None, | |
generator: Optional[torch.Generator] = None, | |
) -> List[ArtifactType]: | |
raise NotImplementedError(f"ModelSpecification::validation is not implemented for {self.__class__.__name__}") | |
def _save_lora_weights( | |
self, | |
directory: str, | |
transformer: torch.nn.Module, | |
transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, | |
scheduler: Optional[SchedulerType] = None, | |
) -> None: | |
r""" | |
Save the lora state dicts of the model to the given directory. | |
This API is not backwards compatible and will be changed in near future. | |
""" | |
raise NotImplementedError( | |
f"ModelSpecification::save_lora_weights is not implemented for {self.__class__.__name__}" | |
) | |
def _save_model( | |
self, | |
directory: str, | |
transformer: torch.nn.Module, | |
transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, | |
scheduler: Optional[SchedulerType] = None, | |
) -> None: | |
r""" | |
Save the state dicts to the given directory. | |
This API is not backwards compatible and will be changed in near future. | |
""" | |
raise NotImplementedError(f"ModelSpecification::save_model is not implemented for {self.__class__.__name__}") | |
def apply_tensor_parallel( | |
self, | |
backend: ParallelBackendEnum, | |
device_mesh: torch.distributed.DeviceMesh, | |
text_encoder: torch.nn.Module, | |
text_encoder_2: torch.nn.Module, | |
text_encoder_3: torch.nn.Module, | |
transformer: torch.nn.Module, | |
vae: torch.nn.Module, | |
) -> None: | |
raise NotImplementedError( | |
f"ModelSpecification::apply_tensor_parallel is not implemented for {self.__class__.__name__}" | |
) | |
def _load_configs(self) -> None: | |
self._load_transformer_config() | |
self._load_vae_config() | |
def _load_transformer_config(self) -> None: | |
if self.transformer_id is not None: | |
transformer_cls = resolve_component_cls( | |
self.transformer_id, | |
component_name="_class_name", | |
filename="config.json", | |
revision=self.revision, | |
cache_dir=self.cache_dir, | |
) | |
self.transformer_config = transformer_cls.load_config( | |
self.transformer_id, revision=self.revision, cache_dir=self.cache_dir | |
) | |
else: | |
transformer_cls = resolve_component_cls( | |
self.pretrained_model_name_or_path, | |
component_name="transformer", | |
filename="model_index.json", | |
revision=self.revision, | |
cache_dir=self.cache_dir, | |
) | |
self.transformer_config = transformer_cls.load_config( | |
self.pretrained_model_name_or_path, | |
subfolder="transformer", | |
revision=self.revision, | |
cache_dir=self.cache_dir, | |
) | |
self.transformer_config = FrozenDict(**self.transformer_config) | |
def _load_vae_config(self) -> None: | |
if self.vae_id is not None: | |
vae_cls = resolve_component_cls( | |
self.vae_id, | |
component_name="_class_name", | |
filename="config.json", | |
revision=self.revision, | |
cache_dir=self.cache_dir, | |
) | |
self.vae_config = vae_cls.load_config(self.vae_id, revision=self.revision, cache_dir=self.cache_dir) | |
else: | |
vae_cls = resolve_component_cls( | |
self.pretrained_model_name_or_path, | |
component_name="vae", | |
filename="model_index.json", | |
revision=self.revision, | |
cache_dir=self.cache_dir, | |
) | |
self.vae_config = vae_cls.load_config( | |
self.pretrained_model_name_or_path, subfolder="vae", revision=self.revision, cache_dir=self.cache_dir | |
) | |
self.vae_config = FrozenDict(**self.vae_config) | |