from PIL import Image
import requests
import matplotlib.pyplot as plt
import gradio as gr
from gradio.mix import Parallel
import torch
from transformers import (
    ViTConfig, 
    ViTForImageClassification,
    ViTFeatureExtractor,
    AutoModelForCausalLM,
    LogitsProcessorList,
    MinLengthLogitsProcessor,
    StoppingCriteriaList,
    MaxLengthCriteria,
    ImageClassificationPipeline, 
    PerceiverForImageClassificationConvProcessing, 
    PerceiverFeatureExtractor,
    VisionEncoderDecoderModel,
    AutoTokenizer,
)
import json
import os
#get from local file spaces_info.py
from spaces_info import description, examples, initial_prompt_value

#some constants
API_URL = os.getenv("API_URL")
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
##Bloom Inference API
API_URL = "https://api-inference.huggingface.co/models/bigscience/bloom"
#HF_API_TOKEN = os.environ["HF_API_TOKEN"]
headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}

print(API_URL)
print(HF_API_TOKEN)


def query(payload):
    print(payload)
    response = requests.request("POST", API_URL, json=payload, headers={"Authorization": f"Bearer {HF_API_TOKEN}"})
    print(response)
    return json.loads(response.content.decode("utf-8"))



def inference(input_sentence, max_length, sample_or_greedy, seed=42):
    if sample_or_greedy == "Sample":
        parameters = {
            "max_new_tokens": max_length,
            "top_p": 0.9,
            "do_sample": True,
            "seed": seed,
            "early_stopping": False,
            "length_penalty": 0.0,
            "eos_token_id": None,
        }
    else:
        parameters = {
            "max_new_tokens": max_length,
            "do_sample": False,
            "seed": seed,
            "early_stopping": False,
            "length_penalty": 0.0,
            "eos_token_id": None,
        }

    payload = {"inputs": input_sentence, "parameters": parameters,"options" : {"use_cache": False} }

    data = query(payload)

    if "error" in data:
        return (None, None, f"<span style='color:red'>ERROR: {data['error']} </span>")

    generation = data[0]["generated_text"].split(input_sentence, 1)[1]
    print(generation)
    '''
    return (
        input_sentence
        + prompt_to_generation
        + generation
        + after_generation,
        data[0]["generated_text"],
        "",
    )
    '''
    return input_sentence + generation





def self_caption(image):
  repo_name = "ydshieh/vit-gpt2-coco-en"
  test_image = image
  feature_extractor2 = ViTFeatureExtractor.from_pretrained(repo_name)
  tokenizer = AutoTokenizer.from_pretrained(repo_name)
  model2 = VisionEncoderDecoderModel.from_pretrained(repo_name)
  pixel_values = feature_extractor2(test_image, return_tensors="pt").pixel_values
  print("Pixel Values")
  print(pixel_values)
  # autoregressively generate text (using beam search or other decoding strategy)
  generated_ids = model2.generate(pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True)
  
  # decode into text
  preds = tokenizer.batch_decode(generated_ids[0], skip_special_tokens=True)
  preds = [pred.strip() for pred in preds]
  print("Predictions")
  print(preds)
  print("The preds type is : ",type(preds))
  pred_keys = ["Prediction"]
  pred_value = preds

  pred_dictionary = dict(zip(pred_keys, pred_value))
  print("Pred dictionary")
  print(pred_dictionary)
 
  preds = ' '.join(preds)
  #inference(input_sentence, max_length, sample_or_greedy, seed=42)
  story = inference(preds, 64, "Sample", 42) 

  return story


def classify_image(image):
  config = ViTConfig(num_hidden_layers=12, hidden_size=768)
  model = ViTForImageClassification(config)

  #print(config)

  feature_extractor = ViTFeatureExtractor()
  # or, to load one that corresponds to a checkpoint on the hub:
  #feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")

  #the following gets called by classify_image() 
  feature_extractor = PerceiverFeatureExtractor.from_pretrained("deepmind/vision-perceiver-conv")
  model = PerceiverForImageClassificationConvProcessing.from_pretrained("deepmind/vision-perceiver-conv")
  #google/vit-base-patch16-224, deepmind/vision-perceiver-conv
  image_pipe = ImageClassificationPipeline(model=model, feature_extractor=feature_extractor)
 
  
  results = image_pipe(image)
  
  print("RESULTS")
  print(results)
  # convert to format Gradio expects
  output = {}
  for prediction in results:
    predicted_label = prediction['label']
    score = prediction['score']
    output[predicted_label] = score
  print("OUTPUT")
  print(output)
  return output


image = gr.inputs.Image(type="pil")
label = gr.outputs.Label(num_top_classes=5)
examples = [ ["cats.jpg"], ["batter.jpg"],["drinkers.jpg"] ]
#examples = [ ["batter.jpg"] ] 
title = "Generate a Story from an Image using BLOOM"
description = "Demo for classifying images with Perceiver IO. To use it, simply upload an image and click 'submit', a story is autogenerated as well, story generated using Bigscience/BLOOM"
article = "<p style='text-align: center'></p>"

img_info1 = gr.Interface(
    fn=classify_image,
    inputs=image,
    outputs=label,
)

img_info2 = gr.Interface(
    fn=self_caption,
    inputs=image,
    #outputs=label,
    outputs = [
    gr.outputs.Textbox(label = 'Story')
],
)

Parallel(img_info1,img_info2, inputs=image, title=title, description=description, examples=examples, enable_queue=True).launch(debug=True)