try:
    import spaces
    gpu_decorator = spaces.GPU
    from .load_model import load_xclip
except ImportError:
    print("No GPU decorator found. Running on CPU.")
    # Define a no-operation decorator as fallback
    def gpu_decorator(func):
        return func

    
import PIL
import torch

from .prompts import GetPromptList

ORG_PART_ORDER = ['back', 'beak', 'belly', 'breast', 'crown', 'forehead', 'eyes', 'legs', 'wings', 'nape', 'tail', 'throat']
ORDERED_PARTS = ['crown', 'forehead', 'nape', 'eyes', 'beak', 'throat', 'breast', 'belly', 'back', 'wings', 'legs', 'tail']

def encode_descs_xclip(owlvit_det_processor: callable, model: callable, descs: list[str], device: str, max_batch_size: int = 512):
    total_num_batches = len(descs) // max_batch_size + 1
    with torch.no_grad():
        text_embeds = []
        for batch_idx in range(total_num_batches):
            query_descs = descs[batch_idx*max_batch_size:(batch_idx+1)*max_batch_size]
            query_tokens = owlvit_det_processor(text=query_descs, padding="max_length", truncation=True, return_tensors="pt").to(device)
            query_embeds = model.owlvit.get_text_features(**query_tokens)
            text_embeds.append(query_embeds.cpu().float())
    text_embeds = torch.cat(text_embeds, dim=0)
    return text_embeds.to(device)

# def encode_descs_clip(model: callable, descs: list[str], device: str, max_batch_size: int = 512):
#     total_num_batches = len(descs) // max_batch_size + 1
#     with torch.no_grad():
#         text_embeds = []
#         for batch_idx in range(total_num_batches):
#             desc = descs[batch_idx*max_batch_size:(batch_idx+1)*max_batch_size]
#             query_tokens = clip.tokenize(desc).to(device)
#             text_embeds.append(model.encode_text(query_tokens).cpu().float())
#     text_embeds = torch.cat(text_embeds, dim=0)
#     text_embeds = torch.nn.functional.normalize(text_embeds, dim=-1)
#     return text_embeds.to(device)
@gpu_decorator
def xclip_pred(new_desc: dict, 
               new_part_mask: dict, 
               new_class: str, 
               org_desc: str, 
               image: PIL.Image, 
               model: callable, 
               owlvit_processor: callable,
               device: str,
               return_img_embeds: bool = False,
               use_precompute_embeddings = True,
               image_name: str = None,
               cub_embeds: torch.Tensor = None,
               cub_idx2name: dict = None,
               descriptors: dict = None):
    # check if in huggingface space
    try:
        model.to('cuda')
        device = 'cuda'
    except:
        device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
    
    # reorder the new description and the mask
    if new_class is not None:
        new_desc_ = {k: new_desc[k] for k in ORG_PART_ORDER}
        new_part_mask_ = {k: new_part_mask[k] for k in ORG_PART_ORDER}
        desc_mask = list(new_part_mask_.values())
    else:
        desc_mask = [1] * 12

    if cub_embeds is None:
        # replace the description if the new class is in the description, otherwise add a new class
        getprompt = GetPromptList(org_desc)
        if new_class not in getprompt.desc and new_class is not None:
            getprompt.name2idx[new_class] = len(getprompt.name2idx)
        if new_class is not None:
            getprompt.desc[new_class] = list(new_desc_.values())
        
        idx2name = dict(zip(getprompt.name2idx.values(), getprompt.name2idx.keys()))
        modified_class_idx = getprompt.name2idx[new_class] if new_class is not None else None
        
        n_classes = len(getprompt.name2idx)
        descs, class_idxs, class_mapping, org_desc_mapper, class_list = getprompt('chatgpt-no-template', max_len=12, pad=True)
        query_embeds = encode_descs_xclip(owlvit_processor, model, descs, device)
        
    else:
        cub_embeds = cub_embeds.to(device)
        if new_class is not None:
            if new_class in list(cub_idx2name.values()):
                new_class = f"{new_class}_custom"
            idx2name = cub_idx2name | {200: new_class}
            descriptors |= {new_class: list(new_desc_.values())}
            n_classes = 201
            query_tokens = owlvit_processor(text=list(new_desc_.values()), padding="max_length", truncation=True, return_tensors="pt").to(device)
            new_class_embed = model.owlvit.get_text_features(**query_tokens)
            query_embeds = torch.cat([cub_embeds, new_class_embed], dim=0)
            modified_class_idx = 200
        else:
            n_classes = 200
            query_embeds = cub_embeds
            idx2name = cub_idx2name
            modified_class_idx = None
        
    model.cls_head.num_classes = n_classes
    
    with torch.no_grad():
        
        part_embeds = owlvit_processor(text=[ORG_PART_ORDER], return_tensors="pt").to(device)
        if use_precompute_embeddings:
            image_embeds = torch.load(f'data/image_embeddings/{image_name}.pt', weights_only=True, map_location=device).to(device)
        else:
            image_input = owlvit_processor(images=image, return_tensors='pt').to(device)
            image_embeds, _ = model.image_embedder(pixel_values = image_input['pixel_values'])
            
        pred_logits, part_logits = model(image_embeds, part_embeds, query_embeds, None)
        
        b, c, n = part_logits.shape
        mask = torch.tensor(desc_mask, dtype=float).unsqueeze(0).unsqueeze(0).repeat(b, c, 1).to(device)
        # overwrite the pred_logits
        part_logits = part_logits * mask
        pred_logits = torch.sum(part_logits, dim=-1)
        
        pred_class_idx = torch.argmax(pred_logits, dim=-1).cpu()
        pred_class_name = idx2name[pred_class_idx.item()]
        
        softmax_scores = torch.softmax(pred_logits, dim=-1).cpu()
        softmax_score_top1 = torch.topk(softmax_scores, k=1, dim=-1)[0].squeeze(-1).item()
        
        part_scores = part_logits[0, pred_class_idx].cpu().squeeze(0)
        part_scores_dict = dict(zip(ORG_PART_ORDER, part_scores.tolist()))
        
        if modified_class_idx is not None:
            modified_score = softmax_scores[0, modified_class_idx].item()
            modified_part_scores = part_logits[0, modified_class_idx].cpu().squeeze(0)
            modified_part_scores_dict = dict(zip(ORG_PART_ORDER, modified_part_scores.tolist()))
        else:
            modified_score = None
            modified_part_scores_dict = None
        
    output_dict = {"pred_class": pred_class_name,
                   "pred_score": softmax_score_top1,
                   "pred_desc_scores": part_scores_dict,
                   "descriptions": descriptors[pred_class_name],
                   "modified_class": new_class,
                   "modified_score": modified_score,
                   "modified_desc_scores": modified_part_scores_dict,
                   "modified_descriptions": descriptors.get(new_class),
                   }
    return (output_dict, image_embeds) if return_img_embeds else output_dict


