gredio
Browse files
app.py
CHANGED
@@ -1,68 +1,35 @@
|
|
1 |
-
import
|
2 |
import easyocr
|
3 |
import pandas as pd
|
4 |
-
from io import BytesIO
|
5 |
-
from PIL import Image
|
6 |
import numpy as np
|
7 |
import os
|
8 |
from pathlib import Path
|
9 |
from gliner import GLiNER
|
10 |
import cv2
|
11 |
import re
|
|
|
|
|
12 |
|
13 |
# Set environment variables for model storage
|
14 |
-
os.environ['GLINER_HOME'] =
|
15 |
-
os.environ['TRANSFORMERS_CACHE'] =
|
16 |
|
17 |
# Initialize EasyOCR reader with English and Arabic support
|
18 |
reader = easyocr.Reader(['en', 'ar'])
|
19 |
|
20 |
-
|
21 |
-
"""Get the path to the local model directory."""
|
22 |
-
base_dir = Path.home() / '.gliner_models'
|
23 |
-
model_dir = base_dir / 'gliner_large-v2.1'
|
24 |
-
return model_dir
|
25 |
-
|
26 |
-
def download_model():
|
27 |
-
"""Download the model if it doesn't exist locally."""
|
28 |
-
model_dir = get_model_path()
|
29 |
-
if not model_dir.exists():
|
30 |
-
st.info("Downloading GLiNER model for the first time... This may take a few minutes.")
|
31 |
-
try:
|
32 |
-
model_dir.parent.mkdir(parents=True, exist_ok=True)
|
33 |
-
temp_model = GLiNER.from_pretrained("urchade/gliner_large-v2.1")
|
34 |
-
temp_model.save_pretrained(str(model_dir))
|
35 |
-
st.success("Model downloaded successfully!")
|
36 |
-
return temp_model
|
37 |
-
except Exception as e:
|
38 |
-
st.error(f"Error downloading model: {str(e)}")
|
39 |
-
raise e
|
40 |
-
return None
|
41 |
-
|
42 |
-
@st.cache_resource
|
43 |
def load_gliner_model():
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
except Exception as e:
|
50 |
-
st.warning("Error loading existing model. Attempting to redownload...")
|
51 |
-
import shutil
|
52 |
-
shutil.rmtree(model_dir, ignore_errors=True)
|
53 |
-
|
54 |
-
model = download_model()
|
55 |
-
if model:
|
56 |
return model
|
57 |
-
return GLiNER.from_pretrained(str(
|
|
|
|
|
58 |
|
59 |
def preprocess_image(image):
|
60 |
-
"""
|
61 |
-
Preprocess the image using OpenCV:
|
62 |
-
- Convert to grayscale
|
63 |
-
- Apply median blur for denoising
|
64 |
-
- Apply thresholding (Otsu) for binarization
|
65 |
-
"""
|
66 |
img_array = np.array(image)
|
67 |
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
|
68 |
denoised = cv2.medianBlur(gray, 3)
|
@@ -70,145 +37,102 @@ def preprocess_image(image):
|
|
70 |
return thresh
|
71 |
|
72 |
def clean_extracted_text(text):
|
73 |
-
"""
|
74 |
-
Clean the extracted text:
|
75 |
-
- Remove unwanted characters while preserving Arabic Unicode blocks,
|
76 |
-
English letters, digits, spaces, and common punctuation.
|
77 |
-
- Normalize extra spaces.
|
78 |
-
"""
|
79 |
cleaned = re.sub(r'[^\u0600-\u06FF\u0750-\u077F\u08A0-\u08FFA-Za-z0-9\s@.,-]', '', text)
|
80 |
-
|
81 |
-
return cleaned
|
82 |
|
83 |
-
def
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
-
def
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
flat_ner=not nested_ner,
|
103 |
-
threshold=threshold
|
104 |
-
)
|
105 |
|
106 |
-
#
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
"end": entity["end"]
|
114 |
-
})
|
115 |
|
116 |
-
|
117 |
-
|
118 |
-
"
|
119 |
-
"Company Name": [],
|
120 |
-
"Job Title": [],
|
121 |
-
"Phone": [],
|
122 |
-
"Email": [],
|
123 |
-
"Address": []
|
124 |
}
|
125 |
-
|
126 |
-
for entity in formatted_entities:
|
127 |
-
category = entity["entity"].title()
|
128 |
-
if category in results:
|
129 |
-
results[category].append(entity["word"])
|
130 |
-
|
131 |
-
# Join multiple entries with semicolons
|
132 |
-
return {k: "; ".join(set(v)) if v else "" for k, v in results.items()}
|
133 |
|
134 |
-
|
135 |
-
|
136 |
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
nested_ner = st.sidebar.checkbox(
|
150 |
-
"Enable Nested NER",
|
151 |
-
value=True,
|
152 |
-
help="Allow detection of nested entities"
|
153 |
-
)
|
154 |
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
)
|
164 |
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
progress_bar = st.progress(0)
|
172 |
-
|
173 |
-
for idx, file in enumerate(files_to_process):
|
174 |
-
with st.expander(f"Processing {file.name}"):
|
175 |
-
image = Image.open(file)
|
176 |
-
# Extract text using OCR after preprocessing
|
177 |
-
extracted_text_list = extract_text_from_image(image)
|
178 |
-
raw_text = " ".join(extracted_text_list)
|
179 |
-
# Clean the extracted text
|
180 |
-
clean_text = clean_extracted_text(raw_text)
|
181 |
-
|
182 |
-
st.text("Extracted Text:")
|
183 |
-
st.text(clean_text)
|
184 |
-
|
185 |
-
# Process extracted text with GLiNER for entity recognition
|
186 |
-
result = process_entities(clean_text, model, threshold, nested_ner)
|
187 |
-
result["File Name"] = file.name
|
188 |
-
results.append(result)
|
189 |
-
|
190 |
-
st.json(result)
|
191 |
-
|
192 |
-
progress_bar.progress((idx + 1) / len(files_to_process))
|
193 |
-
|
194 |
-
if results:
|
195 |
-
st.success("Processing Complete!")
|
196 |
-
|
197 |
-
# Convert results to a DataFrame
|
198 |
-
df = pd.DataFrame(results)
|
199 |
-
cols = ["File Name"] + [col for col in df.columns if col != "File Name"]
|
200 |
-
df = df[cols]
|
201 |
-
|
202 |
-
st.dataframe(df, use_container_width=True)
|
203 |
-
|
204 |
-
csv = df.to_csv(index=False)
|
205 |
-
st.download_button(
|
206 |
-
"Download Results CSV",
|
207 |
-
csv,
|
208 |
-
"business_card_results.csv",
|
209 |
-
"text/csv",
|
210 |
-
key='download-csv'
|
211 |
-
)
|
212 |
|
213 |
-
|
214 |
-
|
|
|
1 |
+
import gradio as gr
|
2 |
import easyocr
|
3 |
import pandas as pd
|
|
|
|
|
4 |
import numpy as np
|
5 |
import os
|
6 |
from pathlib import Path
|
7 |
from gliner import GLiNER
|
8 |
import cv2
|
9 |
import re
|
10 |
+
from PIL import Image
|
11 |
+
import time
|
12 |
|
13 |
# Set environment variables for model storage
|
14 |
+
os.environ['GLINER_HOME'] = '/tmp/.gliner_models'
|
15 |
+
os.environ['TRANSFORMERS_CACHE'] = '/tmp/.gliner_models/cache'
|
16 |
|
17 |
# Initialize EasyOCR reader with English and Arabic support
|
18 |
reader = easyocr.Reader(['en', 'ar'])
|
19 |
|
20 |
+
# Initialize GLiNER model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
def load_gliner_model():
|
22 |
+
model_path = Path(os.environ['GLINER_HOME']) / 'gliner_large-v2.1'
|
23 |
+
if not model_path.exists():
|
24 |
+
model_path.parent.mkdir(parents=True, exist_ok=True)
|
25 |
+
model = GLiNER.from_pretrained("urchade/gliner_large-v2.1")
|
26 |
+
model.save_pretrained(str(model_path))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
return model
|
28 |
+
return GLiNER.from_pretrained(str(model_path))
|
29 |
+
|
30 |
+
model = load_gliner_model()
|
31 |
|
32 |
def preprocess_image(image):
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
img_array = np.array(image)
|
34 |
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
|
35 |
denoised = cv2.medianBlur(gray, 3)
|
|
|
37 |
return thresh
|
38 |
|
39 |
def clean_extracted_text(text):
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
cleaned = re.sub(r'[^\u0600-\u06FF\u0750-\u077F\u08A0-\u08FFA-Za-z0-9\s@.,-]', '', text)
|
41 |
+
return re.sub(r'\s+', ' ', cleaned).strip()
|
|
|
42 |
|
43 |
+
def process_single_image(image, threshold=0.3, nested_ner=True, progress=gr.Progress()):
|
44 |
+
try:
|
45 |
+
# Preprocess and extract text
|
46 |
+
progress(0.2, "Processing image...")
|
47 |
+
preprocessed = preprocess_image(image)
|
48 |
+
ocr_results = reader.readtext(preprocessed, detail=0, paragraph=True)
|
49 |
+
clean_text = clean_extracted_text(" ".join(ocr_results))
|
50 |
+
|
51 |
+
# Entity extraction
|
52 |
+
progress(0.6, "Extracting entities...")
|
53 |
+
labels = ["person name", "company name", "job title", "phone", "email", "address"]
|
54 |
+
entities = model.predict_entities(clean_text, labels, threshold=threshold, flat_ner=not nested_ner)
|
55 |
+
|
56 |
+
# Format results
|
57 |
+
results = {label.title(): [] for label in labels}
|
58 |
+
for entity in entities:
|
59 |
+
label = entity["label"].title()
|
60 |
+
if label in results:
|
61 |
+
results[label].append(entity["text"])
|
62 |
+
|
63 |
+
return {
|
64 |
+
"text": clean_text,
|
65 |
+
"entities": {k: "; ".join(set(v)) for k, v in results.items()},
|
66 |
+
"csv": pd.DataFrame([results])
|
67 |
+
}
|
68 |
+
except Exception as e:
|
69 |
+
return {"error": str(e)}
|
70 |
|
71 |
+
def process_batch(files, threshold, nested_ner, progress=gr.Progress()):
|
72 |
+
results = []
|
73 |
+
for i, file in enumerate(files):
|
74 |
+
progress(i/len(files), f"Processing {file.name}...")
|
75 |
+
try:
|
76 |
+
image = Image.open(file)
|
77 |
+
result = process_single_image(image, threshold, nested_ner)
|
78 |
+
if "error" not in result:
|
79 |
+
result["filename"] = file.name
|
80 |
+
results.append(result)
|
81 |
+
except Exception as e:
|
82 |
+
results.append({"filename": file.name, "error": str(e)})
|
|
|
|
|
|
|
83 |
|
84 |
+
# Create CSV
|
85 |
+
df = pd.DataFrame([{
|
86 |
+
"Filename": r["filename"],
|
87 |
+
**r.get("entities", {}),
|
88 |
+
"Raw Text": r.get("text", ""),
|
89 |
+
"Error": r.get("error", "")
|
90 |
+
} for r in results])
|
|
|
|
|
91 |
|
92 |
+
return {
|
93 |
+
"batch_results": results,
|
94 |
+
"csv": df
|
|
|
|
|
|
|
|
|
|
|
95 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
+
with gr.Blocks() as app:
|
98 |
+
gr.Markdown("# Business Card Information Extractor")
|
99 |
|
100 |
+
with gr.Tab("Single File"):
|
101 |
+
with gr.Row():
|
102 |
+
with gr.Column():
|
103 |
+
single_image = gr.Image(label="Upload Business Card", type="pil")
|
104 |
+
threshold_single = gr.Slider(0.0, 1.0, value=0.3, label="Detection Threshold")
|
105 |
+
nested_ner_single = gr.Checkbox(True, label="Enable Nested NER")
|
106 |
+
submit_single = gr.Button("Process")
|
107 |
+
with gr.Column():
|
108 |
+
text_output = gr.Textbox(label="Extracted Text")
|
109 |
+
json_output = gr.JSON(label="Entities")
|
110 |
+
csv_download_single = gr.File(label="Download Results")
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
+
with gr.Tab("Batch Processing"):
|
113 |
+
with gr.Row():
|
114 |
+
with gr.Column():
|
115 |
+
batch_files = gr.Files(label="Upload Business Cards", file_types=["image"])
|
116 |
+
threshold_batch = gr.Slider(0.0, 1.0, value=0.3, label="Detection Threshold")
|
117 |
+
nested_ner_batch = gr.Checkbox(True, label="Enable Nested NER")
|
118 |
+
submit_batch = gr.Button("Process Batch")
|
119 |
+
with gr.Column():
|
120 |
+
batch_results = gr.JSON(label="Processing Results")
|
121 |
+
csv_download_batch = gr.File(label="Download CSV")
|
122 |
+
|
123 |
+
# Single file processing
|
124 |
+
submit_single.click(
|
125 |
+
fn=process_single_image,
|
126 |
+
inputs=[single_image, threshold_single, nested_ner_single],
|
127 |
+
outputs=[text_output, json_output, csv_download_single]
|
128 |
)
|
129 |
|
130 |
+
# Batch processing
|
131 |
+
submit_batch.click(
|
132 |
+
fn=process_batch,
|
133 |
+
inputs=[batch_files, threshold_batch, nested_ner_batch],
|
134 |
+
outputs=[batch_results, csv_download_batch]
|
135 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
+
# For API access
|
138 |
+
app.launch(enable_queue=True, share=True)
|