Update app.py
Browse files
app.py
CHANGED
@@ -6,12 +6,11 @@ import os
|
|
6 |
import logging
|
7 |
from pathlib import Path
|
8 |
from gliner import GLiNER
|
9 |
-
from io import BytesIO
|
10 |
import cv2
|
11 |
import re
|
12 |
from PIL import Image
|
13 |
import traceback
|
14 |
-
import
|
15 |
from difflib import SequenceMatcher
|
16 |
|
17 |
# Configure logging
|
@@ -25,16 +24,12 @@ os.environ['TRANSFORMERS_CACHE'] = '/tmp/.gliner_models/cache'
|
|
25 |
def initialize_models():
|
26 |
"""Initialize models with error handling and retries"""
|
27 |
try:
|
28 |
-
# Initialize EasyOCR
|
29 |
logger.info("Initializing EasyOCR...")
|
30 |
reader = easyocr.Reader(['en', 'ar'],
|
31 |
download_enabled=True,
|
32 |
model_storage_directory='/tmp/.easyocr_models')
|
33 |
-
|
34 |
-
# Initialize GLiNER
|
35 |
logger.info("Initializing GLiNER...")
|
36 |
model_path = Path(os.environ['GLINER_HOME']) / 'gliner_large-v2.1'
|
37 |
-
|
38 |
if not model_path.exists():
|
39 |
logger.info("Downloading GLiNER model...")
|
40 |
model_path.parent.mkdir(parents=True, exist_ok=True)
|
@@ -42,10 +37,8 @@ def initialize_models():
|
|
42 |
model.save_pretrained(str(model_path))
|
43 |
else:
|
44 |
model = GLiNER.from_pretrained(str(model_path))
|
45 |
-
|
46 |
logger.info("Models initialized successfully")
|
47 |
return reader, model
|
48 |
-
|
49 |
except Exception as e:
|
50 |
logger.error(f"Model initialization failed: {str(e)}")
|
51 |
raise
|
@@ -59,29 +52,31 @@ except Exception as e:
|
|
59 |
def clean_extracted_text(text):
|
60 |
"""Clean the extracted text with proper error handling"""
|
61 |
try:
|
62 |
-
# Preserve Arabic and basic Latin characters along with digits and common punctuation
|
63 |
cleaned = re.sub(
|
64 |
r'[^\u0600-\u06FF\u0750-\u077F\u08A0-\u08FFA-Za-z0-9\s@.,-]',
|
65 |
'',
|
66 |
text
|
67 |
)
|
68 |
-
# Normalize whitespace
|
69 |
return re.sub(r'\s+', ' ', cleaned).strip()
|
70 |
except Exception as e:
|
71 |
logger.error(f"Text cleaning failed: {traceback.format_exc()}")
|
72 |
-
return text
|
73 |
|
74 |
-
def preprocess_image(image):
|
75 |
-
"""Image preprocessing with validation"""
|
76 |
try:
|
77 |
if not isinstance(image, np.ndarray):
|
78 |
image = np.array(image)
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
81 |
gray = image
|
82 |
else:
|
83 |
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
84 |
-
|
85 |
denoised = cv2.medianBlur(gray, 3)
|
86 |
_, thresh = cv2.threshold(denoised, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
87 |
return thresh
|
@@ -103,47 +98,35 @@ def clean_and_deduplicate(entities):
|
|
103 |
for label, values in entities.items():
|
104 |
unique = []
|
105 |
for val in values:
|
106 |
-
# Validate and clean specific fields
|
107 |
if label.lower() == "email":
|
108 |
match = re.search(r'[\w\.-]+@[\w\.-]+', val)
|
109 |
val = match.group(0) if match else val
|
110 |
elif label.lower() == "phone":
|
111 |
match = re.search(r'\+?\d[\d\s\-]{7,}\d', val)
|
112 |
val = match.group(0) if match else val
|
113 |
-
|
114 |
-
# Avoid adding duplicates or near-duplicates
|
115 |
if not any(similar(val, exist) for exist in unique):
|
116 |
unique.append(val)
|
117 |
cleaned_results[label] = unique
|
118 |
return cleaned_results
|
119 |
|
120 |
def process_single_image(image, threshold=0.3, nested_ner=True, progress=gr.Progress()):
|
121 |
-
"""Process single image with detailed error handling and
|
122 |
try:
|
123 |
-
# Validate input
|
124 |
if image is None:
|
125 |
raise ValueError("No image provided")
|
126 |
-
|
127 |
progress(0.1, "Validating input...")
|
128 |
if not isinstance(image, (Image.Image, np.ndarray)):
|
129 |
raise TypeError(f"Invalid image type: {type(image)}")
|
130 |
-
|
131 |
-
# Preprocessing
|
132 |
progress(0.2, "Preprocessing image...")
|
133 |
preprocessed = preprocess_image(image)
|
134 |
-
|
135 |
-
# OCR
|
136 |
progress(0.4, "Performing OCR...")
|
137 |
try:
|
138 |
ocr_results = reader.readtext(preprocessed, detail=0, paragraph=True)
|
139 |
except Exception as e:
|
140 |
logger.error(f"OCR failed: {traceback.format_exc()}")
|
141 |
raise RuntimeError("OCR processing failed") from e
|
142 |
-
|
143 |
raw_text = " ".join(ocr_results)
|
144 |
clean_text = clean_extracted_text(raw_text)
|
145 |
-
|
146 |
-
# Entity extraction
|
147 |
progress(0.6, "Extracting entities...")
|
148 |
try:
|
149 |
labels = ["person name", "company name", "job title", "phone", "email", "address"]
|
@@ -156,38 +139,32 @@ def process_single_image(image, threshold=0.3, nested_ner=True, progress=gr.Prog
|
|
156 |
except Exception as e:
|
157 |
logger.error(f"Entity extraction failed: {traceback.format_exc()}")
|
158 |
raise RuntimeError("Entity extraction failed") from e
|
159 |
-
|
160 |
-
# Format raw results into a dictionary by label
|
161 |
results = {label.title(): [] for label in labels}
|
162 |
for entity in entities:
|
163 |
label = entity["label"].title()
|
164 |
if label in results:
|
165 |
results[label].append(entity["text"])
|
166 |
-
|
167 |
-
# Post-process the extracted entities for deduplication and validation
|
168 |
cleaned_entities = clean_and_deduplicate(results)
|
169 |
-
|
170 |
-
#
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
csv_path = tmp_file.name
|
176 |
-
|
177 |
return (
|
178 |
clean_text, # Text output (str)
|
179 |
{k: "; ".join(v) for k, v in cleaned_entities.items()}, # JSON output (dict)
|
180 |
-
|
181 |
"" # Empty error message (str)
|
182 |
)
|
183 |
|
184 |
except Exception as e:
|
185 |
logger.error(f"Processing failed: {traceback.format_exc()}")
|
186 |
return (
|
187 |
-
"",
|
188 |
-
{},
|
189 |
-
None,
|
190 |
-
f"Error: {str(e)}\n{traceback.format_exc()}"
|
191 |
)
|
192 |
|
193 |
# Gradio Interface setup
|
@@ -207,7 +184,6 @@ with gr.Blocks() as app:
|
|
207 |
error_output = gr.Textbox(label="Error Details", visible=False)
|
208 |
csv_download_single = gr.File(label="Download Results")
|
209 |
|
210 |
-
# Update click handler to show errors
|
211 |
submit_single.click(
|
212 |
fn=process_single_image,
|
213 |
inputs=[single_image, threshold_single, nested_ner_single],
|
@@ -223,4 +199,4 @@ app.launch(
|
|
223 |
debug=True,
|
224 |
show_error=True,
|
225 |
share=False
|
226 |
-
)
|
|
|
6 |
import logging
|
7 |
from pathlib import Path
|
8 |
from gliner import GLiNER
|
|
|
9 |
import cv2
|
10 |
import re
|
11 |
from PIL import Image
|
12 |
import traceback
|
13 |
+
import io # For in-memory file handling
|
14 |
from difflib import SequenceMatcher
|
15 |
|
16 |
# Configure logging
|
|
|
24 |
def initialize_models():
|
25 |
"""Initialize models with error handling and retries"""
|
26 |
try:
|
|
|
27 |
logger.info("Initializing EasyOCR...")
|
28 |
reader = easyocr.Reader(['en', 'ar'],
|
29 |
download_enabled=True,
|
30 |
model_storage_directory='/tmp/.easyocr_models')
|
|
|
|
|
31 |
logger.info("Initializing GLiNER...")
|
32 |
model_path = Path(os.environ['GLINER_HOME']) / 'gliner_large-v2.1'
|
|
|
33 |
if not model_path.exists():
|
34 |
logger.info("Downloading GLiNER model...")
|
35 |
model_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
37 |
model.save_pretrained(str(model_path))
|
38 |
else:
|
39 |
model = GLiNER.from_pretrained(str(model_path))
|
|
|
40 |
logger.info("Models initialized successfully")
|
41 |
return reader, model
|
|
|
42 |
except Exception as e:
|
43 |
logger.error(f"Model initialization failed: {str(e)}")
|
44 |
raise
|
|
|
52 |
def clean_extracted_text(text):
|
53 |
"""Clean the extracted text with proper error handling"""
|
54 |
try:
|
|
|
55 |
cleaned = re.sub(
|
56 |
r'[^\u0600-\u06FF\u0750-\u077F\u08A0-\u08FFA-Za-z0-9\s@.,-]',
|
57 |
'',
|
58 |
text
|
59 |
)
|
|
|
60 |
return re.sub(r'\s+', ' ', cleaned).strip()
|
61 |
except Exception as e:
|
62 |
logger.error(f"Text cleaning failed: {traceback.format_exc()}")
|
63 |
+
return text
|
64 |
|
65 |
+
def preprocess_image(image, max_dim=1024):
|
66 |
+
"""Image preprocessing with validation and optional resizing"""
|
67 |
try:
|
68 |
if not isinstance(image, np.ndarray):
|
69 |
image = np.array(image)
|
70 |
+
# Optional: Resize if the image is too large (keeping aspect ratio)
|
71 |
+
h, w = image.shape[:2]
|
72 |
+
if max(h, w) > max_dim:
|
73 |
+
scaling = max_dim / float(max(h, w))
|
74 |
+
image = cv2.resize(image, (int(w * scaling), int(h * scaling)))
|
75 |
+
# Convert to grayscale if needed
|
76 |
+
if len(image.shape) == 2:
|
77 |
gray = image
|
78 |
else:
|
79 |
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
|
|
80 |
denoised = cv2.medianBlur(gray, 3)
|
81 |
_, thresh = cv2.threshold(denoised, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
82 |
return thresh
|
|
|
98 |
for label, values in entities.items():
|
99 |
unique = []
|
100 |
for val in values:
|
|
|
101 |
if label.lower() == "email":
|
102 |
match = re.search(r'[\w\.-]+@[\w\.-]+', val)
|
103 |
val = match.group(0) if match else val
|
104 |
elif label.lower() == "phone":
|
105 |
match = re.search(r'\+?\d[\d\s\-]{7,}\d', val)
|
106 |
val = match.group(0) if match else val
|
|
|
|
|
107 |
if not any(similar(val, exist) for exist in unique):
|
108 |
unique.append(val)
|
109 |
cleaned_results[label] = unique
|
110 |
return cleaned_results
|
111 |
|
112 |
def process_single_image(image, threshold=0.3, nested_ner=True, progress=gr.Progress()):
|
113 |
+
"""Process single image with detailed error handling, optimized I/O, and entity cleanup"""
|
114 |
try:
|
|
|
115 |
if image is None:
|
116 |
raise ValueError("No image provided")
|
|
|
117 |
progress(0.1, "Validating input...")
|
118 |
if not isinstance(image, (Image.Image, np.ndarray)):
|
119 |
raise TypeError(f"Invalid image type: {type(image)}")
|
|
|
|
|
120 |
progress(0.2, "Preprocessing image...")
|
121 |
preprocessed = preprocess_image(image)
|
|
|
|
|
122 |
progress(0.4, "Performing OCR...")
|
123 |
try:
|
124 |
ocr_results = reader.readtext(preprocessed, detail=0, paragraph=True)
|
125 |
except Exception as e:
|
126 |
logger.error(f"OCR failed: {traceback.format_exc()}")
|
127 |
raise RuntimeError("OCR processing failed") from e
|
|
|
128 |
raw_text = " ".join(ocr_results)
|
129 |
clean_text = clean_extracted_text(raw_text)
|
|
|
|
|
130 |
progress(0.6, "Extracting entities...")
|
131 |
try:
|
132 |
labels = ["person name", "company name", "job title", "phone", "email", "address"]
|
|
|
139 |
except Exception as e:
|
140 |
logger.error(f"Entity extraction failed: {traceback.format_exc()}")
|
141 |
raise RuntimeError("Entity extraction failed") from e
|
|
|
|
|
142 |
results = {label.title(): [] for label in labels}
|
143 |
for entity in entities:
|
144 |
label = entity["label"].title()
|
145 |
if label in results:
|
146 |
results[label].append(entity["text"])
|
|
|
|
|
147 |
cleaned_entities = clean_and_deduplicate(results)
|
148 |
+
|
149 |
+
# Generate CSV output in-memory to reduce disk I/O
|
150 |
+
csv_io = io.BytesIO()
|
151 |
+
pd.DataFrame([{k: "; ".join(v) for k, v in cleaned_entities.items()}]).to_csv(csv_io, index=False)
|
152 |
+
csv_io.seek(0)
|
153 |
+
|
|
|
|
|
154 |
return (
|
155 |
clean_text, # Text output (str)
|
156 |
{k: "; ".join(v) for k, v in cleaned_entities.items()}, # JSON output (dict)
|
157 |
+
csv_io, # In-memory file (BytesIO)
|
158 |
"" # Empty error message (str)
|
159 |
)
|
160 |
|
161 |
except Exception as e:
|
162 |
logger.error(f"Processing failed: {traceback.format_exc()}")
|
163 |
return (
|
164 |
+
"",
|
165 |
+
{},
|
166 |
+
None,
|
167 |
+
f"Error: {str(e)}\n{traceback.format_exc()}"
|
168 |
)
|
169 |
|
170 |
# Gradio Interface setup
|
|
|
184 |
error_output = gr.Textbox(label="Error Details", visible=False)
|
185 |
csv_download_single = gr.File(label="Download Results")
|
186 |
|
|
|
187 |
submit_single.click(
|
188 |
fn=process_single_image,
|
189 |
inputs=[single_image, threshold_single, nested_ner_single],
|
|
|
199 |
debug=True,
|
200 |
show_error=True,
|
201 |
share=False
|
202 |
+
)
|