codic commited on
Commit
2857c36
·
verified ·
1 Parent(s): 1d35187
Files changed (1) hide show
  1. app.py +101 -177
app.py CHANGED
@@ -1,68 +1,35 @@
1
- import streamlit as st
2
  import easyocr
3
  import pandas as pd
4
- from io import BytesIO
5
- from PIL import Image
6
  import numpy as np
7
  import os
8
  from pathlib import Path
9
  from gliner import GLiNER
10
  import cv2
11
  import re
 
 
12
 
13
  # Set environment variables for model storage
14
- os.environ['GLINER_HOME'] = str(Path.home() / '.gliner_models')
15
- os.environ['TRANSFORMERS_CACHE'] = str(Path.home() / '.gliner_models' / 'cache')
16
 
17
  # Initialize EasyOCR reader with English and Arabic support
18
  reader = easyocr.Reader(['en', 'ar'])
19
 
20
- def get_model_path():
21
- """Get the path to the local model directory."""
22
- base_dir = Path.home() / '.gliner_models'
23
- model_dir = base_dir / 'gliner_large-v2.1'
24
- return model_dir
25
-
26
- def download_model():
27
- """Download the model if it doesn't exist locally."""
28
- model_dir = get_model_path()
29
- if not model_dir.exists():
30
- st.info("Downloading GLiNER model for the first time... This may take a few minutes.")
31
- try:
32
- model_dir.parent.mkdir(parents=True, exist_ok=True)
33
- temp_model = GLiNER.from_pretrained("urchade/gliner_large-v2.1")
34
- temp_model.save_pretrained(str(model_dir))
35
- st.success("Model downloaded successfully!")
36
- return temp_model
37
- except Exception as e:
38
- st.error(f"Error downloading model: {str(e)}")
39
- raise e
40
- return None
41
-
42
- @st.cache_resource
43
  def load_gliner_model():
44
- """Load the GLiNER model, downloading it if necessary."""
45
- model_dir = get_model_path()
46
- if model_dir.exists():
47
- try:
48
- return GLiNER.from_pretrained(str(model_dir))
49
- except Exception as e:
50
- st.warning("Error loading existing model. Attempting to redownload...")
51
- import shutil
52
- shutil.rmtree(model_dir, ignore_errors=True)
53
-
54
- model = download_model()
55
- if model:
56
  return model
57
- return GLiNER.from_pretrained(str(model_dir))
 
 
58
 
59
  def preprocess_image(image):
60
- """
61
- Preprocess the image using OpenCV:
62
- - Convert to grayscale
63
- - Apply median blur for denoising
64
- - Apply thresholding (Otsu) for binarization
65
- """
66
  img_array = np.array(image)
67
  gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
68
  denoised = cv2.medianBlur(gray, 3)
@@ -70,145 +37,102 @@ def preprocess_image(image):
70
  return thresh
71
 
72
  def clean_extracted_text(text):
73
- """
74
- Clean the extracted text:
75
- - Remove unwanted characters while preserving Arabic Unicode blocks,
76
- English letters, digits, spaces, and common punctuation.
77
- - Normalize extra spaces.
78
- """
79
  cleaned = re.sub(r'[^\u0600-\u06FF\u0750-\u077F\u08A0-\u08FFA-Za-z0-9\s@.,-]', '', text)
80
- cleaned = re.sub(r'\s+', ' ', cleaned).strip()
81
- return cleaned
82
 
