import gradio as gr
import glob
import torch
import pickle
from PIL import Image, ImageDraw
import numpy as np
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation

from scipy.ndimage import center_of_mass

def combine_ims(im1, im2, val=128):
    p = Image.new("L", im1.size, val)
    im = Image.composite(im1, im2, p)
    return im

def get_class_centers(segmentation_mask, class_dict):
    segmentation_mask = segmentation_mask.numpy() + 1
    class_centers = {}
    for class_index, _ in class_dict.items():
        class_mask = (segmentation_mask == class_index).astype(int)
        center_of_mass_list = center_of_mass(class_mask)
        
        class_centers[class_index] = center_of_mass_list
    
    class_centers = {k: list(map(int, v)) for k, v in class_centers.items() if not np.isnan(sum(v))}
    return class_centers

def visualize_mask(predicted_semantic_map, class_ids, class_colors):
    h, w = predicted_semantic_map.shape
    color_indexes = np.zeros((h, w), dtype=np.uint8)
    color_indexes[:] = predicted_semantic_map.numpy()
    color_indexes = color_indexes.flatten()

    colors = class_colors[class_ids[color_indexes]]
    output = colors.reshape(h, w, 3).astype(np.uint8)
    image_mask = Image.fromarray(output)
    return image_mask

def get_out_image(image, predicted_semantic_map):
    class_centers = get_class_centers(predicted_semantic_map, class_dict)
    mask = visualize_mask(predicted_semantic_map, class_ids, class_colors)
    image_mask = combine_ims(image, mask, val=128)
    draw = ImageDraw.Draw(image_mask)

    extracted_tags = []
    for id, (y, x) in class_centers.items():
        class_name = str(class_names[id - 1])
        extracted_tags.append(class_name)  # Append only the class name
        draw.text((x, y), class_name, fill='black')

    # Joining all tags into a single string separated by " | "
    tags_string = " | ".join(extracted_tags)

    return image_mask, tags_string




def gradio_process(image):
    inputs = processor(images=image, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**inputs)

    predicted_semantic_map = processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]

    out_image, extracted_tags = get_out_image(image, predicted_semantic_map)
    return out_image, extracted_tags

with open('ade20k_classes.pickle', 'rb') as f:
    class_names, class_ids, class_colors = pickle.load(f)
class_names, class_ids, class_colors = np.array(class_names), np.array(class_ids), np.array(class_colors)
class_dict = dict(zip(class_ids, class_names))

device = torch.device("cpu")
processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-large-ade-semantic")
model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-large-ade-semantic").to(device)
model.eval()

demo = gr.Interface(
    gradio_process, 
    inputs=gr.inputs.Image(type="pil"), 
    outputs=[gr.outputs.Image(type="pil"), gr.outputs.Textbox()],
    title="Semantic Segmentation",
    examples=glob.glob('./examples/*.jpg'),
    allow_flagging="never",
)

demo.launch()