File size: 2,465 Bytes
093705f
38356e3
 
1c2f184
 
 
a02b10a
 
 
 
 
 
 
2e668a6
093705f
38356e3
 
 
2e668a6
 
 
093705f
 
38356e3
 
1c2f184
a02b10a
1c2f184
38356e3
afb71f7
b6abde0
 
a02b10a
 
38356e3
093705f
 
38356e3
093705f
 
a02b10a
093705f
 
 
a02b10a
093705f
 
 
38356e3
093705f
 
38356e3
2e668a6
 
 
 
 
093705f
2e668a6
38356e3
2e668a6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from transformers import MarkupLMProcessor, MarkupLMForQuestionAnswering
import torch

import logging
import json

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler()  # This goes to stdout
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)

# This is used https://github.com/NielsRogge/Transformers-Tutorials/blob/master/MarkupLM/Inference_with_MarkupLM_for_question_answering_on_web_pages.ipynb
# https://www.naveedafzal.com/posts/scraping-websites-by-asking-questions-with-markuplm/
class EndpointHandler:
    def __init__(self, path=""):
        # Load model, tokenizer, and feature extractor
        # logger.debug("Loading model from: " + path)

        # WE ARE CURRENTLY NOT USING OUR REPO'S MODEL
        self.processor = MarkupLMProcessor.from_pretrained("microsoft/markuplm-large-finetuned-websrc")
        self.model = MarkupLMForQuestionAnswering.from_pretrained("microsoft/markuplm-large-finetuned-websrc")

    def __call__(self, data):
        # Use logger.debug() with a json dump for complex objects:
        logger.debug("Full input: %s", json.dumps(data, indent=2))

        # Extract inputs from data
        httpInputs = data.get("inputs", "")
        html = httpInputs.get("context", "")
        question = httpInputs.get("question", "")
        logger.debug("HTML: %s", json.dumps(html, indent=2))
        logger.debug("Question: %s", json.dumps(question, indent=2))

        # Encode the inputs
        encoding = self.processor(html, questions=question, return_tensors="pt")

        for k,v in encoding.items():
            print(k,v.shape)

        # Perform inference
        with torch.no_grad():
            outputs = self.model(**encoding)

        # Extract the answer
        answer_start_index = outputs.start_logits.argmax()
        answer_end_index = outputs.end_logits.argmax()

        predict_answer_tokens = encoding.input_ids[0, answer_start_index : answer_end_index + 1]
        answer = self.processor.decode(predict_answer_tokens, skip_special_tokens=True)

        # Get the score
        start_score = outputs.start_logits[0, answer_start_index].item()
        end_score = outputs.end_logits[0, answer_end_index].item()
        score = (start_score + end_score) / 2

        print(f"Answer: {answer}")
        print(f"Score: {score}")

        return {"answer": answer, "score": score}