ahmadsalahudin commited on
Commit
7b67189
·
verified ·
1 Parent(s): aa23d05

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +190 -0
app.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import easyocr
3
+ import pandas as pd
4
+ from io import BytesIO
5
+ from PIL import Image
6
+ import numpy as np
7
+ import os
8
+ from pathlib import Path
9
+ from gliner import GLiNER
10
+
11
+ # Set environment variables for model storage
12
+ os.environ['GLINER_HOME'] = str(Path.home() / '.gliner_models')
13
+ os.environ['TRANSFORMERS_CACHE'] = str(Path.home() / '.gliner_models' / 'cache')
14
+
15
+ # Initialize EasyOCR reader
16
+ reader = easyocr.Reader(['en'])
17
+
18
+ def get_model_path():
19
+ """Get the path to the local model directory."""
20
+ base_dir = Path.home() / '.gliner_models'
21
+ model_dir = base_dir / 'gliner_large-v2.1'
22
+ return model_dir
23
+
24
+ def download_model():
25
+ """Download the model if it doesn't exist locally."""
26
+ model_dir = get_model_path()
27
+ if not model_dir.exists():
28
+ st.info("Downloading GLiNER model for the first time... This may take a few minutes.")
29
+ try:
30
+ model_dir.parent.mkdir(parents=True, exist_ok=True)
31
+ temp_model = GLiNER.from_pretrained("urchade/gliner_large-v2.1")
32
+ temp_model.save_pretrained(str(model_dir))
33
+ st.success("Model downloaded successfully!")
34
+ return temp_model
35
+ except Exception as e:
36
+ st.error(f"Error downloading model: {str(e)}")
37
+ raise e
38
+ return None
39
+
40
+ @st.cache_resource
41
+ def load_gliner_model():
42
+ """Load the GLiNER model, downloading it if necessary."""
43
+ model_dir = get_model_path()
44
+ if model_dir.exists():
45
+ try:
46
+ return GLiNER.from_pretrained(str(model_dir))
47
+ except Exception as e:
48
+ st.warning("Error loading existing model. Attempting to redownload...")
49
+ import shutil
50
+ shutil.rmtree(model_dir, ignore_errors=True)
51
+
52
+ model = download_model()
53
+ if model:
54
+ return model
55
+ return GLiNER.from_pretrained(str(model_dir))
56
+
57
+ def extract_text_from_image(image):
58
+ """Extracts text from a single image using EasyOCR."""
59
+ image_array = np.array(image)
60
+ return reader.readtext(image_array, detail=0, paragraph=True)
61
+
62
+ def process_entities(text: str, model, threshold: float, nested_ner: bool) -> dict:
63
+ """Process text with GLiNER model - matching app.py implementation."""
64
+ # Define our business card labels
65
+ labels = "person name, company name, job title, phone, email, address"
66
+ labels = [label.strip() for label in labels.split(",")]
67
+
68
+ # Get predictions
69
+ entities = model.predict_entities(
70
+ text,
71
+ labels,
72
+ flat_ner=not nested_ner,
73
+ threshold=threshold
74
+ )
75
+
76
+ # Format results matching app.py structure
77
+ formatted_entities = []
78
+ for entity in entities:
79
+ formatted_entities.append({
80
+ "entity": entity["label"],
81
+ "word": entity["text"],
82
+ "start": entity["start"],
83
+ "end": entity["end"]
84
+ })
85
+
86
+ # Organize results by category
87
+ results = {
88
+ "Person Name": [],
89
+ "Company Name": [],
90
+ "Job Title": [],
91
+ "Phone": [],
92
+ "Email": [],
93
+ "Address": []
94
+ }
95
+
96
+ for entity in formatted_entities:
97
+ category = entity["entity"].title()
98
+ if category in results:
99
+ results[category].append(entity["word"])
100
+
101
+ # Join multiple entries with semicolons
102
+ return {k: "; ".join(set(v)) if v else "" for k, v in results.items()}
103
+
104
+ def main():
105
+ st.title("Business Card Information Extractor")
106
+
107
+ # Model settings in sidebar
108
+ st.sidebar.title("Settings")
109
+
110
+ threshold = st.sidebar.slider(
111
+ "Detection Threshold",
112
+ min_value=0.0,
113
+ max_value=1.0,
114
+ value=0.3,
115
+ step=0.05,
116
+ help="Lower values will detect more entities (as in app.py example)"
117
+ )
118
+
119
+ nested_ner = st.sidebar.checkbox(
120
+ "Enable Nested NER",
121
+ value=True,
122
+ help="Allow detection of nested entities"
123
+ )
124
+
125
+ # Upload options
126
+ upload_type = st.sidebar.radio("Upload Type", ("Single", "Batch"))
127
+
128
+ # File uploader
129
+ uploaded_files = st.file_uploader(
130
+ "Upload Business Card Image(s)",
131
+ type=["png", "jpg", "jpeg"],
132
+ accept_multiple_files=(upload_type == "Batch")
133
+ )
134
+
135
+ if uploaded_files:
136
+ # Load model
137
+ model = load_gliner_model()
138
+
139
+ # Process files
140
+ results = []
141
+ files_to_process = uploaded_files if isinstance(uploaded_files, list) else [uploaded_files]
142
+
143
+ progress_bar = st.progress(0)
144
+ for idx, file in enumerate(files_to_process):
145
+ with st.expander(f"Processing {file.name}"):
146
+ # Load and extract text
147
+ image = Image.open(file)
148
+ extracted_text = extract_text_from_image(image)
149
+ clean_text = " ".join(extracted_text)
150
+
151
+ # Show extracted text
152
+ st.text("Extracted Text:")
153
+ st.text(clean_text)
154
+
155
+ # Process with GLiNER
156
+ result = process_entities(clean_text, model, threshold, nested_ner)
157
+ result["File Name"] = file.name
158
+ results.append(result)
159
+
160
+ # Show individual results
161
+ st.json(result)
162
+
163
+ progress_bar.progress((idx + 1) / len(files_to_process))
164
+
165
+ # Show final results
166
+ if results:
167
+ st.success("Processing Complete!")
168
+
169
+ # Convert to DataFrame
170
+ df = pd.DataFrame(results)
171
+
172
+ # Reorder columns to put filename first
173
+ cols = ["File Name"] + [col for col in df.columns if col != "File Name"]
174
+ df = df[cols]
175
+
176
+ # Display results
177
+ st.dataframe(df, use_container_width=True)
178
+
179
+ # Provide download option
180
+ csv = df.to_csv(index=False)
181
+ st.download_button(
182
+ "Download Results CSV",
183
+ csv,
184
+ "business_card_results.csv",
185
+ "text/csv",
186
+ key='download-csv'
187
+ )
188
+
189
+ if __name__ == "__main__":
190
+ main()