diff --git "a/pipeline.py" "b/pipeline.py"
new file mode 100644--- /dev/null
+++ "b/pipeline.py"
@@ -0,0 +1,1903 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+import torch.nn.functional as F
+import torchsde
+
+from transformers import (
+    CLIPImageProcessor,
+    CLIPTextModel,
+    CLIPTextModelWithProjection,
+    CLIPTokenizer,
+    CLIPVisionModelWithProjection,
+)
+
+from diffusers.utils.import_utils import is_invisible_watermark_available
+
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import (
+    FromSingleFileMixin,
+    IPAdapterMixin,
+    StableDiffusionXLLoraLoaderMixin,
+    TextualInversionLoaderMixin,
+)
+from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
+from diffusers.models.attention_processor import (
+    AttnProcessor2_0,
+    LoRAAttnProcessor2_0,
+    LoRAXFormersAttnProcessor,
+    XFormersAttnProcessor,
+)
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+    USE_PEFT_BACKEND,
+    deprecate,
+    logging,
+    replace_example_docstring,
+    scale_lora_layers,
+    unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+
+
+if is_invisible_watermark_available():
+    from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+
+from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
+
+
+logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+    Examples:
+        ```py
+        >>> # pip install accelerate transformers safetensors diffusers
+
+        >>> import torch
+        >>> import numpy as np
+        >>> from PIL import Image
+
+        >>> from transformers import DPTFeatureExtractor, DPTForDepthEstimation
+        >>> from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL
+        >>> from diffusers.utils import load_image
+
+
+        >>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
+        >>> feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
+        >>> controlnet = ControlNetModel.from_pretrained(
+        ...     "diffusers/controlnet-depth-sdxl-1.0-small",
+        ...     variant="fp16",
+        ...     use_safetensors=True,
+        ...     torch_dtype=torch.float16,
+        ... ).to("cuda")
+        >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to("cuda")
+        >>> pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
+        ...     "stabilityai/stable-diffusion-xl-base-1.0",
+        ...     controlnet=controlnet,
+        ...     vae=vae,
+        ...     variant="fp16",
+        ...     use_safetensors=True,
+        ...     torch_dtype=torch.float16,
+        ... ).to("cuda")
+        >>> pipe.enable_model_cpu_offload()
+
+
+        >>> def get_depth_map(image):
+        ...     image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
+        ...     with torch.no_grad(), torch.autocast("cuda"):
+        ...         depth_map = depth_estimator(image).predicted_depth
+
+        ...     depth_map = torch.nn.functional.interpolate(
+        ...         depth_map.unsqueeze(1),
+        ...         size=(1024, 1024),
+        ...         mode="bicubic",
+        ...         align_corners=False,
+        ...     )
+        ...     depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
+        ...     depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
+        ...     depth_map = (depth_map - depth_min) / (depth_max - depth_min)
+        ...     image = torch.cat([depth_map] * 3, dim=1)
+        ...     image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
+        ...     image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
+        ...     return image
+
+
+        >>> prompt = "A robot, 4k photo"
+        >>> image = load_image(
+        ...     "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
+        ...     "/kandinsky/cat.png"
+        ... ).resize((1024, 1024))
+        >>> controlnet_conditioning_scale = 0.5  # recommended for good generalization
+        >>> depth_image = get_depth_map(image)
+
+        >>> images = pipe(
+        ...     prompt,
+        ...     image=image,
+        ...     control_image=depth_image,
+        ...     strength=0.99,
+        ...     num_inference_steps=50,
+        ...     controlnet_conditioning_scale=controlnet_conditioning_scale,
+        ... ).images
+        >>> images[0].save(f"robot_cat.png")
+        ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+    encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+    if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+        return encoder_output.latent_dist.sample(generator)
+    elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+        return encoder_output.latent_dist.mode()
+    elif hasattr(encoder_output, "latents"):
+        return encoder_output.latents
+    else:
+        raise AttributeError("Could not access latents of provided encoder_output")
+
+class BatchedBrownianTree:
+    """A wrapper around torchsde.BrownianTree that enables batches of entropy."""
+
+    def __init__(self, x, t0, t1, seed=None, **kwargs):
+        self.cpu_tree = True
+        if "cpu" in kwargs:
+            self.cpu_tree = kwargs.pop("cpu")
+        t0, t1, self.sign = self.sort(t0, t1)
+        w0 = kwargs.get('w0', torch.zeros_like(x))
+        if seed is None:
+            seed = torch.randint(0, 2 ** 63 - 1, []).item()
+        self.batched = True
+        try:
+            assert len(seed) == x.shape[0]
+            w0 = w0[0]
+        except TypeError:
+            seed = [seed]
+            self.batched = False
+        if self.cpu_tree:
+            self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed]
+        else:
+            self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
+
+    @staticmethod
+    def sort(a, b):
+        return (a, b, 1) if a < b else (b, a, -1)
+
+    def __call__(self, t0, t1):
+        t0, t1, sign = self.sort(t0, t1)
+        if self.cpu_tree:
+            w = torch.stack([tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees]) * (self.sign * sign)
+        else:
+            w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
+
+        return w if self.batched else w[0]
+
+
+class BrownianTreeNoiseSampler:
+    """A noise sampler backed by a torchsde.BrownianTree.
+
+    Args:
+        x (Tensor): The tensor whose shape, device and dtype to use to generate
+            random samples.
+        sigma_min (float): The low end of the valid interval.
+        sigma_max (float): The high end of the valid interval.
+        seed (int or List[int]): The random seed. If a list of seeds is
+            supplied instead of a single integer, then the noise sampler will
+            use one BrownianTree per batch item, each with its own seed.
+        transform (callable): A function that maps sigma to the sampler's
+            internal timestep.
+    """
+
+    def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False):
+        self.transform = transform
+        t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
+        self.tree = BatchedBrownianTree(x, t0, t1, seed, cpu=cpu)
+
+    def __call__(self, sigma, sigma_next):
+        t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
+        return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
+
+
+class OmniZeroPipeline(
+    DiffusionPipeline,
+    StableDiffusionMixin,
+    TextualInversionLoaderMixin,
+    StableDiffusionXLLoraLoaderMixin,
+    FromSingleFileMixin,
+    IPAdapterMixin,
+):
+    r"""
+    Pipeline for image-to-image generation using Stable Diffusion XL with ControlNet guidance.
+
+    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+    The pipeline also inherits the following loading methods:
+        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+        - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+        - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+    Args:
+        vae ([`AutoencoderKL`]):
+            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+        text_encoder ([`CLIPTextModel`]):
+            Frozen text-encoder. Stable Diffusion uses the text portion of
+            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+        text_encoder_2 ([` CLIPTextModelWithProjection`]):
+            Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
+            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+            specifically the
+            [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+            variant.
+        tokenizer (`CLIPTokenizer`):
+            Tokenizer of class
+            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+        tokenizer_2 (`CLIPTokenizer`):
+            Second Tokenizer of class
+            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+        controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
+            Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets
+            as a list, the outputs from each ControlNet are added together to create one combined additional
+            conditioning.
+        scheduler ([`SchedulerMixin`]):
+            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+        requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`):
+            Whether the `unet` requires an `aesthetic_score` condition to be passed during inference. Also see the
+            config of `stabilityai/stable-diffusion-xl-refiner-1-0`.
+        force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
+            Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
+            `stabilityai/stable-diffusion-xl-base-1-0`.
+        add_watermarker (`bool`, *optional*):
+            Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
+            watermark output images. If not defined, it will default to True if the package is installed, otherwise no
+            watermarker will be used.
+        feature_extractor ([`~transformers.CLIPImageProcessor`]):
+            A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
+    """
+
+    model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
+    _optional_components = [
+        "tokenizer",
+        "tokenizer_2",
+        "text_encoder",
+        "text_encoder_2",
+        "feature_extractor",
+        "image_encoder",
+    ]
+    _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+    def __init__(
+        self,
+        vae: AutoencoderKL,
+        text_encoder: CLIPTextModel,
+        text_encoder_2: CLIPTextModelWithProjection,
+        tokenizer: CLIPTokenizer,
+        tokenizer_2: CLIPTokenizer,
+        unet: UNet2DConditionModel,
+        controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
+        scheduler: KarrasDiffusionSchedulers,
+        requires_aesthetics_score: bool = False,
+        force_zeros_for_empty_prompt: bool = True,
+        add_watermarker: Optional[bool] = None,
+        feature_extractor: CLIPImageProcessor = None,
+        image_encoder: CLIPVisionModelWithProjection = None,
+    ):
+        super().__init__()
+
+        if isinstance(controlnet, (list, tuple)):
+            controlnet = MultiControlNetModel(controlnet)
+
+        self.register_modules(
+            vae=vae,
+            text_encoder=text_encoder,
+            text_encoder_2=text_encoder_2,
+            tokenizer=tokenizer,
+            tokenizer_2=tokenizer_2,
+            unet=unet,
+            controlnet=controlnet,
+            scheduler=scheduler,
+            feature_extractor=feature_extractor,
+            image_encoder=image_encoder,
+        )
+        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
+        self.control_image_processor = VaeImageProcessor(
+            vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
+        )
+        add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
+
+        if add_watermarker:
+            self.watermark = StableDiffusionXLWatermarker()
+        else:
+            self.watermark = None
+
+        self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+        self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
+
+        self.ays_noise_sigmas = {"SD1": [14.6146412293, 6.4745760956,  3.8636745985,  2.6946151520, 1.8841921177,  1.3943805092,  0.9642583904,  0.6523686016, 0.3977456272,  0.1515232662,  0.0291671582],
+                "SDXL":[14.6146412293, 6.3184485287,  3.7681790315,  2.1811480769, 1.3405244945,  0.8620721141,  0.5550693289,  0.3798540708, 0.2332364134,  0.1114188177,  0.0291671582],
+                "SVD": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.002]}
+        
+    @staticmethod
+    def _loglinear_interp(t_steps, num_steps):
+        xs = np.linspace(0, 1, len(t_steps))
+        ys = np.log(t_steps[::-1])
+
+        new_xs = np.linspace(0, 1, num_steps)
+        new_ys = np.interp(new_xs, xs, ys)
+
+        return np.exp(new_ys)[::-1].copy()
+    
+    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
+    def encode_prompt(
+        self,
+        prompt: str,
+        prompt_2: Optional[str] = None,
+        device: Optional[torch.device] = None,
+        num_images_per_prompt: int = 1,
+        do_classifier_free_guidance: bool = True,
+        negative_prompt: Optional[str] = None,
+        negative_prompt_2: Optional[str] = None,
+        prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+        lora_scale: Optional[float] = None,
+        clip_skip: Optional[int] = None,
+    ):
+        r"""
+        Encodes the prompt into text encoder hidden states.
+
+        Args:
+            prompt (`str` or `List[str]`, *optional*):
+                prompt to be encoded
+            prompt_2 (`str` or `List[str]`, *optional*):
+                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+                used in both text-encoders
+            device: (`torch.device`):
+                torch device
+            num_images_per_prompt (`int`):
+                number of images that should be generated per prompt
+            do_classifier_free_guidance (`bool`):
+                whether to use classifier free guidance or not
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. If not defined, one has to pass
+                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+                less than `1`).
+            negative_prompt_2 (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+                `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+            prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+                provided, text embeddings will be generated from `prompt` input argument.
+            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+                argument.
+            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+                If not provided, pooled text embeddings will be generated from `prompt` input argument.
+            negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+                input argument.
+            lora_scale (`float`, *optional*):
+                A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+            clip_skip (`int`, *optional*):
+                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+                the output of the pre-final layer will be used for computing the prompt embeddings.
+        """
+        device = device or self._execution_device
+
+        # set lora scale so that monkey patched LoRA
+        # function of text encoder can correctly access it
+        if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
+            self._lora_scale = lora_scale
+
+            # dynamically adjust the LoRA scale
+            if self.text_encoder is not None:
+                if not USE_PEFT_BACKEND:
+                    adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+                else:
+                    scale_lora_layers(self.text_encoder, lora_scale)
+
+            if self.text_encoder_2 is not None:
+                if not USE_PEFT_BACKEND:
+                    adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+                else:
+                    scale_lora_layers(self.text_encoder_2, lora_scale)
+
+        prompt = [prompt] if isinstance(prompt, str) else prompt
+
+        if prompt is not None:
+            batch_size = len(prompt)
+        else:
+            batch_size = prompt_embeds.shape[0]
+
+        # Define tokenizers and text encoders
+        tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+        text_encoders = (
+            [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+        )
+
+        if prompt_embeds is None:
+            prompt_2 = prompt_2 or prompt
+            prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+            # textual inversion: process multi-vector tokens if necessary
+            prompt_embeds_list = []
+            prompts = [prompt, prompt_2]
+            for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+                if isinstance(self, TextualInversionLoaderMixin):
+                    prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+                text_inputs = tokenizer(
+                    prompt,
+                    padding="max_length",
+                    max_length=tokenizer.model_max_length,
+                    truncation=True,
+                    return_tensors="pt",
+                )
+
+                text_input_ids = text_inputs.input_ids
+                untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+                if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+                    text_input_ids, untruncated_ids
+                ):
+                    removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+                    logger.warning(
+                        "The following part of your input was truncated because CLIP can only handle sequences up to"
+                        f" {tokenizer.model_max_length} tokens: {removed_text}"
+                    )
+
+                prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+
+                # We are only ALWAYS interested in the pooled output of the final text encoder
+                pooled_prompt_embeds = prompt_embeds[0]
+                if clip_skip is None:
+                    prompt_embeds = prompt_embeds.hidden_states[-2]
+                else:
+                    # "2" because SDXL always indexes from the penultimate layer.
+                    prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
+
+                prompt_embeds_list.append(prompt_embeds)
+
+            prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+        # get unconditional embeddings for classifier free guidance
+        zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+        if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+            negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+            negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+        elif do_classifier_free_guidance and negative_prompt_embeds is None:
+            negative_prompt = negative_prompt or ""
+            negative_prompt_2 = negative_prompt_2 or negative_prompt
+
+            # normalize str to list
+            negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+            negative_prompt_2 = (
+                batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
+            )
+
+            uncond_tokens: List[str]
+            if prompt is not None and type(prompt) is not type(negative_prompt):
+                raise TypeError(
+                    f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+                    f" {type(prompt)}."
+                )
+            elif batch_size != len(negative_prompt):
+                raise ValueError(
+                    f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+                    f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+                    " the batch size of `prompt`."
+                )
+            else:
+                uncond_tokens = [negative_prompt, negative_prompt_2]
+
+            negative_prompt_embeds_list = []
+            for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
+                if isinstance(self, TextualInversionLoaderMixin):
+                    negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
+
+                max_length = prompt_embeds.shape[1]
+                uncond_input = tokenizer(
+                    negative_prompt,
+                    padding="max_length",
+                    max_length=max_length,
+                    truncation=True,
+                    return_tensors="pt",
+                )
+
+                negative_prompt_embeds = text_encoder(
+                    uncond_input.input_ids.to(device),
+                    output_hidden_states=True,
+                )
+                # We are only ALWAYS interested in the pooled output of the final text encoder
+                negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+                negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+                negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+            negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+        if self.text_encoder_2 is not None:
+            prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+        else:
+            prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+        bs_embed, seq_len, _ = prompt_embeds.shape
+        # duplicate text embeddings for each generation per prompt, using mps friendly method
+        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+        if do_classifier_free_guidance:
+            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+            seq_len = negative_prompt_embeds.shape[1]
+
+            if self.text_encoder_2 is not None:
+                negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+            else:
+                negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+        pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+            bs_embed * num_images_per_prompt, -1
+        )
+        if do_classifier_free_guidance:
+            negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+                bs_embed * num_images_per_prompt, -1
+            )
+
+        if self.text_encoder is not None:
+            if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+                # Retrieve the original scale by scaling back the LoRA layers
+                unscale_lora_layers(self.text_encoder, lora_scale)
+
+        if self.text_encoder_2 is not None:
+            if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+                # Retrieve the original scale by scaling back the LoRA layers
+                unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+        return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+    # Copied from ..stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+    def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None, unconditional_noising_factor=1.0):
+        dtype = next(self.image_encoder.parameters()).dtype
+
+        needs_encoding = not isinstance(image, torch.Tensor)
+        if needs_encoding:
+            image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+        image = image.to(device=device, dtype=dtype)
+
+        avg_image = torch.mean(image, dim=0, keepdim=True).to(dtype=torch.float32)
+        seed = int(torch.sum(avg_image).item()) % 1000000007
+        torch.manual_seed(seed)
+        additional_noise_for_uncond = torch.rand_like(image) * unconditional_noising_factor
+
+        if output_hidden_states:
+            if needs_encoding:
+                image_encoded = self.image_encoder(image, output_hidden_states=True)
+                image_enc_hidden_states = image_encoded.hidden_states[-2]
+            else:
+                image_enc_hidden_states = image.unsqueeze(0).unsqueeze(0)
+            image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+          
+            if needs_encoding:
+                uncond_image_encoded = self.image_encoder(additional_noise_for_uncond, output_hidden_states=True)
+                uncond_image_enc_hidden_states = uncond_image_encoded.hidden_states[-2]
+            else:
+                uncond_image_enc_hidden_states = additional_noise_for_uncond.unsqueeze(0).unsqueeze(0)
+            uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+                num_images_per_prompt, dim=0
+            )
+            return image_enc_hidden_states, uncond_image_enc_hidden_states
+        else:
+            if needs_encoding:
+                image_encoded = self.image_encoder(image)
+                image_embeds = image_encoded.image_embeds
+            else:
+                image_embeds = image.unsqueeze(0).unsqueeze(0)
+            if needs_encoding:
+                uncond_image_encoded = self.image_encoder(additional_noise_for_uncond)
+                uncond_image_embeds = uncond_image_encoded.image_embeds
+            else:
+                uncond_image_embeds = additional_noise_for_uncond.unsqueeze(0).unsqueeze(0)
+
+            image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+            uncond_image_embeds = uncond_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+
+            return image_embeds, uncond_image_embeds
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+    def prepare_ip_adapter_image_embeds(
+        self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
+    ):
+        if ip_adapter_image_embeds is None:
+            if not isinstance(ip_adapter_image, list):
+                ip_adapter_image = [ip_adapter_image]
+
+            if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+                raise ValueError(
+                    f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+                )
+
+            image_embeds = []
+            for single_ip_adapter_image, image_proj_layer in zip(
+                ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+            ):
+                output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+                single_image_embeds, single_negative_image_embeds = self.encode_image(
+                    single_ip_adapter_image, device, 1, output_hidden_state
+                )
+                single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
+                single_negative_image_embeds = torch.stack(
+                    [single_negative_image_embeds] * num_images_per_prompt, dim=0
+                )
+
+                if do_classifier_free_guidance:
+                    single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+                    single_image_embeds = single_image_embeds.to(device)
+
+                image_embeds.append(single_image_embeds)
+        else:
+            repeat_dims = [1]
+            image_embeds = []
+            for single_image_embeds in ip_adapter_image_embeds:
+                if do_classifier_free_guidance:
+                    single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
+                    single_image_embeds = single_image_embeds.repeat(
+                        num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
+                    )
+                    single_negative_image_embeds = single_negative_image_embeds.repeat(
+                        num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
+                    )
+                    single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+                else:
+                    single_image_embeds = single_image_embeds.repeat(
+                        num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
+                    )
+                image_embeds.append(single_image_embeds)
+
+        return image_embeds
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+    def prepare_extra_step_kwargs(self, generator, eta):
+        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+        # and should be between [0, 1]
+
+        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+        extra_step_kwargs = {}
+        if accepts_eta:
+            extra_step_kwargs["eta"] = eta
+
+        # check if the scheduler accepts generator
+        accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+        if accepts_generator:
+            extra_step_kwargs["generator"] = generator
+        return extra_step_kwargs
+
+    def check_inputs(
+        self,
+        prompt,
+        prompt_2,
+        image,
+        strength,
+        num_inference_steps,
+        callback_steps,
+        negative_prompt=None,
+        negative_prompt_2=None,
+        prompt_embeds=None,
+        negative_prompt_embeds=None,
+        pooled_prompt_embeds=None,
+        negative_pooled_prompt_embeds=None,
+        ip_adapter_image=None,
+        ip_adapter_image_embeds=None,
+        controlnet_conditioning_scale=1.0,
+        control_guidance_start=0.0,
+        control_guidance_end=1.0,
+        callback_on_step_end_tensor_inputs=None,
+    ):
+        if strength < 0 or strength > 1:
+            raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+        if num_inference_steps is None:
+            raise ValueError("`num_inference_steps` cannot be None.")
+        elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0:
+            raise ValueError(
+                f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type"
+                f" {type(num_inference_steps)}."
+            )
+
+        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+            raise ValueError(
+                f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+                f" {type(callback_steps)}."
+            )
+
+        if callback_on_step_end_tensor_inputs is not None and not all(
+            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+        ):
+            raise ValueError(
+                f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+            )
+
+        if prompt is not None and prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+                " only forward one of the two."
+            )
+        elif prompt_2 is not None and prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+                " only forward one of the two."
+            )
+        elif prompt is None and prompt_embeds is None:
+            raise ValueError(
+                "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+            )
+        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+        elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+            raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+        if negative_prompt is not None and negative_prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+                f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+            )
+        elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+                f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+            )
+
+        if prompt_embeds is not None and negative_prompt_embeds is not None:
+            if prompt_embeds.shape != negative_prompt_embeds.shape:
+                raise ValueError(
+                    "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+                    f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+                    f" {negative_prompt_embeds.shape}."
+                )
+
+        if prompt_embeds is not None and pooled_prompt_embeds is None:
+            raise ValueError(
+                "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+            )
+
+        if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+            raise ValueError(
+                "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+            )
+
+        # `prompt` needs more sophisticated handling when there are multiple
+        # conditionings.
+        if isinstance(self.controlnet, MultiControlNetModel):
+            if isinstance(prompt, list):
+                logger.warning(
+                    f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
+                    " prompts. The conditionings will be fixed across the prompts."
+                )
+
+        # Check `image`
+        is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
+            self.controlnet, torch._dynamo.eval_frame.OptimizedModule
+        )
+        if (
+            isinstance(self.controlnet, ControlNetModel)
+            or is_compiled
+            and isinstance(self.controlnet._orig_mod, ControlNetModel)
+        ):
+            self.check_image(image, prompt, prompt_embeds)
+        elif (
+            isinstance(self.controlnet, MultiControlNetModel)
+            or is_compiled
+            and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
+        ):
+            if not isinstance(image, list):
+                raise TypeError("For multiple controlnets: `image` must be type `list`")
+
+            # When `image` is a nested list:
+            # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
+            elif any(isinstance(i, list) for i in image):
+                raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+            elif len(image) != len(self.controlnet.nets):
+                raise ValueError(
+                    f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
+                )
+
+            for image_ in image:
+                self.check_image(image_, prompt, prompt_embeds)
+        else:
+            assert False
+
+        # Check `controlnet_conditioning_scale`
+        if (
+            isinstance(self.controlnet, ControlNetModel)
+            or is_compiled
+            and isinstance(self.controlnet._orig_mod, ControlNetModel)
+        ):
+            if not isinstance(controlnet_conditioning_scale, float):
+                raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
+        elif (
+            isinstance(self.controlnet, MultiControlNetModel)
+            or is_compiled
+            and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
+        ):
+            if isinstance(controlnet_conditioning_scale, list):
+                if any(isinstance(i, list) for i in controlnet_conditioning_scale):
+                    raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+            elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
+                self.controlnet.nets
+            ):
+                raise ValueError(
+                    "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
+                    " the same length as the number of controlnets"
+                )
+        else:
+            assert False
+
+        if not isinstance(control_guidance_start, (tuple, list)):
+            control_guidance_start = [control_guidance_start]
+
+        if not isinstance(control_guidance_end, (tuple, list)):
+            control_guidance_end = [control_guidance_end]
+
+        if len(control_guidance_start) != len(control_guidance_end):
+            raise ValueError(
+                f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
+            )
+
+        if isinstance(self.controlnet, MultiControlNetModel):
+            if len(control_guidance_start) != len(self.controlnet.nets):
+                raise ValueError(
+                    f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
+                )
+
+        for start, end in zip(control_guidance_start, control_guidance_end):
+            if start >= end:
+                raise ValueError(
+                    f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
+                )
+            if start < 0.0:
+                raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
+            if end > 1.0:
+                raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
+
+        if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
+            raise ValueError(
+                "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
+            )
+
+        if ip_adapter_image_embeds is not None:
+            if not isinstance(ip_adapter_image_embeds, list):
+                raise ValueError(
+                    f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
+                )
+            elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
+                raise ValueError(
+                    f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
+                )
+
+    # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_image
+    def check_image(self, image, prompt, prompt_embeds):
+        image_is_pil = isinstance(image, PIL.Image.Image)
+        image_is_tensor = isinstance(image, torch.Tensor)
+        image_is_np = isinstance(image, np.ndarray)
+        image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
+        image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
+        image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
+
+        if (
+            not image_is_pil
+            and not image_is_tensor
+            and not image_is_np
+            and not image_is_pil_list
+            and not image_is_tensor_list
+            and not image_is_np_list
+        ):
+            raise TypeError(
+                f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
+            )
+
+        if image_is_pil:
+            image_batch_size = 1
+        else:
+            image_batch_size = len(image)
+
+        if prompt is not None and isinstance(prompt, str):
+            prompt_batch_size = 1
+        elif prompt is not None and isinstance(prompt, list):
+            prompt_batch_size = len(prompt)
+        elif prompt_embeds is not None:
+            prompt_batch_size = prompt_embeds.shape[0]
+
+        if image_batch_size != 1 and image_batch_size != prompt_batch_size:
+            raise ValueError(
+                f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
+            )
+
+    # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image
+    def prepare_control_image(
+        self,
+        image,
+        width,
+        height,
+        batch_size,
+        num_images_per_prompt,
+        device,
+        dtype,
+        do_classifier_free_guidance=False,
+        guess_mode=False,
+    ):
+        image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
+        image_batch_size = image.shape[0]
+
+        if image_batch_size == 1:
+            repeat_by = batch_size
+        else:
+            # image batch size is the same as prompt batch size
+            repeat_by = num_images_per_prompt
+
+        image = image.repeat_interleave(repeat_by, dim=0)
+
+        image = image.to(device=device, dtype=dtype)
+
+        if do_classifier_free_guidance and not guess_mode:
+            image = torch.cat([image] * 2)
+
+        return image
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
+    def get_timesteps(self, num_inference_steps, strength, device):
+        # get the original timestep using init_timestep
+        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+        t_start = max(num_inference_steps - init_timestep, 0)
+        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+        if hasattr(self.scheduler, "set_begin_index"):
+            self.scheduler.set_begin_index(t_start * self.scheduler.order)
+
+        return timesteps, num_inference_steps - t_start
+
+    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents
+    def prepare_latents(
+        self, image, timestep, batch_size, num_channels_latents, height, width, dtype, device, generator=None, add_noise=True, seed=None
+    ):
+
+        if image is None:
+            shape = (
+                batch_size,
+                num_channels_latents,
+                int(height) // self.vae_scale_factor,
+                int(width) // self.vae_scale_factor,
+            )
+            init_latents = torch.zeros(shape, device=device, dtype=dtype)
+        else:
+            if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+                raise ValueError(
+                    f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+                )
+
+            latents_mean = latents_std = None
+            if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
+                latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
+            if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
+                latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
+
+            # Offload text encoder if `enable_model_cpu_offload` was enabled
+            if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+                self.text_encoder_2.to("cpu")
+                torch.cuda.empty_cache()
+            
+            image = image.to(device=device, dtype=dtype)
+
+            if image.shape[1] == 4:
+                init_latents = image
+
+            else:
+                # make sure the VAE is in float32 mode, as it overflows in float16
+                if self.vae.config.force_upcast:
+                    image = image.float()
+                    self.vae.to(dtype=torch.float32)
+
+                if isinstance(generator, list) and len(generator) != batch_size:
+                    raise ValueError(
+                        f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+                        f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+                    )
+
+                elif isinstance(generator, list):
+                    init_latents = [
+                        retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+                        for i in range(batch_size)
+                    ]
+                    init_latents = torch.cat(init_latents, dim=0)
+                else:
+                    init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+                if self.vae.config.force_upcast:
+                    self.vae.to(dtype)
+
+                init_latents = init_latents.to(dtype)
+                if latents_mean is not None and latents_std is not None:
+                    latents_mean = latents_mean.to(device=self.device, dtype=dtype)
+                    latents_std = latents_std.to(device=self.device, dtype=dtype)
+                    init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std
+                else:
+                    init_latents = self.vae.config.scaling_factor * init_latents
+
+            if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+                # expand init_latents for batch_size
+                additional_image_per_prompt = batch_size // init_latents.shape[0]
+                init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
+            elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+                raise ValueError(
+                    f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+                )
+            else:
+                init_latents = torch.cat([init_latents], dim=0)
+
+        if add_noise:
+            return self.add_noise(init_latents, timestep, device, dtype, generator, seed)
+
+        latents = init_latents
+
+        return latents
+    
+    def add_noise(self, latents, timestep, device, dtype, generator=None, seed=None):
+        if seed is not None:
+            if device == "cpu":
+                generator = torch.manual_seed(seed)
+            else:
+                generator = torch.cuda.manual_seed(seed)
+        noise = torch.randn(
+            torch.Size(latents.shape),
+            dtype=torch.float32,
+            layout=torch.strided,
+            generator=generator,
+            device=device,
+        ).to(device)
+        new_latents = self.scheduler.add_noise(
+            latents.to(device), noise, timestep
+        )
+        return new_latents.to(device, dtype=dtype)
+
+    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
+    def _get_add_time_ids(
+        self,
+        original_size,
+        crops_coords_top_left,
+        target_size,
+        aesthetic_score,
+        negative_aesthetic_score,
+        negative_original_size,
+        negative_crops_coords_top_left,
+        negative_target_size,
+        dtype,
+        text_encoder_projection_dim=None,
+    ):
+        if self.config.requires_aesthetics_score:
+            add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
+            add_neg_time_ids = list(
+                negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,)
+            )
+        else:
+            add_time_ids = list(original_size + crops_coords_top_left + target_size)
+            add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
+
+        passed_add_embed_dim = (
+            self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
+        )
+        expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+        if (
+            expected_add_embed_dim > passed_add_embed_dim
+            and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim
+        ):
+            raise ValueError(
+                f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model."
+            )
+        elif (
+            expected_add_embed_dim < passed_add_embed_dim
+            and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim
+        ):
+            raise ValueError(
+                f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model."
+            )
+        elif expected_add_embed_dim != passed_add_embed_dim:
+            raise ValueError(
+                f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+            )
+
+        add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+        add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
+
+        return add_time_ids, add_neg_time_ids
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
+    def upcast_vae(self):
+        dtype = self.vae.dtype
+        self.vae.to(dtype=torch.float32)
+        use_torch_2_0_or_xformers = isinstance(
+            self.vae.decoder.mid_block.attentions[0].processor,
+            (
+                AttnProcessor2_0,
+                XFormersAttnProcessor,
+                LoRAXFormersAttnProcessor,
+                LoRAAttnProcessor2_0,
+            ),
+        )
+        # if xformers or torch_2_0 is used attention block does not need
+        # to be in float32 which can save lots of memory
+        if use_torch_2_0_or_xformers:
+            self.vae.post_quant_conv.to(dtype)
+            self.vae.decoder.conv_in.to(dtype)
+            self.vae.decoder.mid_block.to(dtype)
+
+    @property
+    def guidance_scale(self):
+        return self._guidance_scale
+
+    @property
+    def clip_skip(self):
+        return self._clip_skip
+
+    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+    # corresponds to doing no classifier free guidance.
+    @property
+    def do_classifier_free_guidance(self):
+        return self._guidance_scale > 1
+
+    @property
+    def cross_attention_kwargs(self):
+        return self._cross_attention_kwargs
+
+    @property
+    def num_timesteps(self):
+        return self._num_timesteps
+
+    @torch.no_grad()
+    @replace_example_docstring(EXAMPLE_DOC_STRING)
+    def __call__(
+        self,
+        prompt: Union[str, List[str]] = None,
+        prompt_2: Optional[Union[str, List[str]]] = None,
+        image: PipelineImageInput = None,
+        mask_image: PipelineImageInput = None,
+        i2i_mask_guidance_start: Optional[float] = 0.0,
+        i2i_mask_guidance_end: Optional[float] = 1.0,
+        control_image: PipelineImageInput = None,
+        control_mask = None,
+        identity_control_indices = None,
+        height: Optional[int] = None,
+        width: Optional[int] = None,
+        strength: float = 0.8,
+        num_inference_steps: int = 50,
+        timesteps: Optional[List[int]] = None,
+        sigmas: Optional[List[float]] = None,
+        guidance_scale: float = 5.0,
+        negative_prompt: Optional[Union[str, List[str]]] = None,
+        negative_prompt_2: Optional[Union[str, List[str]]] = None,
+        num_images_per_prompt: Optional[int] = 1,
+        eta: float = 0.0,
+        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+        seed: Optional[int] = None,
+        latents: Optional[torch.FloatTensor] = None,
+        prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+        ip_adapter_image: Optional[PipelineImageInput] = None,
+        ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
+        output_type: Optional[str] = "pil",
+        return_dict: bool = True,
+        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+        controlnet_conditioning_scale: Union[float, List[float]] = 0.8,
+        guess_mode: bool = False,
+        control_guidance_start: Union[float, List[float]] = 0.0,
+        control_guidance_end: Union[float, List[float]] = 1.0,
+        original_size: Tuple[int, int] = None,
+        crops_coords_top_left: Tuple[int, int] = (0, 0),
+        target_size: Tuple[int, int] = None,
+        negative_original_size: Optional[Tuple[int, int]] = None,
+        negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+        negative_target_size: Optional[Tuple[int, int]] = None,
+        aesthetic_score: float = 6.0,
+        negative_aesthetic_score: float = 2.5,
+        clip_skip: Optional[int] = None,
+        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+        **kwargs,
+    ):
+        r"""
+        Function invoked when calling the pipeline for generation.
+
+        Args:
+            prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+                instead.
+            prompt_2 (`str` or `List[str]`, *optional*):
+                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+                used in both text-encoders
+            image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
+                    `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
+                The initial image will be used as the starting point for the image generation process. Can also accept
+                image latents as `image`, if passing latents directly, it will not be encoded again.
+            control_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
+                    `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
+                The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
+                the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
+                also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
+                height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
+                specified in init, images must be passed as a list such that each element of the list can be correctly
+                batched for input to a single controlnet.
+            height (`int`, *optional*, defaults to the size of control_image):
+                The height in pixels of the generated image. Anything below 512 pixels won't work well for
+                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+                and checkpoints that are not specifically fine-tuned on low resolutions.
+            width (`int`, *optional*, defaults to the size of control_image):
+                The width in pixels of the generated image. Anything below 512 pixels won't work well for
+                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+                and checkpoints that are not specifically fine-tuned on low resolutions.
+            strength (`float`, *optional*, defaults to 0.8):
+                Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
+                starting point and more noise is added the higher the `strength`. The number of denoising steps depends
+                on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
+                process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
+                essentially ignores `image`.
+            num_inference_steps (`int`, *optional*, defaults to 50):
+                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+                expense of slower inference.
+            guidance_scale (`float`, *optional*, defaults to 7.5):
+                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+                `guidance_scale` is defined as `w` of equation 2. of [Imagen
+                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+                usually at the expense of lower image quality.
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. If not defined, one has to pass
+                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+                less than `1`).
+            negative_prompt_2 (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+                `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+            num_images_per_prompt (`int`, *optional*, defaults to 1):
+                The number of images to generate per prompt.
+            eta (`float`, *optional*, defaults to 0.0):
+                Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+                [`schedulers.DDIMScheduler`], will be ignored for others.
+            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+                to make generation deterministic.
+            latents (`torch.FloatTensor`, *optional*):
+                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+                tensor will ge generated by sampling using the supplied random `generator`.
+            prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+                provided, text embeddings will be generated from `prompt` input argument.
+            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+                argument.
+            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+                If not provided, pooled text embeddings will be generated from `prompt` input argument.
+            negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+                input argument.
+            ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+            ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
+                Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+                IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+                contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+                provided, embeddings are computed from the `ip_adapter_image` input argument.
+            output_type (`str`, *optional*, defaults to `"pil"`):
+                The output format of the generate image. Choose between
+                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+                plain tuple.
+            cross_attention_kwargs (`dict`, *optional*):
+                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+                `self.processor` in
+                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+            controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+                The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
+                to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
+                corresponding scale as a list.
+            guess_mode (`bool`, *optional*, defaults to `False`):
+                In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
+                you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
+            control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
+                The percentage of total steps at which the controlnet starts applying.
+            control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
+                The percentage of total steps at which the controlnet stops applying.
+            original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+                If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+                `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
+                explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+            crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+                `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+                `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+                `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+            target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+                For most cases, `target_size` should be set to the desired height and width of the generated image. If
+                not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
+                section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+            negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+                To negatively condition the generation process based on a specific image resolution. Part of SDXL's
+                micro-conditioning as explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+            negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+                To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
+                micro-conditioning as explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+            negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+                To negatively condition the generation process based on a target image resolution. It should be as same
+                as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+            aesthetic_score (`float`, *optional*, defaults to 6.0):
+                Used to simulate an aesthetic score of the generated image by influencing the positive text condition.
+                Part of SDXL's micro-conditioning as explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+            negative_aesthetic_score (`float`, *optional*, defaults to 2.5):
+                Part of SDXL's micro-conditioning as explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to
+                simulate an aesthetic score of the generated image by influencing the negative text condition.
+            clip_skip (`int`, *optional*):
+                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+                the output of the pre-final layer will be used for computing the prompt embeddings.
+            callback_on_step_end (`Callable`, *optional*):
+                A function that calls at the end of each denoising steps during the inference. The function is called
+                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+                `callback_on_step_end_tensor_inputs`.
+            callback_on_step_end_tensor_inputs (`List`, *optional*):
+                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+                `._callback_tensor_inputs` attribute of your pipeline class.
+
+        Examples:
+
+        Returns:
+            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple`
+            containing the output images.
+        """
+
+        callback = kwargs.pop("callback", None)
+        callback_steps = kwargs.pop("callback_steps", None)
+
+        if callback is not None:
+            deprecate(
+                "callback",
+                "1.0.0",
+                "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+            )
+        if callback_steps is not None:
+            deprecate(
+                "callback_steps",
+                "1.0.0",
+                "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+            )
+
+        controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+
+        # align format for control guidance
+        if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+            control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+        elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+            control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+        elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+            mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
+            control_guidance_start, control_guidance_end = (
+                mult * [control_guidance_start],
+                mult * [control_guidance_end],
+            )
+
+        # 1. Check inputs. Raise error if not correct
+        self.check_inputs(
+            prompt,
+            prompt_2,
+            control_image,
+            strength,
+            num_inference_steps,
+            callback_steps,
+            negative_prompt,
+            negative_prompt_2,
+            prompt_embeds,
+            negative_prompt_embeds,
+            pooled_prompt_embeds,
+            negative_pooled_prompt_embeds,
+            ip_adapter_image,
+            ip_adapter_image_embeds,
+            controlnet_conditioning_scale,
+            control_guidance_start,
+            control_guidance_end,
+            callback_on_step_end_tensor_inputs,
+        )
+
+        self._guidance_scale = guidance_scale
+        self._clip_skip = clip_skip
+        self._cross_attention_kwargs = cross_attention_kwargs
+
+        # 2. Define call parameters
+        if prompt is not None and isinstance(prompt, str):
+            batch_size = 1
+        elif prompt is not None and isinstance(prompt, list):
+            batch_size = len(prompt)
+        else:
+            batch_size = prompt_embeds.shape[0]
+
+        device = self._execution_device
+
+        if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
+            controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
+
+        global_pool_conditions = (
+            controlnet.config.global_pool_conditions
+            if isinstance(controlnet, ControlNetModel)
+            else controlnet.nets[0].config.global_pool_conditions
+        )
+        guess_mode = guess_mode or global_pool_conditions
+
+        # 3.1. Encode input prompt
+        text_encoder_lora_scale = (
+            self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+        )
+        (
+            prompt_embeds,
+            negative_prompt_embeds,
+            pooled_prompt_embeds,
+            negative_pooled_prompt_embeds,
+        ) = self.encode_prompt(
+            prompt,
+            prompt_2,
+            device,
+            num_images_per_prompt,
+            self.do_classifier_free_guidance,
+            negative_prompt,
+            negative_prompt_2,
+            prompt_embeds=prompt_embeds,
+            negative_prompt_embeds=negative_prompt_embeds,
+            pooled_prompt_embeds=pooled_prompt_embeds,
+            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+            lora_scale=text_encoder_lora_scale,
+            clip_skip=self.clip_skip,
+        )
+
+        # 3.2 Encode ip_adapter_image
+        if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+            image_embeds = self.prepare_ip_adapter_image_embeds(
+                ip_adapter_image,
+                ip_adapter_image_embeds,
+                device,
+                batch_size * num_images_per_prompt,
+                self.do_classifier_free_guidance,
+            )
+
+        # 4. Prepare image and controlnet_conditioning_image
+        if image is not None:
+            image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
+        else:
+            strength = 1.0
+
+        if mask_image is not None:
+            mask_image = self.image_processor.preprocess(mask_image, height=height, width=width).to(dtype=torch.float32)
+
+        if isinstance(controlnet, ControlNetModel):
+            control_image = self.prepare_control_image(
+                image=control_image,
+                width=width,
+                height=height,
+                batch_size=batch_size * num_images_per_prompt,
+                num_images_per_prompt=num_images_per_prompt,
+                device=device,
+                dtype=controlnet.dtype,
+                do_classifier_free_guidance=self.do_classifier_free_guidance,
+                guess_mode=guess_mode,
+            )
+            height, width = control_image.shape[-2:]
+        elif isinstance(controlnet, MultiControlNetModel):
+            control_images = []
+
+            for control_image_ in control_image:
+                control_image_ = self.prepare_control_image(
+                    image=control_image_,
+                    width=width,
+                    height=height,
+                    batch_size=batch_size * num_images_per_prompt,
+                    num_images_per_prompt=num_images_per_prompt,
+                    device=device,
+                    dtype=controlnet.dtype,
+                    do_classifier_free_guidance=self.do_classifier_free_guidance,
+                    guess_mode=guess_mode,
+                )
+
+                control_images.append(control_image_)
+
+            control_image = control_images
+            height, width = control_image[0].shape[-2:]
+        else:
+            assert False
+
+        # 4.1 Region control
+        controlnet_masks = []
+        if control_mask is not None:
+            for mask in control_mask:
+                mask = np.array(mask)
+                mask_tensor = torch.from_numpy(mask).to(device=device, dtype=prompt_embeds.dtype)
+                mask_tensor = mask_tensor[:, :, 0] / 255.
+                mask_tensor = mask_tensor[None, None]
+                h, w = mask_tensor.shape[-2:]
+                control_mask_list = []
+                for scale in [8, 8, 8, 16, 16, 16, 32, 32, 32]:
+                    # Python uses IEEE 754 rounding rules, we need to add a small value to round like the unet model
+                    w_n = round((w + 0.01) / 8)
+                    h_n = round((h + 0.01) / 8)
+                    if scale in [16, 32]:
+                        w_n = round((w_n + 0.01) / 2)
+                        h_n = round((h_n + 0.01) / 2)
+                    if scale == 32:
+                        w_n = round((w_n + 0.01) / 2)
+                        h_n = round((h_n + 0.01) / 2)
+                    scale_mask_weight_image_tensor = F.interpolate(
+                        mask_tensor,(h_n, w_n), mode='bilinear')
+                    control_mask_list.append(scale_mask_weight_image_tensor)
+                controlnet_masks.append(control_mask_list)
+
+        # 5. Prepare timesteps
+        full_num_inference_steps = int(num_inference_steps / strength) if strength > 0 else num_inference_steps
+
+        if timesteps is None:
+            self.scheduler.set_timesteps(full_num_inference_steps + 1, device=device)
+            sigmas = self._loglinear_interp(self.ays_noise_sigmas["SDXL"], full_num_inference_steps + 1)
+            sigmas[-1] = 0
+            log_sigmas = np.log(np.array((1 - self.scheduler.alphas_cumprod) / self.scheduler.alphas_cumprod) ** 0.5)
+            timesteps = np.array([self.scheduler._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
+            timesteps = timesteps[-(num_inference_steps + 1):-1]
+            if hasattr(self.scheduler, "sigmas"):
+                self.scheduler.sigmas = torch.from_numpy(sigmas)[-(num_inference_steps + 1):]
+            self.scheduler.timesteps = torch.from_numpy(timesteps).to(self.device, dtype=torch.int64)
+            self.scheduler.num_inference_steps = len(self.scheduler.timesteps)
+
+        else:
+            if "timesteps" in inspect.signature(self.scheduler.set_timesteps).parameters:
+                self.scheduler.set_timesteps(full_num_inference_steps + 1, timesteps=timesteps, device=device)
+            else:
+                self.scheduler.set_timesteps(full_num_inference_steps + 1, device=device)
+
+        latent_timestep = self.scheduler.timesteps[:1].repeat(batch_size * num_images_per_prompt)
+        self._num_timesteps = len(self.scheduler.timesteps)
+
+        # 6. Prepare latent variables
+        if latents is None:
+            num_channels_latents = self.unet.config.in_channels
+            latents = self.prepare_latents(
+                image,
+                latent_timestep,
+                batch_size * num_images_per_prompt,
+                num_channels_latents,
+                height,
+                width,
+                prompt_embeds.dtype,
+                device,
+                generator,
+                True,
+                seed
+            )
+
+            untouched_latents = self.prepare_latents(
+                image,
+                latent_timestep,
+                batch_size * num_images_per_prompt,
+                num_channels_latents,
+                height,
+                width,
+                prompt_embeds.dtype,
+                device,
+                generator,
+                False,
+                seed,
+            )
+
+            if mask_image is not None:
+                #resize i2i mask to the same size as the latents and reduce it to 1 channel
+                mask_image = mask_image[:, 0:1, :, :]
+                mask_image = F.interpolate(mask_image, (latents.shape[-2], latents.shape[-1]), mode="bilinear")
+                mask_image = mask_image.to(device=device, dtype=prompt_embeds.dtype)
+
+        else:
+            untouched_latents = latents.clone()
+
+        if hasattr(self.scheduler, "sigmas"):
+            sigmas = self.scheduler.sigmas
+            sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
+            seeds = [seed] * len(latents) if seed is not None else generator.seed()
+            brownian_tree_noise_sampler = BrownianTreeNoiseSampler(latents, sigma_min, sigma_max, seed=seeds, cpu=False)
+        else:
+            brownian_tree_noise_sampler = None
+
+        # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+        # 7.1 Create tensor stating which controlnets to keep
+        controlnet_keep = []
+        for i in range(len(timesteps)):
+            keeps = [
+                1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+                for s, e in zip(control_guidance_start, control_guidance_end)
+            ]
+            controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
+
+        # 7.2 Prepare added time ids & embeddings
+        if isinstance(control_image, list):
+            original_size = original_size or control_image[0].shape[-2:]
+        else:
+            original_size = original_size or control_image.shape[-2:]
+        target_size = target_size or (height, width)
+
+        if negative_original_size is None:
+            negative_original_size = original_size
+        if negative_target_size is None:
+            negative_target_size = target_size
+        add_text_embeds = pooled_prompt_embeds
+
+        if self.text_encoder_2 is None:
+            text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+        else:
+            text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+
+        add_time_ids, add_neg_time_ids = self._get_add_time_ids(
+            original_size,
+            crops_coords_top_left,
+            target_size,
+            aesthetic_score,
+            negative_aesthetic_score,
+            negative_original_size,
+            negative_crops_coords_top_left,
+            negative_target_size,
+            dtype=prompt_embeds.dtype,
+            text_encoder_projection_dim=text_encoder_projection_dim,
+        )
+        add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+
+        if self.do_classifier_free_guidance:
+            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+            add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+            add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
+
+        prompt_embeds = prompt_embeds.to(device)
+        add_text_embeds = add_text_embeds.to(device)
+        add_time_ids = add_time_ids.to(device)
+
+        # 8. Denoising loop
+        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+        with self.progress_bar(total=num_inference_steps) as progress_bar:
+            for i, t in enumerate(timesteps):
+                # expand the latents if we are doing classifier free guidance
+                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+                added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+
+                # controlnet(s) inference
+                if guess_mode and self.do_classifier_free_guidance:
+                    # Infer ControlNet only for the conditional batch.
+                    control_model_input = latents
+                    control_model_input = self.scheduler.scale_model_input(control_model_input, t)
+                    controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
+                    controlnet_added_cond_kwargs = {
+                        "text_embeds": add_text_embeds.chunk(2)[1],
+                        "time_ids": add_time_ids.chunk(2)[1],
+                    }
+                else:
+                    control_model_input = latent_model_input
+                    controlnet_prompt_embeds = prompt_embeds
+                    controlnet_added_cond_kwargs = added_cond_kwargs
+
+                if isinstance(controlnet_keep[i], list):
+                    cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+                else:
+                    controlnet_cond_scale = controlnet_conditioning_scale
+                    if isinstance(controlnet_cond_scale, list):
+                        controlnet_cond_scale = controlnet_cond_scale[0]
+                    cond_scale = controlnet_cond_scale * controlnet_keep[i]
+                
+                if ip_adapter_image_embeds is None and ip_adapter_image is not None:
+                    encoder_hidden_states = self.unet.process_encoder_hidden_states(prompt_embeds, {"image_embeds": image_embeds})
+                    ip_adapter_image_embeds = encoder_hidden_states[1]
+
+                down_block_res_samples = None
+                mid_block_res_sample = None
+
+                for controlnet_index in range(len(self.controlnet.nets)):
+                    ip_adapter_index = next((y for x, y in identity_control_indices if x == controlnet_index), None)
+                    if ip_adapter_index is not None:
+                        control_prompt_embeds = ip_adapter_image_embeds[ip_adapter_index].squeeze(1)
+                    else:
+                        control_prompt_embeds = controlnet_prompt_embeds
+                    down_samples, mid_sample = self.controlnet.nets[controlnet_index](
+                        control_model_input,
+                        t,
+                        encoder_hidden_states=control_prompt_embeds,
+                        controlnet_cond=control_image[controlnet_index],
+                        conditioning_scale=cond_scale[controlnet_index],
+                        guess_mode=guess_mode,
+                        added_cond_kwargs=controlnet_added_cond_kwargs,
+                        return_dict=False,
+                    )
+
+                    if len(controlnet_masks) > controlnet_index and controlnet_masks[controlnet_index] is not None:
+                        down_samples = [
+                            down_sample * mask_weight
+                            for down_sample, mask_weight in zip(down_samples, controlnet_masks[controlnet_index])
+                        ]
+                        mid_sample *= controlnet_masks[controlnet_index][-1]
+
+                    if down_block_res_samples is None and mid_block_res_sample is None:
+                        down_block_res_samples = down_samples
+                        mid_block_res_sample = mid_sample
+                    else:
+                        down_block_res_samples = [
+                            samples_prev + samples_curr
+                            for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
+                        ]
+                        mid_block_res_sample += mid_sample
+
+                if guess_mode and self.do_classifier_free_guidance:
+                    # Infered ControlNet only for the conditional batch.
+                    # To apply the output of ControlNet to both the unconditional and conditional batches,
+                    # add 0 to the unconditional batch to keep it unchanged.
+                    down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
+                    mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
+
+                if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+                    added_cond_kwargs["image_embeds"] = image_embeds
+
+                # predict the noise residual
+                noise_pred = self.unet(
+                    latent_model_input,
+                    t,
+                    encoder_hidden_states=prompt_embeds,
+                    cross_attention_kwargs=self.cross_attention_kwargs,
+                    down_block_additional_residuals=down_block_res_samples,
+                    mid_block_additional_residual=mid_block_res_sample,
+                    added_cond_kwargs=added_cond_kwargs,
+                    return_dict=False,
+                )[0]
+
+                # perform guidance
+                if self.do_classifier_free_guidance:
+                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+                # compute the previous noisy sample x_t -> x_t-1
+                # check if scheduler.step supports variance noise
+                if "variance_noise" in inspect.signature(self.scheduler.step).parameters and brownian_tree_noise_sampler is not None:
+                    sigmas = self.scheduler.sigmas
+                    noise = brownian_tree_noise_sampler(sigmas[i], sigmas[i + 1]).to(device=device, dtype=latents.dtype)
+                    latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False, variance_noise=noise)[0]
+                else:
+                    latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+                if image is not None and mask_image is not None:
+                    timesteps_count = len(timesteps) - 1
+                    i2i_mask_guidance_end_index = int(i2i_mask_guidance_end * timesteps_count)
+                    i2i_mask_guidance_start_index = int(i2i_mask_guidance_start * timesteps_count)
+                    if i < i2i_mask_guidance_end_index and i >= i2i_mask_guidance_start_index:
+                        inpaint_timesteps = torch.tensor([t-1]).repeat(
+                            batch_size * num_images_per_prompt
+                        )
+                        new_noisy_latents = self.add_noise(untouched_latents, inpaint_timesteps, device, latents.dtype, generator, seed)
+                        latents = torch.where(mask_image != 1, new_noisy_latents, latents)
+
+
+                if callback_on_step_end is not None:
+                    callback_kwargs = {}
+                    for k in callback_on_step_end_tensor_inputs:
+                        callback_kwargs[k] = locals()[k]
+                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+                    latents = callback_outputs.pop("latents", latents)
+                    prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+                    negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+                # call the callback, if provided
+                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+                    progress_bar.update()
+                    if callback is not None and i % callback_steps == 0:
+                        step_idx = i // getattr(self.scheduler, "order", 1)
+                        callback(step_idx, t, latents)
+
+        # If we do sequential model offloading, let's offload unet and controlnet
+        # manually for max memory savings
+        if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+            self.unet.to("cpu")
+            self.controlnet.to("cpu")
+            torch.cuda.empty_cache()
+
+        if not output_type == "latent":
+            # make sure the VAE is in float32 mode, as it overflows in float16
+            needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+
+            if needs_upcasting:
+                self.upcast_vae()
+                latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+            # unscale/denormalize the latents
+            # denormalize with the mean and std if available and not None
+            has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
+            has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
+            if has_latents_mean and has_latents_std:
+                latents_mean = (
+                    torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+                )
+                latents_std = (
+                    torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+                )
+                latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
+            else:
+                latents = latents / self.vae.config.scaling_factor
+
+            image = self.vae.decode(latents, return_dict=False)[0]
+
+            # cast back to fp16 if needed
+            if needs_upcasting:
+                self.vae.to(dtype=torch.float16)
+        else:
+            image = latents
+            return StableDiffusionXLPipelineOutput(images=image)
+
+        # apply watermark if available
+        if self.watermark is not None:
+            image = self.watermark.apply_watermark(image)
+
+        image = self.image_processor.postprocess(image, output_type=output_type)
+
+        # Offload all models
+        self.maybe_free_model_hooks()
+
+        if not return_dict:
+            return (image,)
+
+        return StableDiffusionXLPipelineOutput(images=image)