|
import torch |
|
import torch.nn as nn |
|
from transformers import PreTrainedModel |
|
import logging |
|
import floret |
|
import os |
|
from huggingface_hub import hf_hub_download |
|
from .configuration_ocrqa import ImpressoConfig |
|
|
|
logger = logging.getLogger(__name__) |
|
from pybloomfilter import BloomFilter |
|
from transformers import pipeline |
|
import unicodedata |
|
from typing import Optional |
|
|
|
QUOTES_PUNCT = "„•<>!\"#%&'’" |
|
ASCII_PUNCT = "()*,./:;?" |
|
BRACKETS_SPECIAL = "[]\\~_{}" |
|
UNICODE_PUNCT = "\xa1\xab\xb7\xbb\xbf" |
|
DASH_CARET = "—^`" |
|
SPECIAL_SYMBOLS = "¦§£=" |
|
HYPHEN = "-" |
|
DIGITS = "0123456789" |
|
|
|
NORMALIZATION_TABLE = str.maketrans( |
|
{ |
|
char: " " |
|
for char in ( |
|
QUOTES_PUNCT |
|
+ ASCII_PUNCT |
|
+ BRACKETS_SPECIAL |
|
+ UNICODE_PUNCT |
|
+ DASH_CARET |
|
+ SPECIAL_SYMBOLS |
|
+ HYPHEN |
|
) |
|
} |
|
| {char: "0" for char in DIGITS} |
|
) |
|
|
|
|
|
def normalize_text(s: str, unicode_normalize: Optional[str] = "NFKC") -> str: |
|
"""Normalize text by replacing punctuation with spaces and digits with '0'.""" |
|
if unicode_normalize: |
|
s = unicodedata.normalize(unicode_normalize, s).lower() |
|
return s.translate(NORMALIZATION_TABLE) |
|
|
|
|
|
def filter_text(text: str, bloom_filter: BloomFilter): |
|
|
|
knowns = set() |
|
unknowns = set() |
|
|
|
|
|
normalized_text = normalize_text(text) |
|
tokens = normalized_text.split() |
|
|
|
|
|
for token in tokens: |
|
if token in bloom_filter: |
|
|
|
knowns.add(token) |
|
else: |
|
|
|
unknowns.add(token) |
|
result = {"known": knowns, "unknown": unknowns} |
|
return result |
|
|
|
class QAAssessmentModel(PreTrainedModel): |
|
config_class = ImpressoConfig |
|
|
|
def get_bloomfilter(self, model_id: str, filename: str): |
|
return BloomFilter.open(hf_hub_download(repo_id=model_id, filename=filename)) |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
|
|
self.dummy_param = nn.Parameter(torch.zeros(1)) |
|
bin_filenames = {"en": self.config.config.filename["en"], |
|
"de": self.config.config.filename["de"], |
|
"fr": self.config.config.filename["fr"], |
|
"other": self.config.config.filename["other"]} |
|
|
|
self.ocrqa_assessors = {} |
|
|
|
|
|
for lang in bin_filenames.keys(): |
|
model_filename = self.config.config.filename[lang] |
|
print(f"Loading model for {lang}: {model_filename}") |
|
|
|
|
|
self.ocrqa_assessors[lang] = self.get_bloomfilter(model_id=self.config.config._name_or_path, |
|
filename=model_filename) |
|
|
|
print(self.ocrqa_assessors) |
|
self.lang_pipeline = pipeline("langident", |
|
model="impresso-project/language-identifier", |
|
trust_remote_code=True, |
|
device="cpu") |
|
|
|
def forward(self, input_ids, **kwargs): |
|
if isinstance(input_ids, str): |
|
|
|
texts = [input_ids] |
|
elif isinstance(input_ids, list) and all(isinstance(t, str) for t in input_ids): |
|
texts = input_ids |
|
else: |
|
raise ValueError(f"Unexpected input type: {type(input_ids)}") |
|
|
|
predictions, probabilities = [], [] |
|
for text in texts: |
|
langs = self.lang_pipeline(input_ids) |
|
|
|
if len(langs) > 0: |
|
print(f"Detected languages: {langs}") |
|
lang = langs['language'] |
|
logger.info(f"Detected language: {lang}") |
|
else: |
|
lang = "other" |
|
logger.warning("Language detection failed, using 'other' as default.") |
|
|
|
if lang not in self.ocrqa_assessors: |
|
logger.warning(f"Language '{lang}' not found in bin_filename, using 'other' as default.") |
|
lang = "other" |
|
|
|
|
|
result = filter_text(text, self.ocrqa_assessors[lang]) |
|
known_count = len(result["known"]) |
|
unknown_count = len(result["unknown"]) |
|
|
|
|
|
score = (known_count / (known_count + unknown_count + 0.000001)) |
|
predictions.append(score) |
|
|
|
return predictions |
|
|
|
@property |
|
def device(self): |
|
return next(self.parameters()).device |
|
|
|
@classmethod |
|
def from_pretrained(cls, *args, **kwargs): |
|
|
|
|
|
config = ImpressoConfig(**kwargs) |
|
|
|
model = cls(config) |
|
return model |
|
|