from diffusers import StableDiffusionXLPipeline, AutoencoderKL
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from PIL import Image, ImageOps
import gradio as gr
import user_history

vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    vae=vae,
    torch_dtype=torch.float16,
    variant="fp16",
    use_safetensors=True,
)
pipe.to("cuda")
pipe.unet.to(memory_format=torch.channels_last)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)

@torch.no_grad()
def call(
        pipe,
        prompt: Union[str, List[str]] = None,
        prompt2: Union[str, List[str]] = None,
        height: Optional[int] = None,
        width: Optional[int] = None,
        num_inference_steps: int = 50,
        denoising_end: Optional[float] = None,
        guidance_scale: float = 5.0,
        guidance_scale2: float = 5.0,
        negative_prompt: Optional[Union[str, List[str]]] = None,
        negative_prompt2: Optional[Union[str, List[str]]] = None,
        num_images_per_prompt: Optional[int] = 1,
        eta: float = 0.0,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        latents: Optional[torch.FloatTensor] = None,
        prompt_embeds: Optional[torch.FloatTensor] = None,
        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
        negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
        callback_steps: int = 1,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        guidance_rescale: float = 0.0,
        original_size: Optional[Tuple[int, int]] = None,
        crops_coords_top_left: Tuple[int, int] = (0, 0),
        target_size: Optional[Tuple[int, int]] = None,
        negative_original_size: Optional[Tuple[int, int]] = None,
        negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
        negative_target_size: Optional[Tuple[int, int]] = None,
    ):
        # 0. Default height and width to unet
        height = height or pipe.default_sample_size * pipe.vae_scale_factor
        width = width or pipe.default_sample_size * pipe.vae_scale_factor

        original_size = original_size or (height, width)
        target_size = target_size or (height, width)

        # 1. Check inputs. Raise error if not correct
        pipe.check_inputs(
            prompt,
            None,
            height,
            width,
            callback_steps,
            negative_prompt,
            None,
            prompt_embeds,
            negative_prompt_embeds,
            pooled_prompt_embeds,
            negative_pooled_prompt_embeds,
        )

        # 2. Define call parameters
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

        device = pipe._execution_device

        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
        # corresponds to doing no classifier free guidance.
        do_classifier_free_guidance = guidance_scale > 1.0

        # 3. Encode input prompt
        text_encoder_lora_scale = (
            cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
        )

        (
            prompt_embeds,
            negative_prompt_embeds,
            pooled_prompt_embeds,
            negative_pooled_prompt_embeds,
        ) = pipe.encode_prompt(
            prompt=prompt,
            device=device,
            num_images_per_prompt=num_images_per_prompt,
            do_classifier_free_guidance=do_classifier_free_guidance,
            negative_prompt=negative_prompt,
            prompt_embeds=None,
            negative_prompt_embeds=None,
            pooled_prompt_embeds=None,
            negative_pooled_prompt_embeds=None,
            lora_scale=text_encoder_lora_scale,
        )

        (
            prompt2_embeds,
            negative_prompt2_embeds,
            pooled_prompt2_embeds,
            negative_pooled_prompt2_embeds,
        ) = pipe.encode_prompt(
            prompt=prompt2,
            device=device,
            num_images_per_prompt=num_images_per_prompt,
            do_classifier_free_guidance=do_classifier_free_guidance,
            negative_prompt=negative_prompt2,
            prompt_embeds=None,
            negative_prompt_embeds=None,
            pooled_prompt_embeds=None,
            negative_pooled_prompt_embeds=None,
            lora_scale=text_encoder_lora_scale,
        )

        # 4. Prepare timesteps
        pipe.scheduler.set_timesteps(num_inference_steps, device=device)

        timesteps = pipe.scheduler.timesteps

        # 5. Prepare latent variables
        num_channels_latents = pipe.unet.config.in_channels
        latents = pipe.prepare_latents(
            batch_size * num_images_per_prompt,
            num_channels_latents,
            height,
            width,
            prompt_embeds.dtype,
            device,
            generator,
            latents,
        )

        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
        extra_step_kwargs = pipe.prepare_extra_step_kwargs(generator, eta)

        # 7. Prepare added time ids & embeddings
        add_text_embeds = pooled_prompt_embeds
        add_text2_embeds = pooled_prompt2_embeds

        add_time_ids = pipe._get_add_time_ids(
            original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
        )
        add_time2_ids = pipe._get_add_time_ids(
            original_size, crops_coords_top_left, target_size, dtype=prompt2_embeds.dtype
        )

        if negative_original_size is not None and negative_target_size is not None:
            negative_add_time_ids = pipe._get_add_time_ids(
                negative_original_size,
                negative_crops_coords_top_left,
                negative_target_size,
                dtype=prompt_embeds.dtype,
            )
        else:
            negative_add_time_ids = add_time_ids
            negative_add_time2_ids = add_time2_ids

        if do_classifier_free_guidance:
            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
            add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)

            prompt2_embeds = torch.cat([negative_prompt2_embeds, prompt2_embeds], dim=0)
            add_text2_embeds = torch.cat([negative_pooled_prompt2_embeds, add_text2_embeds], dim=0)
            add_time2_ids = torch.cat([negative_add_time2_ids, add_time2_ids], dim=0)

        prompt_embeds = prompt_embeds.to(device)
        add_text_embeds = add_text_embeds.to(device)
        add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)

        prompt2_embeds = prompt2_embeds.to(device)
        add_text2_embeds = add_text2_embeds.to(device)
        add_time2_ids = add_time2_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)

        # 8. Denoising loop
        num_warmup_steps = max(len(timesteps) - num_inference_steps * pipe.scheduler.order, 0)

        # 7.1 Apply denoising_end
        if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
            discrete_timestep_cutoff = int(
                round(
                    pipe.scheduler.config.num_train_timesteps
                    - (denoising_end * pipe.scheduler.config.num_train_timesteps)
                )
            )
            num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
            timesteps = timesteps[:num_inference_steps]

        with pipe.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
                if i % 2 == 0:
                  # expand the latents if we are doing classifier free guidance
                  latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents

                  latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)

                  # predict the noise residual
                  added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
                  noise_pred = pipe.unet(
                      latent_model_input,
                      t,
                      encoder_hidden_states=prompt_embeds,
                      cross_attention_kwargs=cross_attention_kwargs,
                      added_cond_kwargs=added_cond_kwargs,
                      return_dict=False,
                  )[0]

                  # perform guidance
                  if do_classifier_free_guidance:
                      noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                      noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
                else:
                  # expand the latents if we are doing classifier free guidance
                  latent_model_input2 = torch.cat([latents.flip(2)] * 2) if do_classifier_free_guidance else latents
                  latent_model_input2 = pipe.scheduler.scale_model_input(latent_model_input2, t)

                  # predict the noise residual
                  added_cond2_kwargs = {"text_embeds": add_text2_embeds, "time_ids": add_time2_ids}
                  noise_pred2 = pipe.unet(
                      latent_model_input2,
                      t,
                      encoder_hidden_states=prompt2_embeds,
                      cross_attention_kwargs=cross_attention_kwargs,
                      added_cond_kwargs=added_cond2_kwargs,
                      return_dict=False,
                  )[0]

                  # perform guidance
                  if do_classifier_free_guidance:
                      noise_pred2_uncond, noise_pred2_text = noise_pred2.chunk(2)
                      noise_pred2 = noise_pred2_uncond + guidance_scale2 * (noise_pred2_text - noise_pred2_uncond)

                noise_pred = noise_pred if i % 2 == 0 else noise_pred2.flip(2)

                # compute the previous noisy sample x_t -> x_t-1
                latents = pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

                # call the callback, if provided
                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipe.scheduler.order == 0):
                    progress_bar.update()
                    if callback is not None and i % callback_steps == 0:
                        callback(i, t, latents)

        if not output_type == "latent":
            # make sure the VAE is in float32 mode, as it overflows in float16
            needs_upcasting = pipe.vae.dtype == torch.float16 and pipe.vae.config.force_upcast

            if needs_upcasting:
                pipe.upcast_vae()
                latents = latents.to(next(iter(pipe.vae.post_quant_conv.parameters())).dtype)

            image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]

            # cast back to fp16 if needed
            if needs_upcasting:
                pipe.vae.to(dtype=torch.float16)
        else:
            image = latents

        if not output_type == "latent":
            # apply watermark if available
            if pipe.watermark is not None:
                image = pipe.watermark.apply_watermark(image)

            image = pipe.image_processor.postprocess(image, output_type=output_type)

        # Offload all models
        pipe.maybe_free_model_hooks()

        if not return_dict:
            return (image,)

        return StableDiffusionXLPipelineOutput(images=image)

