codic commited on
Commit
bcfd2c7
·
verified ·
1 Parent(s): 5b1ad96

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -48
app.py CHANGED
@@ -6,12 +6,11 @@ import os
6
  import logging
7
  from pathlib import Path
8
  from gliner import GLiNER
9
- from io import BytesIO
10
  import cv2
11
  import re
12
  from PIL import Image
13
  import traceback
14
- import tempfile
15
  from difflib import SequenceMatcher
16
 
17
  # Configure logging
@@ -25,16 +24,12 @@ os.environ['TRANSFORMERS_CACHE'] = '/tmp/.gliner_models/cache'
25
  def initialize_models():
26
  """Initialize models with error handling and retries"""
27
  try:
28
- # Initialize EasyOCR
29
  logger.info("Initializing EasyOCR...")
30
  reader = easyocr.Reader(['en', 'ar'],
31
  download_enabled=True,
32
  model_storage_directory='/tmp/.easyocr_models')
33
-
34
- # Initialize GLiNER
35
  logger.info("Initializing GLiNER...")
36
  model_path = Path(os.environ['GLINER_HOME']) / 'gliner_large-v2.1'
37
-
38
  if not model_path.exists():
39
  logger.info("Downloading GLiNER model...")
40
  model_path.parent.mkdir(parents=True, exist_ok=True)
@@ -42,10 +37,8 @@ def initialize_models():
42
  model.save_pretrained(str(model_path))
43
  else:
44
  model = GLiNER.from_pretrained(str(model_path))
45
-
46
  logger.info("Models initialized successfully")
47
  return reader, model
48
-
49
  except Exception as e:
50
  logger.error(f"Model initialization failed: {str(e)}")
51
  raise
@@ -59,29 +52,31 @@ except Exception as e:
59
  def clean_extracted_text(text):
60
  """Clean the extracted text with proper error handling"""
61
  try:
62
- # Preserve Arabic and basic Latin characters along with digits and common punctuation
63
  cleaned = re.sub(
64
  r'[^\u0600-\u06FF\u0750-\u077F\u08A0-\u08FFA-Za-z0-9\s@.,-]',
65
  '',
66
  text
67
  )
68
- # Normalize whitespace
69
  return re.sub(r'\s+', ' ', cleaned).strip()
70
  except Exception as e:
71
  logger.error(f"Text cleaning failed: {traceback.format_exc()}")
72
- return text # Return raw text as fallback
73
 
74
- def preprocess_image(image):
75
- """Image preprocessing with validation"""
76
  try:
77
  if not isinstance(image, np.ndarray):
78
  image = np.array(image)
79
-
80
- if len(image.shape) == 2: # Already grayscale
 
 
 
 
 
81
  gray = image
82
  else:
83
  gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
84
-
85
  denoised = cv2.medianBlur(gray, 3)