83
- def extract_text_from_image(image):
84
- """
85
- Preprocess the image and extract text using EasyOCR.
86
- """
87
- preprocessed_image = preprocess_image(image)
88
- return reader.readtext(preprocessed_image, detail=0, paragraph=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- def process_entities(text: str, model, threshold: float, nested_ner: bool) -> dict:
91
- """
92
- Process text with GLiNER model to extract business card entities.
93
- """
94
- # Define business card labels
95
- labels = "person name, company name, job title, phone, email, address"
96
- labels = [label.strip() for label in labels.split(",")]
97
-
98
- # Get predictions
99
- entities = model.predict_entities(
100
- text,
101
- labels,
102
- flat_ner=not nested_ner,
103
- threshold=threshold
104
- )
105
 
106
- # Format results
107
- formatted_entities = []
108
- for entity in entities:
109
- formatted_entities.append({
110
- "entity": entity["label"],
111
- "word": entity["text"],
112
- "start": entity["start"],
113
- "end": entity["end"]
114
- })
115
 
116
- # Organize results by category
117
- results = {
118
- "Person Name": [],
119
- "Company Name": [],
120
- "Job Title": [],
121
- "Phone": [],
122
- "Email": [],
123
- "Address": []
124
  }
125
-
126
- for entity in formatted_entities:
127
- category = entity["entity"].title()
128
- if category in results:
129
- results[category].append(entity["word"])
130
-
131
- # Join multiple entries with semicolons
132
- return {k: "; ".join(set(v)) if v else "" for k, v in results.items()}
133
 
134
- def main():
135
- st.title("Business Card Information Extractor")
136
 
137
- # Model settings in sidebar
138
- st.sidebar.title("Settings")
139
-
140
- threshold = st.sidebar.slider(
141
- "Detection Threshold",
142
- min_value=0.0,
143
- max_value=1.0,
144
- value=0.3,
145
- step=0.05,
146
- help="Lower values will detect more entities"
147
- )
148
-
149
- nested_ner = st.sidebar.checkbox(
150
- "Enable Nested NER",
151
- value=True,
152
- help="Allow detection of nested entities"
153
- )
154
 
155
- # Upload options
156
- upload_type = st.sidebar.radio("Upload Type", ("Single", "Batch"))
157
-
158
- # File uploader for business card images
159
- uploaded_files = st.file_uploader(
160
- "Upload Business Card Image(s)",
161
- type=["png", "jpg", "jpeg"],
162
- accept_multiple_files=(upload_type == "Batch")
 
 
 
 
 
 
 
 
163
  )
164
 
165
- if uploaded_files:
166
- # Load GLiNER model
167
- model = load_gliner_model()
168
-
169
- results = []
170
- files_to_process = uploaded_files if isinstance(uploaded_files, list) else [uploaded_files]
171
- progress_bar = st.progress(0)
172
-
173
- for idx, file in enumerate(files_to_process):
174
- with st.expander(f"Processing {file.name}"):
175
- image = Image.open(file)
176
- # Extract text using OCR after preprocessing
177
- extracted_text_list = extract_text_from_image(image)
178
- raw_text = " ".join(extracted_text_list)
179
- # Clean the extracted text
180
- clean_text = clean_extracted_text(raw_text)
181
-
182
- st.text("Extracted Text:")
183
- st.text(clean_text)
184
-
185
- # Process extracted text with GLiNER for entity recognition
186
- result = process_entities(clean_text, model, threshold, nested_ner)
187
- result["File Name"] = file.name
188
- results.append(result)
189
-
190
- st.json(result)
191
-
192
- progress_bar.progress((idx + 1) / len(files_to_process))
193
-
194
- if results:
195
- st.success("Processing Complete!")
196
-
197
- # Convert results to a DataFrame
198
- df = pd.DataFrame(results)
199
- cols = ["File Name"] + [col for col in df.columns if col != "File Name"]
200
- df = df[cols]
201
-
202
- st.dataframe(df, use_container_width=True)
203
-
204
- csv = df.to_csv(index=False)
205
- st.download_button(
206
- "Download Results CSV",
207
- csv,
208
- "business_card_results.csv",
209
- "text/csv",
210
- key='download-csv'
211
- )
212
 
213
- if __name__ == "__main__":
214
- main()
 
1
+ import gradio as gr
2
  import easyocr
3
  import pandas as pd
 
 
4
  import numpy as np
5
  import os
6
  from pathlib import Path
7
  from gliner import GLiNER
8
  import cv2
9
  import re
10
+ from PIL import Image
11
+ import time
12
 
13
  # Set environment variables for model storage
14
+ os.environ['GLINER_HOME'] = '/tmp/.gliner_models'
15
+ os.environ['TRANSFORMERS_CACHE'] = '/tmp/.gliner_models/cache'
16
 
17
  # Initialize EasyOCR reader with English and Arabic support
18
  reader = easyocr.Reader(['en', 'ar'])
19
 
20
+ # Initialize GLiNER model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def load_gliner_model():
22
+ model_path = Path(os.environ['GLINER_HOME']) / 'gliner_large-v2.1'
23
+ if not model_path.exists():
24
+ model_path.parent.mkdir(parents=True, exist_ok=True)
25
+ model = GLiNER.from_pretrained("urchade/gliner_large-v2.1")
26
+ model.save_pretrained(str(model_path))
 
 
 
 
 
 
 
27
  return model
28
+ return GLiNER.from_pretrained(str(model_path))
29
+
30
+ model = load_gliner_model()
31
 
32
  def preprocess_image(image):
 
 
 
 
 
 
33
  img_array = np.array(image)
34
  gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
35
  denoised = cv2.medianBlur(gray, 3)
 
37
  return thresh
38
 
39
  def clean_extracted_text(text):
 
 
 
 
 
 
40
  cleaned = re.sub(r'[^\u0600-\u06FF\u0750-\u077F\u08A0-\u08FFA-Za-z0-9\s@.,-]', '', text)
41
+ return re.sub(r'\s+', ' ', cleaned).strip()
 
42
 
43
+ def process_single_image(image, threshold=0.3, nested_ner=True, progress=gr.Progress()):
44
+ try:
45
+ # Preprocess and extract text
46
+ progress(0.2, "Processing image...")
47
+ preprocessed = preprocess_image(image)
48
+ ocr_results = reader.readtext(preprocessed, detail=0, paragraph=True)
49
+ clean_text = clean_extracted_text(" ".join(ocr_results))
50
+
51
+ # Entity extraction
52
+ progress(0.6, "Extracting entities...")
53
+ labels = ["person name", "company name", "job title", "phone", "email", "address"]
54
+ entities = model.predict_entities(clean_text, labels, threshold=threshold, flat_ner=not nested_ner)
55
+
56
+ # Format results
57
+ results = {label.title(): [] for label in labels}
58
+ for entity in entities:
59
+ label = entity["label"].title()
60
+ if label in results:
61
+ results[label].append(entity["text"])
62
+
63
+ return {
64
+ "text": clean_text,
65
+ "entities": {k: "; ".join(set(v)) for k, v in results.items()},
66
+ "csv": pd.DataFrame([results])
67
+ }
68
+ except Exception as e:
69
+ return {"error": str(e)}
70
 
71
+ def process_batch(files, threshold, nested_ner, progress=gr.Progress()):
72
+ results = []
73
+ for i, file in enumerate(files):
74
+ progress(i/len(files), f"Processing {file.name}...")
75
+ try:
76
+ image = Image.open(file)
77
+ result = process_single_image(image, threshold, nested_ner)
78
+ if "error" not in result:
79
+ result["filename"] = file.name
80
+ results.append(result)
81
+ except Exception as e:
82
+ results.append({"filename": file.name, "error": str(e)})
 
 
 
83
 
84
+ # Create CSV
85
+ df = pd.DataFrame([{
86
+ "Filename": r["filename"],
87
+ **r.get("entities", {}),
88
+ "Raw Text": r.get("text", ""),
89
+ "Error": r.get("error", "")
90
+ } for r in results])
 
 
91
 
92
+ return {
93
+ "batch_results": results,
94
+ "csv": df
 
 
 
 
 
95
  }
 
 
 
 
 
 
 
 
96
 
97
+ with gr.Blocks() as app:
98
+ gr.Markdown("# Business Card Information Extractor")
99
 
100
+ with gr.Tab("Single File"):
101
+ with gr.Row():
102
+ with gr.Column():
103
+ single_image = gr.Image(label="Upload Business Card", type="pil")
104
+ threshold_single = gr.Slider(0.0, 1.0, value=0.3, label="Detection Threshold")
105
+ nested_ner_single = gr.Checkbox(True, label="Enable Nested NER")
106
+ submit_single = gr.Button("Process")
107
+ with gr.Column():
108
+ text_output = gr.Textbox(label="Extracted Text")
109
+ json_output = gr.JSON(label="Entities")
110
+ csv_download_single = gr.File(label="Download Results")
 
 
 
 
 
 
111
 
112
+ with gr.Tab("Batch Processing"):
113
+ with gr.Row():
114
+ with gr.Column():
115
+ batch_files = gr.Files(label="Upload Business Cards", file_types=["image"])
116
+ threshold_batch = gr.Slider(0.0, 1.0, value=0.3, label="Detection Threshold")
117
+ nested_ner_batch = gr.Checkbox(True, label="Enable Nested NER")
118
+ submit_batch = gr.Button("Process Batch")
119
+ with gr.Column():
120
+ batch_results = gr.JSON(label="Processing Results")
121
+ csv_download_batch = gr.File(label="Download CSV")
122
+
123
+ # Single file processing
124
+ submit_single.click(
125
+ fn=process_single_image,
126
+ inputs=[single_image, threshold_single, nested_ner_single],
127
+ outputs=[text_output, json_output, csv_download_single]
128
  )
129
 
130
+ # Batch processing
131
+ submit_batch.click(
132
+ fn=process_batch,
133
+ inputs=[batch_files, threshold_batch, nested_ner_batch],
134
+ outputs=[batch_results, csv_download_batch]
135
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
+ # For API access
138
+ app.launch(enable_queue=True, share=True)