File size: 8,171 Bytes
2857c36
7b67189
 
 
 
34218ac
7b67189
 
6f67ddd
 
2857c36
34218ac
93a1d27
5b1ad96
93a1d27
 
34218ac
 
 
 
7b67189
 
2857c36
 
7b67189
34218ac
93a1d27
34218ac
 
5b1ad96
 
93a1d27
34218ac
 
 
 
 
 
 
 
 
 
 
 
 
 
2857c36
34218ac
 
 
 
 
082c048
bc7ce7e
93a1d27
bc7ce7e
01e43c3
93a1d27
01e43c3
 
 
 
bc7ce7e
 
bcfd2c7
082c048
bcfd2c7
93a1d27
34218ac
 
 
93a1d27
bcfd2c7
 
 
93a1d27
 
 
 
 
 
 
 
 
34218ac
 
 
6f67ddd
5b1ad96
93a1d27
5b1ad96
 
 
93a1d27
 
 
 
 
 
5b1ad96
93a1d27
5b1ad96
93a1d27
 
 
 
 
 
 
 
 
 
082c048
93a1d27
 
2857c36
34218ac
 
 
 
 
 
2857c36
34218ac
 
93a1d27
34218ac
 
16bbce4
5b1ad96
 
2857c36
34218ac
93a1d27
34218ac
93a1d27
 
 
 
5b1ad96
34218ac
 
16bbce4
93a1d27
 
 
 
 
5b1ad96
bcfd2c7
93a1d27
bcfd2c7
 
 
 
93a1d27
25c0851
 
 
 
94b2d10
93a1d27
 
 
 
94b2d10
34218ac
2857c36
34218ac
94b2d10
bcfd2c7
 
 
 
94b2d10
7b67189
93a1d27
16bbce4
5b1ad96
7b67189
2857c36
 
 
 
93a1d27
2857c36
5b1ad96
2857c36
 
5b1ad96
34218ac
2857c36
7b67189
2857c36
 
 
34218ac
 
 
 
 
 
2857c36
7b67189
34218ac
 
 
16bbce4
93a1d27
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import gradio as gr
import easyocr
import pandas as pd
import numpy as np
import os
import logging
from pathlib import Path
from gliner import GLiNER
import cv2
import re
from PIL import Image
import traceback
import io  # For in-memory file handling
from difflib import SequenceMatcher
import tempfile  # Ensure tempfile is imported
import io        # For in-memory file handling

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set environment variables for model storage
os.environ['GLINER_HOME'] = '/tmp/.gliner_models'
os.environ['TRANSFORMERS_CACHE'] = '/tmp/.gliner_models/cache'

def initialize_models():
    """Initialize models with error handling and retries"""
    try:
        logger.info("Initializing EasyOCR...")
        reader = easyocr.Reader(['en', 'ar'], 
                              download_enabled=True,
                              model_storage_directory='/tmp/.easyocr_models')
        logger.info("Initializing GLiNER...")
        model_path = Path(os.environ['GLINER_HOME']) / 'gliner_large-v2.1'
        if not model_path.exists():
            logger.info("Downloading GLiNER model...")
            model_path.parent.mkdir(parents=True, exist_ok=True)
            model = GLiNER.from_pretrained("urchade/gliner_large-v2.1")
            model.save_pretrained(str(model_path))
        else:
            model = GLiNER.from_pretrained(str(model_path))
        logger.info("Models initialized successfully")
        return reader, model
    except Exception as e:
        logger.error(f"Model initialization failed: {str(e)}")
        raise

try:
    reader, model = initialize_models()
except Exception as e:
    logger.error(f"Critical failure: {traceback.format_exc()}")
    raise RuntimeError("Failed to initialize models") from e

def clean_extracted_text(text):
    """Clean the extracted text with proper error handling"""
    try:
        cleaned = re.sub(
            r'[^\u0600-\u06FF\u0750-\u077F\u08A0-\u08FFA-Za-z0-9\s@.,-]', 
            '', 
            text
        )
        return re.sub(r'\s+', ' ', cleaned).strip()
    except Exception as e:
        logger.error(f"Text cleaning failed: {traceback.format_exc()}")
        return text