NEGATIVE_PROMPTS = "text, watermark, low-quality, signature, moiré pattern, downsampling, aliasing, distorted, blurry, glossy, blur, jpeg artifacts, compression artifacts, poorly drawn, low-resolution, bad, distortion, twisted, excessive, exaggerated pose, exaggerated limbs, grainy, symmetrical, duplicate, error, pattern, beginner, pixelated, fake, hyper, glitch, overexposed, high-contrast, bad-contrast"

def rotate_output(has_flipped):
    if(has_flipped):
        return gr.Image(elem_classes="not_rotated"), gr.Button("Rotate to see prompt 2!"), not has_flipped
    else:
        return gr.Image(elem_classes="rotated"), gr.Button("Rotate to see prompt 1!"), not has_flipped

def simple_call(prompt1, prompt2, profile: gr.OAuthProfile | None=None):
    generator = [torch.Generator(device="cuda").manual_seed(5)]
    res = call(
        pipe,
        prompt1,
        prompt2,
        width=768,
        height=768,
        num_images_per_prompt=1,
        num_inference_steps=50,
        guidance_scale=5.0,
        guidance_scale2=8.0,
        negative_prompt=NEGATIVE_PROMPTS,
        negative_prompt2=NEGATIVE_PROMPTS,
        generator=generator
    )
    image1 = res.images[0]

    # save generated images (if logged in)
    user_history.save_image(label=f"{prompt1} / {prompt2}", image=image1, profile=profile, metadata={
        "prompt2": prompt1,
        "prompt1": prompt2,
    })

    return image1