# def sachit_pred(new_desc: list, 
#                 new_class: str,
#                 org_desc: str,
#                 image: PIL.Image,
#                 model: callable,
#                 preprocess: callable,
#                 device: str,
#                 ):

#     # replace the description if the new class is in the description, otherwise add a new class
#     getprompt = GetPromptList(org_desc)
    
#     if new_class not in getprompt.desc:
#         getprompt.name2idx[new_class] = len(getprompt.name2idx)
#     getprompt.desc[new_class] = new_desc
    
#     idx2name = dict(zip(getprompt.name2idx.values(), getprompt.name2idx.keys()))
#     modified_class_idx = getprompt.name2idx[new_class]
    
#     descs, class_idxs, class_mapping, org_desc_mapper, class_list = getprompt('Sachit-descriptors', max_len=12, pad=True)
    
#     text_embeds = encode_descs_clip(model, descs, device)
    
#     with torch.no_grad():
#         image_embed = model.encode_image(preprocess(image).unsqueeze(0).to(device))
#         desc_mask = torch.tensor(class_idxs)
#         desc_mask = torch.where(desc_mask == -1, 0, 1).unsqueeze(0).to(device)
        
#         sim = torch.matmul(image_embed.float(), text_embeds.T)
#         sim = (sim * desc_mask).view(1, -1, 12)
#         pred_scores = torch.sum(sim, dim=-1)
#         pred_class_idx = torch.argmax(pred_scores, dim=-1).cpu()
#         pred_class = idx2name[pred_class_idx.item()]
        
#         softmax_scores = torch.nn.functional.softmax(pred_scores, dim=-1).cpu()
#         top1_score = torch.topk(softmax_scores, k=1, dim=-1)[0].squeeze(-1).item()
#         modified_score = softmax_scores[0, modified_class_idx].item()
        
#         pred_desc_scores = sim[0, pred_class_idx].cpu().squeeze(0)
#         modified_class_scores = sim[0, modified_class_idx].cpu().squeeze(0)
        
    
#     output_dict = {"pred_class": pred_class,
#                    "pred_score": top1_score,
#                    "pred_desc_scores": pred_desc_scores.tolist(),
#                    "descriptions": getprompt.desc[pred_class],
#                    "modified_class": new_class,
#                    "modified_score": modified_score,
#                    "modified_desc_scores": modified_class_scores.tolist(),
#                    "modified_descriptions": getprompt.desc[new_class],
#                    }
    
#     return output_dict