emanuelaboros's picture
moved to language-identifier
70b421f
raw
history blame contribute delete
5.26 kB
import torch
import torch.nn as nn
from transformers import PreTrainedModel
import logging
import floret
import os
from huggingface_hub import hf_hub_download
from .configuration_ocrqa import ImpressoConfig
logger = logging.getLogger(__name__)
from pybloomfilter import BloomFilter
from transformers import pipeline
import unicodedata
from typing import Optional
QUOTES_PUNCT = "„•<>!\"#%&'’"
ASCII_PUNCT = "()*,./:;?"
BRACKETS_SPECIAL = "[]\\~_{}"
UNICODE_PUNCT = "\xa1\xab\xb7\xbb\xbf"
DASH_CARET = "—^`"
SPECIAL_SYMBOLS = "¦§£="
HYPHEN = "-"
DIGITS = "0123456789"
NORMALIZATION_TABLE = str.maketrans(
{
char: " "
for char in (
QUOTES_PUNCT
+ ASCII_PUNCT
+ BRACKETS_SPECIAL
+ UNICODE_PUNCT
+ DASH_CARET
+ SPECIAL_SYMBOLS
+ HYPHEN
)
}
| {char: "0" for char in DIGITS}
)
def normalize_text(s: str, unicode_normalize: Optional[str] = "NFKC") -> str:
"""Normalize text by replacing punctuation with spaces and digits with '0'."""
if unicode_normalize:
s = unicodedata.normalize(unicode_normalize, s).lower()
return s.translate(NORMALIZATION_TABLE)
def filter_text(text: str, bloom_filter: BloomFilter):
knowns = set()
unknowns = set()
# Normalize and tokenize text
normalized_text = normalize_text(text)
tokens = normalized_text.split()
# Check tokens against the bloom filter
for token in tokens:
if token in bloom_filter:
# print(f"'{token}' is in the bloom filter.")
knowns.add(token)
else:
# print(f"'{token}' is NOT in the bloom filter.")
unknowns.add(token)
result = {"known": knowns, "unknown": unknowns}
return result
class QAAssessmentModel(PreTrainedModel):
config_class = ImpressoConfig
def get_bloomfilter(self, model_id: str, filename: str):
return BloomFilter.open(hf_hub_download(repo_id=model_id, filename=filename))
def __init__(self, config):
super().__init__(config)
self.config = config
# Dummy for device checking
self.dummy_param = nn.Parameter(torch.zeros(1))
bin_filenames = {"en": self.config.config.filename["en"],
"de": self.config.config.filename["de"],
"fr": self.config.config.filename["fr"],
"other": self.config.config.filename["other"]}
self.ocrqa_assessors = {}
# model_filename = self.config.config.model[lang]
for lang in bin_filenames.keys():
model_filename = self.config.config.filename[lang]
print(f"Loading model for {lang}: {model_filename}")
# if not os.path.exists(model_filename):
# print(f"{bin_filename} not found locally, downloading from Hugging Face hub...")
self.ocrqa_assessors[lang] = self.get_bloomfilter(model_id=self.config.config._name_or_path,
filename=model_filename)
print(self.ocrqa_assessors)
self.lang_pipeline = pipeline("langident",
model="impresso-project/language-identifier",
trust_remote_code=True,
device="cpu")
def forward(self, input_ids, **kwargs):
if isinstance(input_ids, str):
# If the input is a single string, make it a list for floret
texts = [input_ids]
elif isinstance(input_ids, list) and all(isinstance(t, str) for t in input_ids):
texts = input_ids
else:
raise ValueError(f"Unexpected input type: {type(input_ids)}")
predictions, probabilities = [], []
for text in texts:
langs = self.lang_pipeline(input_ids)
# [{'label': 'fr', 'confidence': 99.87}]
if len(langs) > 0:
print(f"Detected languages: {langs}")
lang = langs['language']
logger.info(f"Detected language: {lang}")
else:
lang = "other"
logger.warning("Language detection failed, using 'other' as default.")
if lang not in self.ocrqa_assessors:
logger.warning(f"Language '{lang}' not found in bin_filename, using 'other' as default.")
lang = "other"
# Process the text using the selected filter
result = filter_text(text, self.ocrqa_assessors[lang])
known_count = len(result["known"])
unknown_count = len(result["unknown"])
# Compute quality score percentage
score = (known_count / (known_count + unknown_count + 0.000001)) # * 100
predictions.append(score)
return predictions
@property
def device(self):
return next(self.parameters()).device
@classmethod
def from_pretrained(cls, *args, **kwargs):
# print("Ignoring weights and using custom initialization.")
# Manually create the config
config = ImpressoConfig(**kwargs)
# Pass the manually created config to the class
model = cls(config)
return model