markuplm-large / handler.py
Tigran Tokmajyan
Various changes
2e668a6
raw
history blame
2.47 kB
from transformers import MarkupLMProcessor, MarkupLMForQuestionAnswering
import torch
import logging
import json
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler() # This goes to stdout
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
# This is used https://github.com/NielsRogge/Transformers-Tutorials/blob/master/MarkupLM/Inference_with_MarkupLM_for_question_answering_on_web_pages.ipynb
# https://www.naveedafzal.com/posts/scraping-websites-by-asking-questions-with-markuplm/
class EndpointHandler:
def __init__(self, path=""):
# Load model, tokenizer, and feature extractor
# logger.debug("Loading model from: " + path)
# WE ARE CURRENTLY NOT USING OUR REPO'S MODEL
self.processor = MarkupLMProcessor.from_pretrained("microsoft/markuplm-large-finetuned-websrc")
self.model = MarkupLMForQuestionAnswering.from_pretrained("microsoft/markuplm-large-finetuned-websrc")
def __call__(self, data):
# Use logger.debug() with a json dump for complex objects:
logger.debug("Full input: %s", json.dumps(data, indent=2))
# Extract inputs from data
httpInputs = data.get("inputs", "")
html = httpInputs.get("context", "")
question = httpInputs.get("question", "")
logger.debug("HTML: %s", json.dumps(html, indent=2))
logger.debug("Question: %s", json.dumps(question, indent=2))
# Encode the inputs
encoding = self.processor(html, questions=question, return_tensors="pt")
for k,v in encoding.items():
print(k,v.shape)
# Perform inference
with torch.no_grad():
outputs = self.model(**encoding)
# Extract the answer
answer_start_index = outputs.start_logits.argmax()
answer_end_index = outputs.end_logits.argmax()
predict_answer_tokens = encoding.input_ids[0, answer_start_index : answer_end_index + 1]
answer = self.processor.decode(predict_answer_tokens, skip_special_tokens=True)
# Get the score
start_score = outputs.start_logits[0, answer_start_index].item()
end_score = outputs.end_logits[0, answer_end_index].item()
score = (start_score + end_score) / 2
print(f"Answer: {answer}")
print(f"Score: {score}")
return {"answer": answer, "score": score}