Commit
·
b828aa0
1
Parent(s):
197e8c2
repaired import
Browse files- modeling_ocrqa.py +5 -9
modeling_ocrqa.py
CHANGED
@@ -45,14 +45,6 @@ def normalize_text(s: str, unicode_normalize: Optional[str] = "NFKC") -> str:
|
|
45 |
s = unicodedata.normalize(unicode_normalize, s).lower()
|
46 |
return s.translate(NORMALIZATION_TABLE)
|
47 |
|
48 |
-
MODEL_NAME = "impresso-project/impresso-langident"
|
49 |
-
|
50 |
-
lang_pipeline = pipeline(
|
51 |
-
"langident",
|
52 |
-
model=MODEL_NAME,
|
53 |
-
trust_remote_code=True,
|
54 |
-
device="cpu",
|
55 |
-
)
|
56 |
|
57 |
def filter_text(text: str, bloom_filter: BloomFilter):
|
58 |
|
@@ -100,6 +92,10 @@ class QAAssessmentModel(PreTrainedModel):
|
|
100 |
# print(f"{bin_filename} not found locally, downloading from Hugging Face hub...")
|
101 |
self.ocrqa_assessors[lang] = hf_hub_download(repo_id=self.config.config._name_or_path,
|
102 |
filename=model_filename)
|
|
|
|
|
|
|
|
|
103 |
|
104 |
def forward(self, input_ids, **kwargs):
|
105 |
if isinstance(input_ids, str):
|
@@ -112,7 +108,7 @@ class QAAssessmentModel(PreTrainedModel):
|
|
112 |
|
113 |
predictions, probabilities = [], []
|
114 |
for text in texts:
|
115 |
-
langs = lang_pipeline(input_ids)
|
116 |
# [{'label': 'fr', 'confidence': 99.87}]
|
117 |
if len(langs) > 0:
|
118 |
lang = langs[0]['label']
|
|
|
45 |
s = unicodedata.normalize(unicode_normalize, s).lower()
|
46 |
return s.translate(NORMALIZATION_TABLE)
|
47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
def filter_text(text: str, bloom_filter: BloomFilter):
|
50 |
|
|
|
92 |
# print(f"{bin_filename} not found locally, downloading from Hugging Face hub...")
|
93 |
self.ocrqa_assessors[lang] = hf_hub_download(repo_id=self.config.config._name_or_path,
|
94 |
filename=model_filename)
|
95 |
+
self.lang_pipeline = pipeline("langident",
|
96 |
+
model="impresso-project/impresso-langident",
|
97 |
+
trust_remote_code=True,
|
98 |
+
device="cpu")
|
99 |
|
100 |
def forward(self, input_ids, **kwargs):
|
101 |
if isinstance(input_ids, str):
|
|
|
108 |
|
109 |
predictions, probabilities = [], []
|
110 |
for text in texts:
|
111 |
+
langs = self.lang_pipeline(input_ids)
|
112 |
# [{'label': 'fr', 'confidence': 99.87}]
|
113 |
if len(langs) > 0:
|
114 |
lang = langs[0]['label']
|