86
  _, thresh = cv2.threshold(denoised, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
87
  return thresh
@@ -103,47 +98,35 @@ def clean_and_deduplicate(entities):
103
  for label, values in entities.items():
104
  unique = []
105
  for val in values:
106
- # Validate and clean specific fields
107
  if label.lower() == "email":
108
  match = re.search(r'[\w\.-]+@[\w\.-]+', val)
109
  val = match.group(0) if match else val
110
  elif label.lower() == "phone":
111
  match = re.search(r'\+?\d[\d\s\-]{7,}\d', val)
112
  val = match.group(0) if match else val
113
-
114
- # Avoid adding duplicates or near-duplicates
115
  if not any(similar(val, exist) for exist in unique):
116
  unique.append(val)
117
  cleaned_results[label] = unique
118
  return cleaned_results
119
 
120
  def process_single_image(image, threshold=0.3, nested_ner=True, progress=gr.Progress()):
121
- """Process single image with detailed error handling and improved entity cleanup"""
122
  try:
123
- # Validate input
124
  if image is None:
125
  raise ValueError("No image provided")
126
-
127
  progress(0.1, "Validating input...")
128
  if not isinstance(image, (Image.Image, np.ndarray)):
129
  raise TypeError(f"Invalid image type: {type(image)}")
130
-
131
- # Preprocessing
132
  progress(0.2, "Preprocessing image...")
133
  preprocessed = preprocess_image(image)
134
-
135
- # OCR
136
  progress(0.4, "Performing OCR...")
137
  try:
138
  ocr_results = reader.readtext(preprocessed, detail=0, paragraph=True)
139
  except Exception as e:
140
  logger.error(f"OCR failed: {traceback.format_exc()}")
141
  raise RuntimeError("OCR processing failed") from e
142
-
143
  raw_text = " ".join(ocr_results)
144
  clean_text = clean_extracted_text(raw_text)
145
-
146
- # Entity extraction
147
  progress(0.6, "Extracting entities...")
148
  try:
149
  labels = ["person name", "company name", "job title", "phone", "email", "address"]
@@ -156,38 +139,32 @@ def process_single_image(image, threshold=0.3, nested_ner=True, progress=gr.Prog
156
  except Exception as e:
157
  logger.error(f"Entity extraction failed: {traceback.format_exc()}")
158
  raise RuntimeError("Entity extraction failed") from e
159
-
160
- # Format raw results into a dictionary by label
161
  results = {label.title(): [] for label in labels}
162
  for entity in entities:
163
  label = entity["label"].title()
164
  if label in results:
165
  results[label].append(entity["text"])
166
-
167
- # Post-process the extracted entities for deduplication and validation
168
  cleaned_entities = clean_and_deduplicate(results)
169
-
170
- # Create temporary CSV file with final results
171
- with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as tmp_file:
172
- pd.DataFrame([{
173
- k: "; ".join(v) for k, v in cleaned_entities.items()
174
- }]).to_csv(tmp_file.name, index=False)
175
- csv_path = tmp_file.name
176
-
177
  return (
178
  clean_text, # Text output (str)
179
  {k: "; ".join(v) for k, v in cleaned_entities.items()}, # JSON output (dict)
180
- csv_path, # File path (str)
181
  "" # Empty error message (str)
182
  )
183
 
184
  except Exception as e:
185
  logger.error(f"Processing failed: {traceback.format_exc()}")
186
  return (
187
- "", # Empty text
188
- {}, # Empty JSON
189
- None, # No file
190
- f"Error: {str(e)}\n{traceback.format_exc()}" # Error details
191
  )
192
 
193
  # Gradio Interface setup
@@ -207,7 +184,6 @@ with gr.Blocks() as app:
207
  error_output = gr.Textbox(label="Error Details", visible=False)
208
  csv_download_single = gr.File(label="Download Results")
209
 
210
- # Update click handler to show errors
211
  submit_single.click(
212
  fn=process_single_image,
213
  inputs=[single_image, threshold_single, nested_ner_single],
@@ -223,4 +199,4 @@ app.launch(
223
  debug=True,
224
  show_error=True,
225
  share=False
226
- )
 
6
  import logging
7
  from pathlib import Path
8
  from gliner import GLiNER
 
9
  import cv2
10
  import re
11
  from PIL import Image
12
  import traceback
13
+ import io # For in-memory file handling
14
  from difflib import SequenceMatcher
15
 
16
  # Configure logging
 
24
  def initialize_models():
25
  """Initialize models with error handling and retries"""
26
  try:
 
27
  logger.info("Initializing EasyOCR...")
28
  reader = easyocr.Reader(['en', 'ar'],
29
  download_enabled=True,
30
  model_storage_directory='/tmp/.easyocr_models')
 
 
31
  logger.info("Initializing GLiNER...")
32
  model_path = Path(os.environ['GLINER_HOME']) / 'gliner_large-v2.1'
 
33
  if not model_path.exists():
34
  logger.info("Downloading GLiNER model...")
35
  model_path.parent.mkdir(parents=True, exist_ok=True)
 
37
  model.save_pretrained(str(model_path))
38
  else:
39
  model = GLiNER.from_pretrained(str(model_path))
 
40
  logger.info("Models initialized successfully")
41
  return reader, model
 
42
  except Exception as e:
43
  logger.error(f"Model initialization failed: {str(e)}")
44
  raise
 
52
  def clean_extracted_text(text):
53
  """Clean the extracted text with proper error handling"""
54
  try:
 
55
  cleaned = re.sub(
56
  r'[^\u0600-\u06FF\u0750-\u077F\u08A0-\u08FFA-Za-z0-9\s@.,-]',
57
  '',
58
  text
59
  )
 
60
  return re.sub(r'\s+', ' ', cleaned).strip()
61
  except Exception as e:
62
  logger.error(f"Text cleaning failed: {traceback.format_exc()}")
63
+ return text
64
 
65
+ def preprocess_image(image, max_dim=1024):
66
+ """Image preprocessing with validation and optional resizing"""
67
  try:
68
  if not isinstance(image, np.ndarray):
69
  image = np.array(image)
70
+ # Optional: Resize if the image is too large (keeping aspect ratio)
71
+ h, w = image.shape[:2]
72
+ if max(h, w) > max_dim:
73
+ scaling = max_dim / float(max(h, w))
74
+ image = cv2.resize(image, (int(w * scaling), int(h * scaling)))
75
+ # Convert to grayscale if needed
76
+ if len(image.shape) == 2:
77
  gray = image
78
  else:
79
  gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
 
80
  denoised = cv2.medianBlur(gray, 3)
81
  _, thresh = cv2.threshold(denoised, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
82
  return thresh
 
98
  for label, values in entities.items():
99
  unique = []
100
  for val in values:
 
101
  if label.lower() == "email":
102
  match = re.search(r'[\w\.-]+@[\w\.-]+', val)
103
  val = match.group(0) if match else val
104
  elif label.lower() == "phone":
105
  match = re.search(r'\+?\d[\d\s\-]{7,}\d', val)
106
  val = match.group(0) if match else val
 
 
107
  if not any(similar(val, exist) for exist in unique):
108
  unique.append(val)
109
  cleaned_results[label] = unique
110
  return cleaned_results
111
 
112
  def process_single_image(image, threshold=0.3, nested_ner=True, progress=gr.Progress()):
113
+ """Process single image with detailed error handling, optimized I/O, and entity cleanup"""
114
  try:
 
115
  if image is None:
116
  raise ValueError("No image provided")
 
117
  progress(0.1, "Validating input...")
118
  if not isinstance(image, (Image.Image, np.ndarray)):
119
  raise TypeError(f"Invalid image type: {type(image)}")
 
 
120
  progress(0.2, "Preprocessing image...")
121
  preprocessed = preprocess_image(image)
 
 
122
  progress(0.4, "Performing OCR...")
123
  try:
124
  ocr_results = reader.readtext(preprocessed, detail=0, paragraph=True)
125
  except Exception as e:
126
  logger.error(f"OCR failed: {traceback.format_exc()}")
127
  raise RuntimeError("OCR processing failed") from e
 
128
  raw_text = " ".join(ocr_results)
129
  clean_text = clean_extracted_text(raw_text)
 
 
130
  progress(0.6, "Extracting entities...")
131
  try:
132
  labels = ["person name", "company name", "job title", "phone", "email", "address"]
 
139
  except Exception as e:
140
  logger.error(f"Entity extraction failed: {traceback.format_exc()}")
141
  raise RuntimeError("Entity extraction failed") from e
 
 
142
  results = {label.title(): [] for label in labels}
143
  for entity in entities:
144
  label = entity["label"].title()
145
  if label in results:
146
  results[label].append(entity["text"])
 
 
147
  cleaned_entities = clean_and_deduplicate(results)
148
+
149
+ # Generate CSV output in-memory to reduce disk I/O
150
+ csv_io = io.BytesIO()
151
+ pd.DataFrame([{k: "; ".join(v) for k, v in cleaned_entities.items()}]).to_csv(csv_io, index=False)
152
+ csv_io.seek(0)
153
+
 
 
154
  return (
155
  clean_text, # Text output (str)
156
  {k: "; ".join(v) for k, v in cleaned_entities.items()}, # JSON output (dict)
157
+ csv_io, # In-memory file (BytesIO)
158
  "" # Empty error message (str)
159
  )
160
 
161
  except Exception as e:
162
  logger.error(f"Processing failed: {traceback.format_exc()}")
163
  return (
164
+ "",
165
+ {},
166
+ None,
167
+ f"Error: {str(e)}\n{traceback.format_exc()}"
168
  )
169
 
170
  # Gradio Interface setup
 
184
  error_output = gr.Textbox(label="Error Details", visible=False)
185
  csv_download_single = gr.File(label="Download Results")
186
 
 
187
  submit_single.click(
188
  fn=process_single_image,
189
  inputs=[single_image, threshold_single, nested_ner_single],
 
199
  debug=True,
200
  show_error=True,
201
  share=False
202
+ )