winamnd commited on
Commit
bf26c19
·
verified ·
1 Parent(s): 44a4a1e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -20
app.py CHANGED
@@ -74,29 +74,27 @@ def ocr_with_tesseract(img):
74
  return extracted_text, confidences
75
 
76
  # OCR & Classification Function
77
- def generate_ocr(method, image):
78
- if image is None:
79
  raise gr.Error("Please upload an image!")
80
 
81
  # Convert PIL Image to OpenCV format
82
- img_cv = preprocess_image(image)
83
 
84
  # Select OCR method
85
  if method == "PaddleOCR":
86
- extracted_text, confidences = ocr_with_paddle(img_cv)
87
  elif method == "EasyOCR":
88
- extracted_text, confidences = ocr_with_easy(img_cv)
89
  elif method == "KerasOCR":
90
- extracted_text, confidences = ocr_with_keras(img_cv)
91
  elif method == "TesseractOCR":
92
- extracted_text, confidences = ocr_with_tesseract(img_cv)
93
  else:
94
  return "Invalid OCR method", "N/A"
95
 
96
- # Join extracted text into a single string
97
- text_output = " ".join(extracted_text).strip()
98
-
99
- # If no text detected, return early
100
  if len(text_output) == 0:
101
  return "No text detected!", "Cannot classify"
102
 
@@ -108,19 +106,15 @@ def generate_ocr(method, image):
108
  outputs = model(**inputs)
109
  logits = outputs.logits # Get raw logits
110
 
111
- # Print raw logits for debugging
112
  print(f"Raw logits: {logits}")
113
 
114
- # Compare raw logits instead of using softmax
115
  predicted_class = torch.argmax(logits, dim=1).item()
116
 
117
- print(f"Predicted Class Index: {predicted_class}") # Debugging output
118
-
119
- # Ensure correct label mapping
120
- if predicted_class == 1:
121
- label = "Spam"
122
- else:
123
- label = "Not Spam"
124
 
125
  # Save results
126
  save_results_to_repo(text_output, label)
 
74
  return extracted_text, confidences
75
 
76
  # OCR & Classification Function
77
+ def generate_ocr(method, img):
78
+ if img is None:
79
  raise gr.Error("Please upload an image!")
80
 
81
  # Convert PIL Image to OpenCV format
82
+ img = np.array(img)
83
 
84
  # Select OCR method
85
  if method == "PaddleOCR":
86
+ text_output = ocr_with_paddle(img)
87
  elif method == "EasyOCR":
88
+ text_output = ocr_with_easy(img)
89
  elif method == "KerasOCR":
90
+ text_output = ocr_with_keras(img)
91
  elif method == "TesseractOCR":
92
+ text_output, _ = ocr_with_tesseract(img) # Ignore confidence values
93
  else:
94
  return "Invalid OCR method", "N/A"
95
 
96
+ # Clean and truncate the extracted text
97
+ text_output = text_output.strip()
 
 
98
  if len(text_output) == 0:
99
  return "No text detected!", "Cannot classify"
100
 
 
106
  outputs = model(**inputs)
107
  logits = outputs.logits # Get raw logits
108
 
109
+ # Debugging: Print raw logits
110
  print(f"Raw logits: {logits}")
111
 
112
+ # Use raw logits directly instead of softmax
113
  predicted_class = torch.argmax(logits, dim=1).item()
114
 
115
+ # Map class index to labels
116
+ label_map = {0: "Not Spam", 1: "Spam"}
117
+ label = label_map.get(predicted_class, "Unknown")
 
 
 
 
118
 
119
  # Save results
120
  save_results_to_repo(text_output, label)