|
import gradio as gr |
|
import easyocr |
|
import pandas as pd |
|
import numpy as np |
|
import os |
|
import logging |
|
from pathlib import Path |
|
from gliner import GLiNER |
|
import cv2 |
|
import re |
|
from PIL import Image |
|
import traceback |
|
import io |
|
from difflib import SequenceMatcher |
|
import tempfile |
|
import io |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
os.environ['GLINER_HOME'] = '/tmp/.gliner_models' |
|
os.environ['TRANSFORMERS_CACHE'] = '/tmp/.gliner_models/cache' |
|
|
|
def initialize_models(): |
|
"""Initialize models with error handling and retries""" |
|
try: |
|
logger.info("Initializing EasyOCR...") |
|
reader = easyocr.Reader(['en', 'ar'], |
|
download_enabled=True, |
|
model_storage_directory='/tmp/.easyocr_models') |
|
logger.info("Initializing GLiNER...") |
|
model_path = Path(os.environ['GLINER_HOME']) / 'gliner_large-v2.1' |
|
if not model_path.exists(): |
|
logger.info("Downloading GLiNER model...") |
|
model_path.parent.mkdir(parents=True, exist_ok=True) |
|
model = GLiNER.from_pretrained("urchade/gliner_large-v2.1") |
|
model.save_pretrained(str(model_path)) |
|
else: |
|
model = GLiNER.from_pretrained(str(model_path)) |
|
logger.info("Models initialized successfully") |
|
return reader, model |
|
except Exception as e: |
|
logger.error(f"Model initialization failed: {str(e)}") |
|
raise |
|
|
|
try: |
|
reader, model = initialize_models() |
|
except Exception as e: |
|
logger.error(f"Critical failure: {traceback.format_exc()}") |
|
raise RuntimeError("Failed to initialize models") from e |
|
|
|
def clean_extracted_text(text): |
|
"""Clean the extracted text with proper error handling""" |
|
try: |
|
cleaned = re.sub( |
|
r'[^\u0600-\u06FF\u0750-\u077F\u08A0-\u08FFA-Za-z0-9\s@.,-]', |
|
'', |
|
text |
|
) |
|
return re.sub(r'\s+', ' ', cleaned).strip() |
|
except Exception as e: |
|
logger.error(f"Text cleaning failed: {traceback.format_exc()}") |
|
return text |
|
|
|
def preprocess_image(image, max_dim=1024): |
|
"""Image preprocessing with validation and optional resizing""" |
|
try: |
|
if not isinstance(image, np.ndarray): |
|
image = np.array(image) |
|
|
|
h, w = image.shape[:2] |
|
if max(h, w) > max_dim: |
|
scaling = max_dim / float(max(h, w)) |
|
image = cv2.resize(image, (int(w * scaling), int(h * scaling))) |
|
|
|
if len(image.shape) == 2: |
|
gray = image |
|
else: |
|
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) |
|
denoised = cv2.medianBlur(gray, 3) |
|
_, thresh = cv2.threshold(denoised, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) |
|
return thresh |
|
except Exception as e: |
|
logger.error(f"Preprocessing failed: {traceback.format_exc()}") |
|
raise |
|
|
|
def similar(a, b, threshold=0.8): |
|
"""Return True if two strings are similar above the given threshold""" |
|
return SequenceMatcher(None, a, b).ratio() > threshold |
|
|
|
def clean_and_deduplicate(entities): |
|
""" |
|
Post-process entity extraction: |
|
- Validate emails and phone numbers using regex. |
|
- Remove duplicates and near-duplicates. |
|
""" |
|
cleaned_results = {} |
|
for label, values in entities.items(): |
|
unique = [] |
|
for val in values: |
|
if label.lower() == "email": |
|
match = re.search(r'[\w\.-]+@[\w\.-]+', val) |
|
val = match.group(0) if match else val |
|
elif label.lower() == "phone": |
|
match = re.search(r'\+?\d[\d\s\-]{7,}\d', val) |
|
val = match.group(0) if match else val |
|
if not any(similar(val, exist) for exist in unique): |
|
unique.append(val) |
|
cleaned_results[label] = unique |
|
return cleaned_results |
|
|
|
def process_single_image(image, threshold=0.3, nested_ner=True, progress=gr.Progress()): |
|
"""Process single image with detailed error handling, optimized I/O, and entity cleanup""" |
|
try: |
|
if image is None: |
|
raise ValueError("No image provided") |
|
progress(0.1, "Validating input...") |
|
if not isinstance(image, (Image.Image, np.ndarray)): |
|
raise TypeError(f"Invalid image type: {type(image)}") |
|
progress(0.2, "Preprocessing image...") |
|
preprocessed = preprocess_image(image) |
|
progress(0.4, "Performing OCR...") |
|
try: |
|
ocr_results = reader.readtext(preprocessed, detail=0, paragraph=True) |
|
except Exception as e: |
|
logger.error(f"OCR failed: {traceback.format_exc()}") |
|
raise RuntimeError("OCR processing failed") from e |
|
raw_text = " ".join(ocr_results) |
|
clean_text = clean_extracted_text(raw_text) |
|
progress(0.6, "Extracting entities...") |
|
try: |
|
labels = ["person name", "company name", "job title", "phone", "email", "address"] |
|
entities = model.predict_entities( |
|
clean_text, |
|
labels, |
|
threshold=threshold, |
|
flat_ner=not nested_ner |
|
) |
|
except Exception as e: |
|
logger.error(f"Entity extraction failed: {traceback.format_exc()}") |
|
raise RuntimeError("Entity extraction failed") from e |
|
results = {label.title(): [] for label in labels} |
|
for entity in entities: |
|
label = entity["label"].title() |
|
if label in results: |
|
results[label].append(entity["text"]) |
|
cleaned_entities = clean_and_deduplicate(results) |
|
|
|
|
|
csv_io = io.BytesIO() |
|
pd.DataFrame([{k: "; ".join(v) for k, v in cleaned_entities.items()}]).to_csv(csv_io, index=False) |
|
csv_io.seek(0) |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".csv", delete=False, mode="wb") as tmp_file: |
|
tmp_file.write(csv_io.getvalue()) |
|
csv_path = tmp_file.name |
|
|
|
return ( |
|
clean_text, |
|
{k: "; ".join(v) for k, v in cleaned_entities.items()}, |
|
csv_path, |
|
"" |
|
) |
|
|
|
except Exception as e: |
|
logger.error(f"Processing failed: {traceback.format_exc()}") |
|
return ( |
|
"", |
|
{}, |
|
None, |
|
f"Error: {str(e)}\n{traceback.format_exc()}" |
|
) |
|
|
|
|
|
with gr.Blocks() as app: |
|
gr.Markdown("# Business Card Information Extractor") |
|
|
|
with gr.Tab("Single File"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
single_image = gr.Image(label="Upload Business Card", type="pil") |
|
threshold_single = gr.Slider(0.0, 1.0, value=0.3, label="Detection Threshold") |
|
nested_ner_single = gr.Checkbox(True, label="Enable Nested NER") |
|
submit_single = gr.Button("Process") |
|
with gr.Column(): |
|
text_output = gr.Textbox(label="Extracted Text") |
|
json_output = gr.JSON(label="Entities") |
|
error_output = gr.Textbox(label="Error Details", visible=False) |
|
csv_download_single = gr.File(label="Download Results") |
|
|
|
submit_single.click( |
|
fn=process_single_image, |
|
inputs=[single_image, threshold_single, nested_ner_single], |
|
outputs=[text_output, json_output, csv_download_single, error_output], |
|
api_name="process_single" |
|
).then( |
|
lambda x: gr.update(visible=bool(x)), |
|
inputs=[error_output], |
|
outputs=[error_output] |
|
) |
|
|
|
app.launch( |
|
debug=True, |
|
show_error=True, |
|
share=False |
|
) |
|
|