css = '''
#result_image{ transition: transform 2s ease-in-out }
#result_image.rotated{transform: rotate(180deg)}
'''
with gr.Blocks() as app:
    gr.Markdown(
        '''
        <center>
            <h1>Upside Down Diffusion</h1>
            <p>Code by Alex Carlier, <a href="https://colab.research.google.com/drive/1rjDQOn11cTHAf3Oeq87Hfl_Vh41NbTl4?usp=sharing">Google Colab</a>, follow them on <a href="https://twitter.com/alexcarliera">Twitter</a></p>
            <p>A space by <a href="https://twitter.com/angrypenguinPNG">AP</a> with contributions from <a href="https://twitter.com/multimodalart">MultimodalArt</a></p>
        </center>
        <hr>
        <p>
            Enter your first prompt to craft an image that will show when upright. Then, add a second prompt to reveal a mesmerizing surprise when you flip the image upside down!  ✨
        </p>
        <p>
            <em>For best results, please include the prompt in the following format: Art Style and Object. Here is an example: Prompt 1: A sketch of a turtle, Prompt 2: A sketch of a tree. Both prompts need to have the same style!</em>
        </p>
        '''
    )

    has_flipped = gr.State(value=False)
    with gr.Row():
        with gr.Column():
            prompt1 = gr.Textbox(label="Prompt 1", info="Prompt for the side up", placeholder="A sketch of a...")
            prompt2 = gr.Textbox(label="Prompt 2", info="Prompt for the side down", placeholder="A sketch of a...")
            run_btn = gr.Button("Run")
                
        with gr.Column():
            result_image1 = gr.Image(label="Output", elem_id="result_image", elem_classes="not_rotated")
            rotate_button = gr.Button("Rotate to see prompt 2!")
            

    run_btn.click(
        simple_call,
        inputs=[prompt1, prompt2],
        outputs=[result_image1]
    )
    rotate_button.click(
        rotate_output,
        inputs=[has_flipped],
        outputs=[result_image1, rotate_button, has_flipped],
        queue=False,
        show_progress=False
    )

with gr.Blocks(css=css) as app_with_history:
    with gr.Tab("Upside Down Diffusion"):
        app.render()
    with gr.Tab("Past generations"):
        user_history.render()

app_with_history.queue(max_size=20)

if __name__ == "__main__":
    app_with_history.launch(debug=True)