def preprocess_image(image, max_dim=1024):
    """Image preprocessing with validation and optional resizing"""
    try:
        if not isinstance(image, np.ndarray):
            image = np.array(image)
        # Optional: Resize if the image is too large (keeping aspect ratio)
        h, w = image.shape[:2]
        if max(h, w) > max_dim:
            scaling = max_dim / float(max(h, w))
            image = cv2.resize(image, (int(w * scaling), int(h * scaling)))
        # Convert to grayscale if needed
        if len(image.shape) == 2:
            gray = image
        else:
            gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        denoised = cv2.medianBlur(gray, 3)
        _, thresh = cv2.threshold(denoised, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        return thresh
    except Exception as e:
        logger.error(f"Preprocessing failed: {traceback.format_exc()}")
        raise

def similar(a, b, threshold=0.8):
    """Return True if two strings are similar above the given threshold"""
    return SequenceMatcher(None, a, b).ratio() > threshold

def clean_and_deduplicate(entities):
    """
    Post-process entity extraction:
    - Validate emails and phone numbers using regex.
    - Remove duplicates and near-duplicates.
    """
    cleaned_results = {}
    for label, values in entities.items():
        unique = []
        for val in values:
            if label.lower() == "email":
                match = re.search(r'[\w\.-]+@[\w\.-]+', val)
                val = match.group(0) if match else val
            elif label.lower() == "phone":
                match = re.search(r'\+?\d[\d\s\-]{7,}\d', val)
                val = match.group(0) if match else val
            if not any(similar(val, exist) for exist in unique):
                unique.append(val)
        cleaned_results[label] = unique
    return cleaned_results

def process_single_image(image, threshold=0.3, nested_ner=True, progress=gr.Progress()):
    """Process single image with detailed error handling, optimized I/O, and entity cleanup"""
    try:
        if image is None:
            raise ValueError("No image provided")
        progress(0.1, "Validating input...")
        if not isinstance(image, (Image.Image, np.ndarray)):
            raise TypeError(f"Invalid image type: {type(image)}")
        progress(0.2, "Preprocessing image...")
        preprocessed = preprocess_image(image)
        progress(0.4, "Performing OCR...")
        try:
            ocr_results = reader.readtext(preprocessed, detail=0, paragraph=True)
        except Exception as e:
            logger.error(f"OCR failed: {traceback.format_exc()}")
            raise RuntimeError("OCR processing failed") from e
        raw_text = " ".join(ocr_results)
        clean_text = clean_extracted_text(raw_text)
        progress(0.6, "Extracting entities...")
        try:
            labels = ["person name", "company name", "job title", "phone", "email", "address"]
            entities = model.predict_entities(
                clean_text, 
                labels, 
                threshold=threshold, 
                flat_ner=not nested_ner
            )
        except Exception as e:
            logger.error(f"Entity extraction failed: {traceback.format_exc()}")
            raise RuntimeError("Entity extraction failed") from e
        results = {label.title(): [] for label in labels}
        for entity in entities:
            label = entity["label"].title()
            if label in results:
                results[label].append(entity["text"])
        cleaned_entities = clean_and_deduplicate(results)
        
        # Generate CSV content in memory using BytesIO
        csv_io = io.BytesIO()
        pd.DataFrame([{k: "; ".join(v) for k, v in cleaned_entities.items()}]).to_csv(csv_io, index=False)
        csv_io.seek(0)
        
        # Write the CSV content to a temporary file and return its path
        with tempfile.NamedTemporaryFile(suffix=".csv", delete=False, mode="wb") as tmp_file:
            tmp_file.write(csv_io.getvalue())
            csv_path = tmp_file.name
        
        return (
            clean_text,                                  # Text output (str)
            {k: "; ".join(v) for k, v in cleaned_entities.items()},  # JSON output (dict)
            csv_path,                                    # File path (str)
            ""                                           # Empty error message (str)
        )
        
    except Exception as e:
        logger.error(f"Processing failed: {traceback.format_exc()}")
        return (
            "",   
            {},    
            None,  
            f"Error: {str(e)}\n{traceback.format_exc()}"
        )

# Gradio Interface setup
with gr.Blocks() as app:
    gr.Markdown("# Business Card Information Extractor")
    
    with gr.Tab("Single File"):
        with gr.Row():
            with gr.Column():
                single_image = gr.Image(label="Upload Business Card", type="pil")
                threshold_single = gr.Slider(0.0, 1.0, value=0.3, label="Detection Threshold")
                nested_ner_single = gr.Checkbox(True, label="Enable Nested NER")
                submit_single = gr.Button("Process")
            with gr.Column():
                text_output = gr.Textbox(label="Extracted Text")
                json_output = gr.JSON(label="Entities")
                error_output = gr.Textbox(label="Error Details", visible=False)
                csv_download_single = gr.File(label="Download Results")

    submit_single.click(
        fn=process_single_image,
        inputs=[single_image, threshold_single, nested_ner_single],
        outputs=[text_output, json_output, csv_download_single, error_output],
        api_name="process_single"
    ).then(
        lambda x: gr.update(visible=bool(x)),
        inputs=[error_output],
        outputs=[error_output]
    )

app.launch(
    debug=True,
    show_error=True,
    share=False
)