import re
import gradio as gr

import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel
from PIL import Image
import requests
from io import BytesIO
import json
import os


processor = DonutProcessor.from_pretrained("to-be/donut-base-finetuned-invoices")
model = VisionEncoderDecoderModel.from_pretrained("to-be/donut-base-finetuned-invoices")

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)


def update_status(state):
    if state == "start_or_clear": 
        state = 'processing'   #current state becomes
        return (gr.update(value="snowangel.gif",visible=True),gr.update(value="snowangel.gif",visible=True))
    elif state == "processing":
        state = 'finished_processing' #current state becomes
        return (gr.update(value="",visible=False),gr.update(value="",visible=False))
    elif state == "finished_processing":
        state = 'processing'   #current state becomes
        return (gr.update(value="snowangel.gif",visible=True),gr.update(value="snowangel.gif",visible=True))

def process_document(image,sendimg):

    if sendimg == True:
        im1 = Image.fromarray(image)
    elif sendimg == False:
        im1 = Image.open('./no_image.jpg')
    #keep track of demo count
    resp = requests.get('https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fto-be%2Finvoice_document_headers_extraction_with_donut%2Fdemo&label=demos%20served&labelColor=%23edd239&countColor=%23d9e3f0')
    
    #send notification through telegram
    TOKEN = os.getenv('TELEGRAM_BOT_TOKEN')
    CHAT_ID = os.getenv('TELEGRAM_CHANNEL_ID')
    url = f'https://149.154.167.220/bot{TOKEN}/sendPhoto?chat_id={CHAT_ID}'
    bio = BytesIO()
    bio.name = 'image.jpeg'
    im1.save(bio, 'JPEG')
    bio.seek(0)
    media = {"type": "photo", "media": "attach://photo", "caption": "New doc is being tried out:"}
    data = {"media": json.dumps(media)}
    try:
        response = requests.post(url, files={'photo': bio}, data=data)
    except:
        print("telegram api error")
    # prepare encoder inputs
    pixel_values = processor(image, return_tensors="pt").pixel_values
    
    # prepare decoder inputs
    task_prompt = "<s_cord-v2>"
    decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
          
    # generate answer
    outputs = model.generate(
        pixel_values.to(device),
        decoder_input_ids=decoder_input_ids.to(device),
        max_length=model.decoder.config.max_position_embeddings,
        early_stopping=True,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        use_cache=True,
        num_beams=1,
        bad_words_ids=[[processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True,
    )
    
    # postprocess
    sequence = processor.batch_decode(outputs.sequences)[0]
    sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
    sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()  # remove first task start token

    img2.update(visible=False)
    return processor.token2json(sequence), image

title = '<table align="center" border="0" cellpadding="1" cellspacing="1" ><tbody><tr><td style="text-align:center"><img alt="" src="https://huggingface.co/spaces/to-be/invoice_document_headers_extraction_with_donut/resolve/main/circling_small.gif" style="float:right; height:50px; width:50px" /></td><td style="text-align:center"><h1>Demo: invoice header extraction with Donut</h1></td><td style="text-align:center"><img alt="" src="https://huggingface.co/spaces/to-be/invoice_document_headers_extraction_with_donut/resolve/main/circling2_small.gif" style="float:left; height:50px; width:50px" /></td></tr></tbody></table>'
paragraph0 = '<p><strong>(update 29/03/2023: for more info, you can read <a href="https://toon-beerten.medium.com/hands-on-document-data-extraction-with-transformer-7130df3b6132">my article on medium</a>)<br />(update 28/04/2023: want to finetune with your own data? Read&nbsp;<a href="https://towardsdatascience.com/ocr-free-document-data-extraction-with-transformers-1-2-b5a826bc2ac3">this article</a>)</strong></p>'
paragraph1 = '<p>Basic idea of the base 🍩 model is to give it an image as input and extract indexes as text. No bounding boxes or confidences are generated.<br /> I finetuned it on invoices. For more info, see the <a href="https://arxiv.org/abs/2111.15664">original paper</a>&nbsp;and the 🤗&nbsp;<a href="https://huggingface.co/naver-clova-ix/donut-base">model</a>.</p>'
paragraph2 = '<p><strong>Training</strong>:<br />The model was trained with a few thousand of annotated invoices and non-invoices (for those the doctype will be &#39;Other&#39;). They span across different countries and languages. They are always one page only. The dataset is proprietary unfortunately.&nbsp;Model is set to input resolution of 1280x1920 pixels. So any sample you want to try with higher dpi than 150 has no added value.<br />It was trained for about 4 hours on a&nbsp;NVIDIA RTX A4000 for 20k steps with a val_metric of&nbsp;0.03413819904382196 at the end.<br />The <u>following indexes</u> were included in the train set:</p><ul><li><span style="font-family:Calibri"><span style="color:black">DocType</span></span></li><li><span style="font-family:Calibri"><span style="color:black">Currency</span></span></li><li><span style="font-family:Calibri"><span style="color:black">DocumentDate</span></span></li><li><span style="font-family:Calibri"><span style="color:black">GrossAmount</span></span></li><li><span style="font-family:Calibri"><span style="color:black">InvoiceNumber</span></span></li><li><span style="font-family:Calibri"><span style="color:black">NetAmount</span></span></li><li><span style="font-family:Calibri"><span style="color:black">TaxAmount</span></span></li><li><span style="font-family:Calibri"><span style="color:black">OrderNumber</span></span></li><li><span style="font-family:Calibri"><span style="color:black">CreditorCountry</span></span></li></ul>'
paragraph3 = '<p><strong>Benchmark observations:</strong><br />From all documents in the validation set,&nbsp; 60% of them had all indexes captured correctly.</p><p>Here are the results per index:</p><p style="margin-left:40px"><img alt="" src="https://s3.amazonaws.com/moonup/production/uploads/1677749023966-6335a49ceb6132ca653239a0.png" style="height:70%; width:70%" /></p><p>Some other observations:<br />- when trying with a non invoice document, it&#39;s quite reliably identified as Doctype: &#39;Other&#39;<br />- validation set contained mostly same layout invoices as the train set. If it was validated against completely differently sourced invoices, the results would be different<br />- Document date is able to be recognized across different notations, however, it&#39;s often wrong because the data set was not diverse (as in time span of dates) enough</p>'
#demo = gr.Interface(fn=process_document,inputs=gr_image,outputs="json",title="Demo: Donut 🍩 for invoice header retrieval", description=description,
#    article=article,enable_queue=True, examples=[["example.jpg"], ["example_2.jpg"], ["example_3.jpg"]], cache_examples=False)
paragraph4 = '<p><strong>Try it out:</strong><br />To use it, simply upload your image and click &#39;submit&#39;, or click one of the examples to load them.<br /><em>(because this is running on the free cpu tier, it will take about 40 secs before you see a result. On a GPU it takes less than 2 seconds)</em></p><p>&nbsp;</p><p>Have fun&nbsp;😎</p><p>Toon Beerten</p>'
smallprint = '<p>✤&nbsp;<span style="font-size:11px">To get an idea of the usage, you can opt to let me get personally notified via Telegram with the image uploaded. All data will be automatically deleted after 48 hours</span></p>'
css = "#inp {height: auto !important; width: 100% !important;}"
visit_badge = '<a href="https://visitorbadge.io/status?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fto-be%2Finvoice_document_headers_extraction_with_donut"><img src="https://api.visitorbadge.io/api/combined?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fto-be%2Finvoice_document_headers_extraction_with_donut&labelColor=%23edd239&countColor=%23d9e3f0&style=flat" /></a>'

# css = "@media screen and (max-width: 600px) { .output_image, .input_image {height:20rem !important; width: 100% !important;} }"
# css = ".output_image, .input_image {height: 600px !important}"

#css = ".image-preview {height: auto !important;}"
#css='div {margin-left: auto; margin-right: auto; width: 100%;background-image: url("background.gif"); repeat 0 0;}')

with gr.Blocks(css=css) as demo:
    state = gr.State(value='start_or_clear')
    
    gr.HTML(title)
    gr.HTML(paragraph0)
    gr.HTML(paragraph1)
    gr.HTML(paragraph2)
    gr.HTML(paragraph3)
    gr.HTML(paragraph4)
    
    with gr.Row().style():
        with gr.Column(scale=1):
            inp = gr.Image(label='Upload invoice here:')   #.style(height=400)          
        with gr.Column(scale=2):
             gr.Examples([["example.jpg"], ["example_2.jpg"], ["example_3.jpg"]], inputs=[inp],label='Or use one of these examples:')
    with gr.Row().style(equal_height=True,height=200,rounded=False):       
        with gr.Column(scale=1):
            img2 = gr.Image("drinking.gif",label=' ',visible=False).style(rounded=True)
        with gr.Column(scale=2):
            btn = gr.Button(" ↓   Extract   ↓ ")
        with gr.Column(scale=2):
            #img3 = gr.Image("snowangel.gif",label=' ',visible=False).style(rounded=True)
            sendimg = gr.Checkbox(value=True, label="Allow usage data collection for at most 48 hours ✤")
    with gr.Row().style():
        with gr.Column(scale=2):
            imgout = gr.Image(label='Uploaded document:',elem_id="inp")
        with gr.Column(scale=1):
            jsonout = gr.JSON(label='Extracted information:')
    #imgout.clear(fn=update_status,inputs=state,outputs=[img2,img3])    
    #imgout.change(fn=update_status,inputs=state,outputs=[img2,img3])        
    btn.click(fn=process_document, inputs=[inp,sendimg], outputs=[jsonout,imgout])

    gr.HTML(smallprint)
    gr.HTML(visit_badge)
    

demo.launch()