codic's picture
back to the old one
93a1d27 verified
import gradio as gr
import easyocr
import pandas as pd
import numpy as np
import os
import logging
from pathlib import Path
from gliner import GLiNER
import cv2
import re
from PIL import Image
import traceback
import io # For in-memory file handling
from difflib import SequenceMatcher
import tempfile # Ensure tempfile is imported
import io # For in-memory file handling
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Set environment variables for model storage
os.environ['GLINER_HOME'] = '/tmp/.gliner_models'
os.environ['TRANSFORMERS_CACHE'] = '/tmp/.gliner_models/cache'
def initialize_models():
"""Initialize models with error handling and retries"""
try:
logger.info("Initializing EasyOCR...")
reader = easyocr.Reader(['en', 'ar'],
download_enabled=True,
model_storage_directory='/tmp/.easyocr_models')
logger.info("Initializing GLiNER...")
model_path = Path(os.environ['GLINER_HOME']) / 'gliner_large-v2.1'
if not model_path.exists():
logger.info("Downloading GLiNER model...")
model_path.parent.mkdir(parents=True, exist_ok=True)
model = GLiNER.from_pretrained("urchade/gliner_large-v2.1")
model.save_pretrained(str(model_path))
else:
model = GLiNER.from_pretrained(str(model_path))
logger.info("Models initialized successfully")
return reader, model
except Exception as e:
logger.error(f"Model initialization failed: {str(e)}")
raise
try:
reader, model = initialize_models()
except Exception as e:
logger.error(f"Critical failure: {traceback.format_exc()}")
raise RuntimeError("Failed to initialize models") from e
def clean_extracted_text(text):
"""Clean the extracted text with proper error handling"""
try:
cleaned = re.sub(
r'[^\u0600-\u06FF\u0750-\u077F\u08A0-\u08FFA-Za-z0-9\s@.,-]',
'',
text
)
return re.sub(r'\s+', ' ', cleaned).strip()
except Exception as e:
logger.error(f"Text cleaning failed: {traceback.format_exc()}")
return text
def preprocess_image(image, max_dim=1024):
"""Image preprocessing with validation and optional resizing"""
try:
if not isinstance(image, np.ndarray):
image = np.array(image)
# Optional: Resize if the image is too large (keeping aspect ratio)
h, w = image.shape[:2]
if max(h, w) > max_dim:
scaling = max_dim / float(max(h, w))
image = cv2.resize(image, (int(w * scaling), int(h * scaling)))
# Convert to grayscale if needed
if len(image.shape) == 2:
gray = image
else:
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
denoised = cv2.medianBlur(gray, 3)
_, thresh = cv2.threshold(denoised, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
return thresh
except Exception as e:
logger.error(f"Preprocessing failed: {traceback.format_exc()}")
raise
def similar(a, b, threshold=0.8):
"""Return True if two strings are similar above the given threshold"""
return SequenceMatcher(None, a, b).ratio() > threshold
def clean_and_deduplicate(entities):
"""
Post-process entity extraction:
- Validate emails and phone numbers using regex.
- Remove duplicates and near-duplicates.
"""
cleaned_results = {}
for label, values in entities.items():
unique = []
for val in values:
if label.lower() == "email":
match = re.search(r'[\w\.-]+@[\w\.-]+', val)
val = match.group(0) if match else val
elif label.lower() == "phone":
match = re.search(r'\+?\d[\d\s\-]{7,}\d', val)
val = match.group(0) if match else val
if not any(similar(val, exist) for exist in unique):
unique.append(val)
cleaned_results[label] = unique
return cleaned_results
def process_single_image(image, threshold=0.3, nested_ner=True, progress=gr.Progress()):
"""Process single image with detailed error handling, optimized I/O, and entity cleanup"""
try:
if image is None:
raise ValueError("No image provided")
progress(0.1, "Validating input...")
if not isinstance(image, (Image.Image, np.ndarray)):
raise TypeError(f"Invalid image type: {type(image)}")
progress(0.2, "Preprocessing image...")
preprocessed = preprocess_image(image)
progress(0.4, "Performing OCR...")
try:
ocr_results = reader.readtext(preprocessed, detail=0, paragraph=True)
except Exception as e:
logger.error(f"OCR failed: {traceback.format_exc()}")
raise RuntimeError("OCR processing failed") from e
raw_text = " ".join(ocr_results)
clean_text = clean_extracted_text(raw_text)
progress(0.6, "Extracting entities...")
try:
labels = ["person name", "company name", "job title", "phone", "email", "address"]
entities = model.predict_entities(
clean_text,
labels,
threshold=threshold,
flat_ner=not nested_ner
)
except Exception as e:
logger.error(f"Entity extraction failed: {traceback.format_exc()}")
raise RuntimeError("Entity extraction failed") from e
results = {label.title(): [] for label in labels}
for entity in entities:
label = entity["label"].title()
if label in results:
results[label].append(entity["text"])
cleaned_entities = clean_and_deduplicate(results)
# Generate CSV content in memory using BytesIO
csv_io = io.BytesIO()
pd.DataFrame([{k: "; ".join(v) for k, v in cleaned_entities.items()}]).to_csv(csv_io, index=False)
csv_io.seek(0)
# Write the CSV content to a temporary file and return its path
with tempfile.NamedTemporaryFile(suffix=".csv", delete=False, mode="wb") as tmp_file:
tmp_file.write(csv_io.getvalue())
csv_path = tmp_file.name
return (
clean_text, # Text output (str)
{k: "; ".join(v) for k, v in cleaned_entities.items()}, # JSON output (dict)
csv_path, # File path (str)
"" # Empty error message (str)
)
except Exception as e:
logger.error(f"Processing failed: {traceback.format_exc()}")
return (
"",
{},
None,
f"Error: {str(e)}\n{traceback.format_exc()}"
)
# Gradio Interface setup
with gr.Blocks() as app:
gr.Markdown("# Business Card Information Extractor")
with gr.Tab("Single File"):
with gr.Row():
with gr.Column():
single_image = gr.Image(label="Upload Business Card", type="pil")
threshold_single = gr.Slider(0.0, 1.0, value=0.3, label="Detection Threshold")
nested_ner_single = gr.Checkbox(True, label="Enable Nested NER")
submit_single = gr.Button("Process")
with gr.Column():
text_output = gr.Textbox(label="Extracted Text")
json_output = gr.JSON(label="Entities")
error_output = gr.Textbox(label="Error Details", visible=False)
csv_download_single = gr.File(label="Download Results")
submit_single.click(
fn=process_single_image,
inputs=[single_image, threshold_single, nested_ner_single],
outputs=[text_output, json_output, csv_download_single, error_output],
api_name="process_single"
).then(
lambda x: gr.update(visible=bool(x)),
inputs=[error_output],
outputs=[error_output]
)
app.launch(
debug=True,
show_error=True,
share=False
)