Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import easyocr
|
3 |
+
import pandas as pd
|
4 |
+
from io import BytesIO
|
5 |
+
from PIL import Image
|
6 |
+
import numpy as np
|
7 |
+
import os
|
8 |
+
from pathlib import Path
|
9 |
+
from gliner import GLiNER
|
10 |
+
|
11 |
+
# Set environment variables for model storage
|
12 |
+
os.environ['GLINER_HOME'] = str(Path.home() / '.gliner_models')
|
13 |
+
os.environ['TRANSFORMERS_CACHE'] = str(Path.home() / '.gliner_models' / 'cache')
|
14 |
+
|
15 |
+
# Initialize EasyOCR reader
|
16 |
+
reader = easyocr.Reader(['en'])
|
17 |
+
|
18 |
+
def get_model_path():
|
19 |
+
"""Get the path to the local model directory."""
|
20 |
+
base_dir = Path.home() / '.gliner_models'
|
21 |
+
model_dir = base_dir / 'gliner_large-v2.1'
|
22 |
+
return model_dir
|
23 |
+
|
24 |
+
def download_model():
|
25 |
+
"""Download the model if it doesn't exist locally."""
|
26 |
+
model_dir = get_model_path()
|
27 |
+
if not model_dir.exists():
|
28 |
+
st.info("Downloading GLiNER model for the first time... This may take a few minutes.")
|
29 |
+
try:
|
30 |
+
model_dir.parent.mkdir(parents=True, exist_ok=True)
|
31 |
+
temp_model = GLiNER.from_pretrained("urchade/gliner_large-v2.1")
|
32 |
+
temp_model.save_pretrained(str(model_dir))
|
33 |
+
st.success("Model downloaded successfully!")
|
34 |
+
return temp_model
|
35 |
+
except Exception as e:
|
36 |
+
st.error(f"Error downloading model: {str(e)}")
|
37 |
+
raise e
|
38 |
+
return None
|
39 |
+
|
40 |
+
@st.cache_resource
|
41 |
+
def load_gliner_model():
|
42 |
+
"""Load the GLiNER model, downloading it if necessary."""
|
43 |
+
model_dir = get_model_path()
|
44 |
+
if model_dir.exists():
|
45 |
+
try:
|
46 |
+
return GLiNER.from_pretrained(str(model_dir))
|
47 |
+
except Exception as e:
|
48 |
+
st.warning("Error loading existing model. Attempting to redownload...")
|
49 |
+
import shutil
|
50 |
+
shutil.rmtree(model_dir, ignore_errors=True)
|
51 |
+
|
52 |
+
model = download_model()
|
53 |
+
if model:
|
54 |
+
return model
|
55 |
+
return GLiNER.from_pretrained(str(model_dir))
|
56 |
+
|
57 |
+
def extract_text_from_image(image):
|
58 |
+
"""Extracts text from a single image using EasyOCR."""
|
59 |
+
image_array = np.array(image)
|
60 |
+
return reader.readtext(image_array, detail=0, paragraph=True)
|
61 |
+
|
62 |
+
def process_entities(text: str, model, threshold: float, nested_ner: bool) -> dict:
|
63 |
+
"""Process text with GLiNER model - matching app.py implementation."""
|
64 |
+
# Define our business card labels
|
65 |
+
labels = "person name, company name, job title, phone, email, address"
|
66 |
+
labels = [label.strip() for label in labels.split(",")]
|
67 |
+
|
68 |
+
# Get predictions
|
69 |
+
entities = model.predict_entities(
|
70 |
+
text,
|
71 |
+
labels,
|
72 |
+
flat_ner=not nested_ner,
|
73 |
+
threshold=threshold
|
74 |
+
)
|
75 |
+
|
76 |
+
# Format results matching app.py structure
|
77 |
+
formatted_entities = []
|
78 |
+
for entity in entities:
|
79 |
+
formatted_entities.append({
|
80 |
+
"entity": entity["label"],
|
81 |
+
"word": entity["text"],
|
82 |
+
"start": entity["start"],
|
83 |
+
"end": entity["end"]
|
84 |
+
})
|
85 |
+
|
86 |
+
# Organize results by category
|
87 |
+
results = {
|
88 |
+
"Person Name": [],
|
89 |
+
"Company Name": [],
|
90 |
+
"Job Title": [],
|
91 |
+
"Phone": [],
|
92 |
+
"Email": [],
|
93 |
+
"Address": []
|
94 |
+
}
|
95 |
+
|
96 |
+
for entity in formatted_entities:
|
97 |
+
category = entity["entity"].title()
|
98 |
+
if category in results:
|
99 |
+
results[category].append(entity["word"])
|
100 |
+
|
101 |
+
# Join multiple entries with semicolons
|
102 |
+
return {k: "; ".join(set(v)) if v else "" for k, v in results.items()}
|
103 |
+
|
104 |
+
def main():
|
105 |
+
st.title("Business Card Information Extractor")
|
106 |
+
|
107 |
+
# Model settings in sidebar
|
108 |
+
st.sidebar.title("Settings")
|
109 |
+
|
110 |
+
threshold = st.sidebar.slider(
|
111 |
+
"Detection Threshold",
|
112 |
+
min_value=0.0,
|
113 |
+
max_value=1.0,
|
114 |
+
value=0.3,
|
115 |
+
step=0.05,
|
116 |
+
help="Lower values will detect more entities (as in app.py example)"
|
117 |
+
)
|
118 |
+
|
119 |
+
nested_ner = st.sidebar.checkbox(
|
120 |
+
"Enable Nested NER",
|
121 |
+
value=True,
|
122 |
+
help="Allow detection of nested entities"
|
123 |
+
)
|
124 |
+
|
125 |
+
# Upload options
|
126 |
+
upload_type = st.sidebar.radio("Upload Type", ("Single", "Batch"))
|
127 |
+
|
128 |
+
# File uploader
|
129 |
+
uploaded_files = st.file_uploader(
|
130 |
+
"Upload Business Card Image(s)",
|
131 |
+
type=["png", "jpg", "jpeg"],
|
132 |
+
accept_multiple_files=(upload_type == "Batch")
|
133 |
+
)
|
134 |
+
|
135 |
+
if uploaded_files:
|
136 |
+
# Load model
|
137 |
+
model = load_gliner_model()
|
138 |
+
|
139 |
+
# Process files
|
140 |
+
results = []
|
141 |
+
files_to_process = uploaded_files if isinstance(uploaded_files, list) else [uploaded_files]
|
142 |
+
|
143 |
+
progress_bar = st.progress(0)
|
144 |
+
for idx, file in enumerate(files_to_process):
|
145 |
+
with st.expander(f"Processing {file.name}"):
|
146 |
+
# Load and extract text
|
147 |
+
image = Image.open(file)
|
148 |
+
extracted_text = extract_text_from_image(image)
|
149 |
+
clean_text = " ".join(extracted_text)
|
150 |
+
|
151 |
+
# Show extracted text
|
152 |
+
st.text("Extracted Text:")
|
153 |
+
st.text(clean_text)
|
154 |
+
|
155 |
+
# Process with GLiNER
|
156 |
+
result = process_entities(clean_text, model, threshold, nested_ner)
|
157 |
+
result["File Name"] = file.name
|
158 |
+
results.append(result)
|
159 |
+
|
160 |
+
# Show individual results
|
161 |
+
st.json(result)
|
162 |
+
|
163 |
+
progress_bar.progress((idx + 1) / len(files_to_process))
|
164 |
+
|
165 |
+
# Show final results
|
166 |
+
if results:
|
167 |
+
st.success("Processing Complete!")
|
168 |
+
|
169 |
+
# Convert to DataFrame
|
170 |
+
df = pd.DataFrame(results)
|
171 |
+
|
172 |
+
# Reorder columns to put filename first
|
173 |
+
cols = ["File Name"] + [col for col in df.columns if col != "File Name"]
|
174 |
+
df = df[cols]
|
175 |
+
|
176 |
+
# Display results
|
177 |
+
st.dataframe(df, use_container_width=True)
|
178 |
+
|
179 |
+
# Provide download option
|
180 |
+
csv = df.to_csv(index=False)
|
181 |
+
st.download_button(
|
182 |
+
"Download Results CSV",
|
183 |
+
csv,
|
184 |
+
"business_card_results.csv",
|
185 |
+
"text/csv",
|
186 |
+
key='download-csv'
|
187 |
+
)
|
188 |
+
|
189 |
+
if __name__ == "__main__":
|
190 |
+
main()
|