import gradio as gr
import spaces
import torch
from loadimg import load_img
from torchvision import transforms
from transformers import AutoModelForImageSegmentation, pipeline
from diffusers import FluxFillPipeline
from PIL import Image, ImageOps

# from sam2.sam2_image_predictor import SAM2ImagePredictor
import numpy as np
from simple_lama_inpainting import SimpleLama
from contextlib import contextmanager


@contextmanager
def float32_high_matmul_precision():
    torch.set_float32_matmul_precision("high")
    try:
        yield
    finally:
        torch.set_float32_matmul_precision("highest")


pipe = FluxFillPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16
).to("cuda")

birefnet = AutoModelForImageSegmentation.from_pretrained(
    "ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cuda")

transform_image = transforms.Compose(
    [
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)


def prepare_image_and_mask(
    image,
    padding_top=0,
    padding_bottom=0,
    padding_left=0,
    padding_right=0,
):
    image = load_img(image).convert("RGB")
    # expand image (left,top,right,bottom)
    background = ImageOps.expand(
        image,
        border=(padding_left, padding_top, padding_right, padding_bottom),
        fill="white",
    )
    mask = Image.new("RGB", image.size, "black")
    mask = ImageOps.expand(
        mask,
        border=(padding_left, padding_top, padding_right, padding_bottom),
        fill="white",
    )
    return background, mask


def outpaint(
    image,
    padding_top=0,
    padding_bottom=0,
    padding_left=0,
    padding_right=0,
    prompt="",
    num_inference_steps=28,
    guidance_scale=50,
):
    background, mask = prepare_image_and_mask(
        image, padding_top, padding_bottom, padding_left, padding_right
    )

    result = pipe(
        prompt=prompt,
        height=background.height,
        width=background.width,
        image=background,
        mask_image=mask,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
    ).images[0]

    result = result.convert("RGBA")

    return result


def inpaint(
    image,
    mask,
    prompt="",
    num_inference_steps=28,
    guidance_scale=50,
):
    background = image.convert("RGB")
    mask = mask.convert("L")

    result = pipe(
        prompt=prompt,
        height=background.height,
        width=background.width,
        image=background,
        mask_image=mask,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
    ).images[0]

    result = result.convert("RGBA")

    return result


def rmbg(image=None, url=None):
    if image is None:
        image = url
    image = load_img(image).convert("RGB")
    image_size = image.size
    input_images = transform_image(image).unsqueeze(0).to("cuda")
    with float32_high_matmul_precision():
        # Prediction
        with torch.no_grad():
            preds = birefnet(input_images)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(image_size)
    image.putalpha(mask)
    return image


# def mask_generation(image=None, d=None):
#     # use bfloat16 for the entire notebook
#     # torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
#     # # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
#     # if torch.cuda.get_device_properties(0).major >= 8:
#     #     torch.backends.cuda.matmul.allow_tf32 = True
#     #     torch.backends.cudnn.allow_tf32 = True
#     d = eval(d)  # convert this to dictionary
#     with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
#         predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2.1-hiera-large")
#         predictor.set_image(image)
#         input_point = np.array(d["input_points"])
#         input_label = np.array(d["input_labels"])
#         masks, scores, logits = predictor.predict(
#             point_coords=input_point,
#             point_labels=input_label,
#             multimask_output=True,
#         )
#     sorted_ind = np.argsort(scores)[::-1]
#     masks = masks[sorted_ind]
#     scores = scores[sorted_ind]
#     logits = logits[sorted_ind]

#     out = []
#     for i in range(len(masks)):
#         m = Image.fromarray(masks[i] * 255).convert("L")
#         comp = Image.composite(image, m, m)
#         out.append((comp, f"image {i}"))

#     return out


def erase(image=None, mask=None):
    simple_lama = SimpleLama()
    image = load_img(image)
    mask = load_img(mask).convert("L")
    return simple_lama(image, mask)


# Initialize Whisper model
whisper = pipeline(
    task="automatic-speech-recognition",
    model="openai/whisper-large-v3",
    chunk_length_s=30,
    device="cuda" if torch.cuda.is_available() else "cpu",
)


def transcribe(audio, task="transcribe"):
    if audio is None:
        raise gr.Error("No audio file submitted!")

    text = whisper(
        audio, batch_size=8, generate_kwargs={"task": task}, return_timestamps=True
    )["text"]
    return text


@spaces.GPU(duration=120)
def main(*args):
    api_num = args[0]
    args = args[1:]
    if api_num == 1:
        return rmbg(*args)
    elif api_num == 2:
        return outpaint(*args)
    elif api_num == 3:
        return inpaint(*args)
    # elif api_num == 4:
    #     return mask_generation(*args)
    elif api_num == 5:
        return erase(*args)
    elif api_num == 6:
        return transcribe(*args)


rmbg_tab = gr.Interface(
    fn=main,
    inputs=[
        gr.Number(1, interactive=False),
        "image",
        gr.Text("", label="url"),
    ],
    outputs=["image"],
    api_name="rmbg",
    examples=[[1, "./assets/Inpainting mask.png", ""]],
    cache_examples=False,
    description="pass an image or a url of an image",
)

outpaint_tab = gr.Interface(
    fn=main,
    inputs=[
        gr.Number(2, interactive=False),
        gr.Image(label="image", type="pil"),
        gr.Number(label="padding top"),
        gr.Number(label="padding bottom"),
        gr.Number(label="padding left"),
        gr.Number(label="padding right"),
        gr.Text(label="prompt"),
        gr.Number(value=50, label="num_inference_steps"),
        gr.Number(value=28, label="guidance_scale"),
    ],
    outputs=["image"],
    api_name="outpainting",
    examples=[[2, "./assets/rocket.png", 100, 0, 0, 0, "", 50, 28]],
    cache_examples=False,
)


inpaint_tab = gr.Interface(
    fn=main,
    inputs=[
        gr.Number(3, interactive=False),
        gr.Image(label="image", type="pil"),
        gr.Image(label="mask", type="pil"),
        gr.Text(label="prompt"),
        gr.Number(value=50, label="num_inference_steps"),
        gr.Number(value=28, label="guidance_scale"),
    ],
    outputs=["image"],
    api_name="inpaint",
    examples=[[3, "./assets/rocket.png", "./assets/Inpainting mask.png"]],
    cache_examples=False,
    description="it is recommended that you use https://github.com/la-voliere/react-mask-editor when creating an image mask in JS and then inverse it before sending it to this space",
)


# sam2_tab = gr.Interface(
#     main,
#     inputs=[
#         gr.Number(4, interactive=False),
#         gr.Image(type="pil"),
#         gr.Text(),
#     ],
#     outputs=gr.Gallery(),
#     examples=[
#         [
#             4,
#             "./assets/truck.jpg",
#             '{"input_points": [[500, 375], [1125, 625]], "input_labels": [1, 0]}',
#         ]
#     ],
#     api_name="sam2",
#     cache_examples=False,
# )

erase_tab = gr.Interface(
    main,
    inputs=[
        gr.Number(5, interactive=False),
        gr.Image(type="pil"),
        gr.Image(type="pil"),
    ],
    outputs=gr.Image(),
    examples=[
        [
            5,
            "./assets/rocket.png",
            "./assets/Inpainting mask.png",
        ]
    ],
    api_name="erase",
    cache_examples=False,
)

transcribe_tab = gr.Interface(
    fn=main,
    inputs=[
        gr.Number(6, interactive=False),
        gr.Audio(type="filepath"),
        gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
    ],
    outputs="text",
    api_name="transcribe",
    description="Upload an audio file to extract text using Whisper Large V3",
)

demo = gr.TabbedInterface(
    [
        rmbg_tab,
        outpaint_tab,
        inpaint_tab,
        #  sam2_tab,
        erase_tab,
        transcribe_tab,
    ],
    [
        "remove background",
        "outpainting",
        "inpainting",
        #  "sam2",
        "erase",
        "transcribe",
    ],
    title="Utilities that require GPU",
)


demo.launch()