File size: 8,171 Bytes
2857c36 7b67189 34218ac 7b67189 6f67ddd 2857c36 34218ac 93a1d27 5b1ad96 93a1d27 34218ac 7b67189 2857c36 7b67189 34218ac 93a1d27 34218ac 5b1ad96 93a1d27 34218ac 2857c36 34218ac 082c048 bc7ce7e 93a1d27 bc7ce7e 01e43c3 93a1d27 01e43c3 bc7ce7e bcfd2c7 082c048 bcfd2c7 93a1d27 34218ac 93a1d27 bcfd2c7 93a1d27 34218ac 6f67ddd 5b1ad96 93a1d27 5b1ad96 93a1d27 5b1ad96 93a1d27 5b1ad96 93a1d27 082c048 93a1d27 2857c36 34218ac 2857c36 34218ac 93a1d27 34218ac 16bbce4 5b1ad96 2857c36 34218ac 93a1d27 34218ac 93a1d27 5b1ad96 34218ac 16bbce4 93a1d27 5b1ad96 bcfd2c7 93a1d27 bcfd2c7 93a1d27 25c0851 94b2d10 93a1d27 94b2d10 34218ac 2857c36 34218ac 94b2d10 bcfd2c7 94b2d10 7b67189 93a1d27 16bbce4 5b1ad96 7b67189 2857c36 93a1d27 2857c36 5b1ad96 2857c36 5b1ad96 34218ac 2857c36 7b67189 2857c36 34218ac 2857c36 7b67189 34218ac 16bbce4 93a1d27 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
import gradio as gr
import easyocr
import pandas as pd
import numpy as np
import os
import logging
from pathlib import Path
from gliner import GLiNER
import cv2
import re
from PIL import Image
import traceback
import io # For in-memory file handling
from difflib import SequenceMatcher
import tempfile # Ensure tempfile is imported
import io # For in-memory file handling
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Set environment variables for model storage
os.environ['GLINER_HOME'] = '/tmp/.gliner_models'
os.environ['TRANSFORMERS_CACHE'] = '/tmp/.gliner_models/cache'
def initialize_models():
"""Initialize models with error handling and retries"""
try:
logger.info("Initializing EasyOCR...")
reader = easyocr.Reader(['en', 'ar'],
download_enabled=True,
model_storage_directory='/tmp/.easyocr_models')
logger.info("Initializing GLiNER...")
model_path = Path(os.environ['GLINER_HOME']) / 'gliner_large-v2.1'
if not model_path.exists():
logger.info("Downloading GLiNER model...")
model_path.parent.mkdir(parents=True, exist_ok=True)
model = GLiNER.from_pretrained("urchade/gliner_large-v2.1")
model.save_pretrained(str(model_path))
else:
model = GLiNER.from_pretrained(str(model_path))
logger.info("Models initialized successfully")
return reader, model
except Exception as e:
logger.error(f"Model initialization failed: {str(e)}")
raise
try:
reader, model = initialize_models()
except Exception as e:
logger.error(f"Critical failure: {traceback.format_exc()}")
raise RuntimeError("Failed to initialize models") from e
def clean_extracted_text(text):
"""Clean the extracted text with proper error handling"""
try:
cleaned = re.sub(
r'[^\u0600-\u06FF\u0750-\u077F\u08A0-\u08FFA-Za-z0-9\s@.,-]',
'',
text
)
return re.sub(r'\s+', ' ', cleaned).strip()
except Exception as e:
logger.error(f"Text cleaning failed: {traceback.format_exc()}")
return text
def preprocess_image(image, max_dim=1024):
"""Image preprocessing with validation and optional resizing"""
try:
if not isinstance(image, np.ndarray):
image = np.array(image)
# Optional: Resize if the image is too large (keeping aspect ratio)
h, w = image.shape[:2]
if max(h, w) > max_dim:
scaling = max_dim / float(max(h, w))
image = cv2.resize(image, (int(w * scaling), int(h * scaling)))
# Convert to grayscale if needed
if len(image.shape) == 2:
gray = image
else:
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
denoised = cv2.medianBlur(gray, 3)
_, thresh = cv2.threshold(denoised, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
return thresh
except Exception as e:
logger.error(f"Preprocessing failed: {traceback.format_exc()}")
raise
def similar(a, b, threshold=0.8):
"""Return True if two strings are similar above the given threshold"""
return SequenceMatcher(None, a, b).ratio() > threshold
def clean_and_deduplicate(entities):
"""
Post-process entity extraction:
- Validate emails and phone numbers using regex.
- Remove duplicates and near-duplicates.
"""
cleaned_results = {}
for label, values in entities.items():
unique = []
for val in values:
if label.lower() == "email":
match = re.search(r'[\w\.-]+@[\w\.-]+', val)
val = match.group(0) if match else val
elif label.lower() == "phone":
match = re.search(r'\+?\d[\d\s\-]{7,}\d', val)
val = match.group(0) if match else val
if not any(similar(val, exist) for exist in unique):
unique.append(val)
cleaned_results[label] = unique
return cleaned_results
def process_single_image(image, threshold=0.3, nested_ner=True, progress=gr.Progress()):
"""Process single image with detailed error handling, optimized I/O, and entity cleanup"""
try:
if image is None:
raise ValueError("No image provided")
progress(0.1, "Validating input...")
if not isinstance(image, (Image.Image, np.ndarray)):
raise TypeError(f"Invalid image type: {type(image)}")
progress(0.2, "Preprocessing image...")
preprocessed = preprocess_image(image)
progress(0.4, "Performing OCR...")
try:
ocr_results = reader.readtext(preprocessed, detail=0, paragraph=True)
except Exception as e:
logger.error(f"OCR failed: {traceback.format_exc()}")
raise RuntimeError("OCR processing failed") from e
raw_text = " ".join(ocr_results)
clean_text = clean_extracted_text(raw_text)
progress(0.6, "Extracting entities...")
try:
labels = ["person name", "company name", "job title", "phone", "email", "address"]
entities = model.predict_entities(
clean_text,
labels,
threshold=threshold,
flat_ner=not nested_ner
)
except Exception as e:
logger.error(f"Entity extraction failed: {traceback.format_exc()}")
raise RuntimeError("Entity extraction failed") from e
results = {label.title(): [] for label in labels}
for entity in entities:
label = entity["label"].title()
if label in results:
results[label].append(entity["text"])
cleaned_entities = clean_and_deduplicate(results)
# Generate CSV content in memory using BytesIO
csv_io = io.BytesIO()
pd.DataFrame([{k: "; ".join(v) for k, v in cleaned_entities.items()}]).to_csv(csv_io, index=False)
csv_io.seek(0)
# Write the CSV content to a temporary file and return its path
with tempfile.NamedTemporaryFile(suffix=".csv", delete=False, mode="wb") as tmp_file:
tmp_file.write(csv_io.getvalue())
csv_path = tmp_file.name
return (
clean_text, # Text output (str)
{k: "; ".join(v) for k, v in cleaned_entities.items()}, # JSON output (dict)
csv_path, # File path (str)
"" # Empty error message (str)
)
except Exception as e:
logger.error(f"Processing failed: {traceback.format_exc()}")
return (
"",
{},
None,
f"Error: {str(e)}\n{traceback.format_exc()}"
)
# Gradio Interface setup
with gr.Blocks() as app:
gr.Markdown("# Business Card Information Extractor")
with gr.Tab("Single File"):
with gr.Row():
with gr.Column():
single_image = gr.Image(label="Upload Business Card", type="pil")
threshold_single = gr.Slider(0.0, 1.0, value=0.3, label="Detection Threshold")
nested_ner_single = gr.Checkbox(True, label="Enable Nested NER")
submit_single = gr.Button("Process")
with gr.Column():
text_output = gr.Textbox(label="Extracted Text")
json_output = gr.JSON(label="Entities")
error_output = gr.Textbox(label="Error Details", visible=False)
csv_download_single = gr.File(label="Download Results")
submit_single.click(
fn=process_single_image,
inputs=[single_image, threshold_single, nested_ner_single],
outputs=[text_output, json_output, csv_download_single, error_output],
api_name="process_single"
).then(
lambda x: gr.update(visible=bool(x)),
inputs=[error_output],
outputs=[error_output]
)
app.launch(
debug=True,
show_error=True,
share=False
)
|