Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -74,29 +74,27 @@ def ocr_with_tesseract(img):
|
|
74 |
return extracted_text, confidences
|
75 |
|
76 |
# OCR & Classification Function
|
77 |
-
def generate_ocr(method,
|
78 |
-
if
|
79 |
raise gr.Error("Please upload an image!")
|
80 |
|
81 |
# Convert PIL Image to OpenCV format
|
82 |
-
|
83 |
|
84 |
# Select OCR method
|
85 |
if method == "PaddleOCR":
|
86 |
-
|
87 |
elif method == "EasyOCR":
|
88 |
-
|
89 |
elif method == "KerasOCR":
|
90 |
-
|
91 |
elif method == "TesseractOCR":
|
92 |
-
|
93 |
else:
|
94 |
return "Invalid OCR method", "N/A"
|
95 |
|
96 |
-
#
|
97 |
-
text_output =
|
98 |
-
|
99 |
-
# If no text detected, return early
|
100 |
if len(text_output) == 0:
|
101 |
return "No text detected!", "Cannot classify"
|
102 |
|
@@ -108,19 +106,15 @@ def generate_ocr(method, image):
|
|
108 |
outputs = model(**inputs)
|
109 |
logits = outputs.logits # Get raw logits
|
110 |
|
111 |
-
# Print raw logits
|
112 |
print(f"Raw logits: {logits}")
|
113 |
|
114 |
-
#
|
115 |
predicted_class = torch.argmax(logits, dim=1).item()
|
116 |
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
if predicted_class == 1:
|
121 |
-
label = "Spam"
|
122 |
-
else:
|
123 |
-
label = "Not Spam"
|
124 |
|
125 |
# Save results
|
126 |
save_results_to_repo(text_output, label)
|
|
|
74 |
return extracted_text, confidences
|
75 |
|
76 |
# OCR & Classification Function
|
77 |
+
def generate_ocr(method, img):
|
78 |
+
if img is None:
|
79 |
raise gr.Error("Please upload an image!")
|
80 |
|
81 |
# Convert PIL Image to OpenCV format
|
82 |
+
img = np.array(img)
|
83 |
|
84 |
# Select OCR method
|
85 |
if method == "PaddleOCR":
|
86 |
+
text_output = ocr_with_paddle(img)
|
87 |
elif method == "EasyOCR":
|
88 |
+
text_output = ocr_with_easy(img)
|
89 |
elif method == "KerasOCR":
|
90 |
+
text_output = ocr_with_keras(img)
|
91 |
elif method == "TesseractOCR":
|
92 |
+
text_output, _ = ocr_with_tesseract(img) # Ignore confidence values
|
93 |
else:
|
94 |
return "Invalid OCR method", "N/A"
|
95 |
|
96 |
+
# Clean and truncate the extracted text
|
97 |
+
text_output = text_output.strip()
|
|
|
|
|
98 |
if len(text_output) == 0:
|
99 |
return "No text detected!", "Cannot classify"
|
100 |
|
|
|
106 |
outputs = model(**inputs)
|
107 |
logits = outputs.logits # Get raw logits
|
108 |
|
109 |
+
# Debugging: Print raw logits
|
110 |
print(f"Raw logits: {logits}")
|
111 |
|
112 |
+
# Use raw logits directly instead of softmax
|
113 |
predicted_class = torch.argmax(logits, dim=1).item()
|
114 |
|
115 |
+
# Map class index to labels
|
116 |
+
label_map = {0: "Not Spam", 1: "Spam"}
|
117 |
+
label = label_map.get(predicted_class, "Unknown")
|
|
|
|
|
|
|
|
|
118 |
|
119 |
# Save results
|
120 |
save_results_to_repo(text_output, label)
|