emanuelaboros commited on
Commit
b828aa0
·
1 Parent(s): 197e8c2

repaired import

Browse files
Files changed (1) hide show
  1. modeling_ocrqa.py +5 -9
modeling_ocrqa.py CHANGED
@@ -45,14 +45,6 @@ def normalize_text(s: str, unicode_normalize: Optional[str] = "NFKC") -> str:
45
  s = unicodedata.normalize(unicode_normalize, s).lower()
46
  return s.translate(NORMALIZATION_TABLE)
47
 
48
- MODEL_NAME = "impresso-project/impresso-langident"
49
-
50
- lang_pipeline = pipeline(
51
- "langident",
52
- model=MODEL_NAME,
53
- trust_remote_code=True,
54
- device="cpu",
55
- )
56
 
57
  def filter_text(text: str, bloom_filter: BloomFilter):
58
 
@@ -100,6 +92,10 @@ class QAAssessmentModel(PreTrainedModel):
100
  # print(f"{bin_filename} not found locally, downloading from Hugging Face hub...")
101
  self.ocrqa_assessors[lang] = hf_hub_download(repo_id=self.config.config._name_or_path,
102
  filename=model_filename)
 
 
 
 
103
 
104
  def forward(self, input_ids, **kwargs):
105
  if isinstance(input_ids, str):
@@ -112,7 +108,7 @@ class QAAssessmentModel(PreTrainedModel):
112
 
113
  predictions, probabilities = [], []
114
  for text in texts:
115
- langs = lang_pipeline(input_ids)
116
  # [{'label': 'fr', 'confidence': 99.87}]
117
  if len(langs) > 0:
118
  lang = langs[0]['label']
 
45
  s = unicodedata.normalize(unicode_normalize, s).lower()
46
  return s.translate(NORMALIZATION_TABLE)
47
 
 
 
 
 
 
 
 
 
48
 
49
  def filter_text(text: str, bloom_filter: BloomFilter):
50
 
 
92
  # print(f"{bin_filename} not found locally, downloading from Hugging Face hub...")
93
  self.ocrqa_assessors[lang] = hf_hub_download(repo_id=self.config.config._name_or_path,
94
  filename=model_filename)
95
+ self.lang_pipeline = pipeline("langident",
96
+ model="impresso-project/impresso-langident",
97
+ trust_remote_code=True,
98
+ device="cpu")
99
 
100
  def forward(self, input_ids, **kwargs):
101
  if isinstance(input_ids, str):
 
108
 
109
  predictions, probabilities = [], []
110
  for text in texts:
111
+ langs = self.lang_pipeline(input_ids)
112
  # [{'label': 'fr', 'confidence': 99.87}]
113
  if len(langs) > 0:
114
  lang = langs[0]['label']