imadcat commited on
Commit
8ff1f8a
·
1 Parent(s): 0cf2063
Files changed (1) hide show
  1. o3_mini_solver_generic_streamlit.py +566 -0
o3_mini_solver_generic_streamlit.py ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import streamlit as st
3
+ import pandas as pd
4
+ from functools import lru_cache
5
+ from ortools.sat.python import cp_model
6
+
7
+ # Precompile common regex patterns
8
+ NOISE_PATTERNS = [
9
+ re.compile(r'\bthe person\b', re.IGNORECASE),
10
+ re.compile(r'\bperson\b', re.IGNORECASE),
11
+ re.compile(r'\bsmoker\b', re.IGNORECASE),
12
+ re.compile(r'\buses?\b', re.IGNORECASE),
13
+ re.compile(r'\bloves?\b', re.IGNORECASE),
14
+ re.compile(r'\bpartial to\b', re.IGNORECASE),
15
+ re.compile(r'\bbouquet of\b', re.IGNORECASE),
16
+ re.compile(r'\bboquet of\b', re.IGNORECASE),
17
+ re.compile(r'\bvase of\b', re.IGNORECASE),
18
+ re.compile(r'\bmany unique\b', re.IGNORECASE),
19
+ re.compile(r'\bbouquet\b', re.IGNORECASE),
20
+ re.compile(r'\bvase\b', re.IGNORECASE),
21
+ re.compile(r'\barrangement\b', re.IGNORECASE)
22
+ ]
23
+ ARTICLE_PATTERN = re.compile(r'\b(a|an|the)\b', re.IGNORECASE)
24
+ EXTRA_WORDS_PATTERN = re.compile(r'\b(owner|lover|enthusiast)\b', re.IGNORECASE)
25
+ NON_ALNUM_PATTERN = re.compile(r'[^a-z0-9 ]')
26
+ MULTISPACE_PATTERN = re.compile(r'\s+')
27
+
28
+ # Mapping for ordinal words.
29
+ ordinal_map = {
30
+ "first": 1,
31
+ "second": 2,
32
+ "third": 3,
33
+ "fourth": 4,
34
+ "fifth": 5,
35
+ "sixth": 6
36
+ }
37
+
38
+ # Mapping for number words in between–clues.
39
+ word_to_num = {
40
+ "one": 1,
41
+ "two": 2,
42
+ "three": 3,
43
+ "four": 4,
44
+ "five": 5,
45
+ "six": 6
46
+ }
47
+
48
+ def sanitize_token(text):
49
+ text = text.lower()
50
+ for pattern in NOISE_PATTERNS:
51
+ text = pattern.sub('', text)
52
+ text = ARTICLE_PATTERN.sub('', text)
53
+ text = EXTRA_WORDS_PATTERN.sub('', text)
54
+ text = NON_ALNUM_PATTERN.sub(' ', text)
55
+ text = MULTISPACE_PATTERN.sub(' ', text)
56
+ return text.strip()
57
+
58
+ def normalize_token(token, candidate_key=None):
59
+ token_norm = token.lower()
60
+ if candidate_key == "month":
61
+ month_map = {
62
+ "january": "jan",
63
+ "february": "feb",
64
+ "march": "mar",
65
+ "april": "april",
66
+ "may": "may",
67
+ "june": "jun",
68
+ "july": "jul",
69
+ "august": "aug",
70
+ "september": "sept",
71
+ "october": "oct",
72
+ "november": "nov",
73
+ "december": "dec",
74
+ }
75
+ for full, abbr in month_map.items():
76
+ token_norm = token_norm.replace(full, abbr)
77
+ elif candidate_key == "nationalities":
78
+ nat_map = {
79
+ "swedish": "swede",
80
+ "british": "brit",
81
+ "danish": "dane"
82
+ }
83
+ for full, abbr in nat_map.items():
84
+ token_norm = token_norm.replace(full, abbr)
85
+ return token_norm
86
+
87
+ @lru_cache(maxsize=1024)
88
+ def lemmatize_text_cached(text):
89
+ if nlp is not None:
90
+ doc = nlp(text)
91
+ return " ".join(token.lemma_ for token in doc)
92
+ return text
93
+
94
+ def lemmatize_text(text):
95
+ return lemmatize_text_cached(text)
96
+
97
+ def get_category_key(category):
98
+ cat_lower = category.lower()
99
+ if "favorite" in cat_lower and "color" in cat_lower:
100
+ return "favorite_color"
101
+ if "hair" in cat_lower:
102
+ return "hair_color"
103
+ if "name" in cat_lower:
104
+ return "name"
105
+ if "vacation" in cat_lower:
106
+ return "vacation"
107
+ if "occupation" in cat_lower:
108
+ return "occupation"
109
+ if "flower" in cat_lower:
110
+ return "flower"
111
+ if "lunch" in cat_lower:
112
+ return "lunch"
113
+ if "smoothie" in cat_lower:
114
+ return "smoothie"
115
+ if "hobby" in cat_lower:
116
+ return "hobby"
117
+ if "pet" in cat_lower or "animal" in cat_lower:
118
+ return "animals"
119
+ if "birthday" in cat_lower or "month" in cat_lower:
120
+ return "month"
121
+ if "nationalities" in cat_lower:
122
+ return "nationalities"
123
+ tokens = cat_lower.split()
124
+ return tokens[-1] if tokens else cat_lower
125
+
126
+ def shorten_category(category):
127
+ key = get_category_key(category)
128
+ return key.replace('_', ' ')
129
+
130
+ # Try loading spaCy with the transformer-based model.
131
+ try:
132
+ import spacy
133
+ nlp = spacy.load("en_core_web_trf")
134
+ except Exception as e:
135
+ st.warning("spaCy model could not be loaded; proceeding without it: " + str(e))
136
+ nlp = None
137
+
138
+ class PuzzleSolver:
139
+ def __init__(self, puzzle_text, debug=False):
140
+ self.puzzle_text = puzzle_text
141
+ self.num_houses = None
142
+ self.categories = {}
143
+ self.category_keys = {}
144
+ self.clues = []
145
+ self.var = {}
146
+ self.model = cp_model.CpModel()
147
+ self.debug = debug
148
+ self.category_keywords = {
149
+ "nationalities": ["swede", "norwegian", "german", "chinese", "dane", "brit", "danish", "swedish", "british"],
150
+ "name": ["name"],
151
+ "vacation": ["vacation", "trip", "break"],
152
+ "occupation": ["occupation", "job"],
153
+ "lunch": ["lunch", "soup", "stew", "grilled", "cheese", "spaghetti", "pizza", "stir"],
154
+ "smoothie": ["smoothie", "cherry", "dragonfruit", "watermelon", "lime", "blueberry", "desert"],
155
+ "models": ["phone", "model", "iphone", "pixel", "oneplus", "samsung", "xiaomi", "huawei"],
156
+ "hair_color": ["hair"],
157
+ "month": ["month", "birthday", "birth"],
158
+ "hobby": ["photography", "cooking", "knitting", "woodworking", "paints", "painting", "gardening"],
159
+ "pet": ["rabbit", "hamster", "fish", "cat", "bird", "dog"],
160
+ "animals": ["rabbit", "dog", "horse", "fish", "bird", "cat"]
161
+ }
162
+
163
+ def parse_puzzle(self):
164
+ m = re.search(r"There are (\d+) houses", self.puzzle_text, re.IGNORECASE)
165
+ self.num_houses = int(m.group(1)) if m else 6
166
+ cat_pattern = re.compile(r"^[-*]\s*(.*?):\s*(.+)$")
167
+ for line in self.puzzle_text.splitlines():
168
+ line = line.strip()
169
+ m = cat_pattern.match(line)
170
+ if m:
171
+ cat_label = m.group(1).strip()
172
+ attr_line = m.group(2).strip()
173
+ attrs = [x.strip() for x in attr_line.split(",") if x.strip()]
174
+ self.categories[cat_label] = attrs
175
+ self.category_keys[cat_label] = get_category_key(cat_label)
176
+ if self.debug:
177
+ st.write(f"Parsed category: '{cat_label}' with attributes {attrs}")
178
+ st.write(f"Assigned key for category: {self.category_keys[cat_label]}")
179
+ clues_section = False
180
+ for line in self.puzzle_text.splitlines():
181
+ if "### Clues:" in line:
182
+ clues_section = True
183
+ continue
184
+ if clues_section:
185
+ clean = line.strip()
186
+ if clean:
187
+ self.clues.append(clean)
188
+ if self.debug:
189
+ st.write(f"Parsed clue: {clean}")
190
+
191
+ def build_variables(self):
192
+ for cat, attrs in self.categories.items():
193
+ self.var[cat] = {}
194
+ for attr in attrs:
195
+ self.var[cat][attr] = self.model.NewIntVar(1, self.num_houses, f"{cat}_{attr}")
196
+ self.model.AddAllDifferent(list(self.var[cat].values()))
197
+ if self.debug:
198
+ st.write(f"Added all-different constraint for category '{cat}'.")
199
+
200
+ def find_attribute(self, token):
201
+ token_san = sanitize_token(token)
202
+ candidate_key = None
203
+ for key, kws in self.category_keywords.items():
204
+ if any(kw in token_san for kw in kws):
205
+ candidate_key = key
206
+ if self.debug:
207
+ st.write(f"Debug: Token '{token}' suggests category key '{candidate_key}' based on keywords {kws}.")
208
+ break
209
+ if candidate_key == "pet":
210
+ candidate_key = "animals"
211
+ token_lemmatized = lemmatize_text(token_san)
212
+ if self.debug:
213
+ st.write(f"Debug: Lemmatized token for '{token}': '{token_lemmatized}'")
214
+ if candidate_key == "hobby" and "paint" in token_lemmatized:
215
+ token_lemmatized = token_lemmatized.replace("paint", "painting")
216
+ if self.debug:
217
+ st.write(f"Debug: Adjusted hobby token to '{token_lemmatized}' for proper matching.")
218
+ if candidate_key in ["month", "nationalities"]:
219
+ token_san = normalize_token(token_san, candidate_key)
220
+ if self.debug:
221
+ st.write(f"Debug: Normalized token for {candidate_key}: '{token_san}'")
222
+ if candidate_key:
223
+ categories_to_search = [(cat, attrs) for cat, attrs in self.categories.items() if self.category_keys.get(cat) == candidate_key]
224
+ if self.debug:
225
+ st.write(f"Debug: Restricted search to categories: {[cat for cat, _ in categories_to_search]}")
226
+ else:
227
+ categories_to_search = self.categories.items()
228
+ best = None
229
+ best_len = 0
230
+ for cat, attrs in categories_to_search:
231
+ for attr in attrs:
232
+ attr_san = sanitize_token(attr)
233
+ if candidate_key in ["month", "nationalities"]:
234
+ attr_san = normalize_token(attr_san, candidate_key)
235
+ pattern = rf'\b{re.escape(attr_san)}\b'
236
+ if re.search(pattern, token_san) or re.search(pattern, token_lemmatized):
237
+ if len(attr_san) > best_len:
238
+ best = (cat, attr)
239
+ best_len = len(attr_san)
240
+ else:
241
+ alt = attr_san[:-1] if attr_san.endswith('s') else attr_san + 's'
242
+ if re.search(rf'\b{re.escape(alt)}\b', token_san) or re.search(rf'\b{re.escape(alt)}\b', token_lemmatized):
243
+ if len(attr_san) > best_len:
244
+ best = (cat, attr)
245
+ best_len = len(attr_san)
246
+ if best is None and candidate_key in ["month", "nationalities"]:
247
+ if self.debug:
248
+ st.write(f"Debug: Fallback for {candidate_key}: no match found in token '{token_san}'. Trying explicit substrings.")
249
+ mapping = {}
250
+ if candidate_key == "month":
251
+ mapping = {"jan": "jan", "feb": "feb", "mar": "mar",
252
+ "april": "april", "may": "may", "jun": "jun",
253
+ "jul": "jul", "aug": "aug", "sept": "sept", "oct": "oct", "nov": "nov", "dec": "dec"}
254
+ elif candidate_key == "nationalities":
255
+ mapping = {"swede": "swede", "norwegian": "norwegian", "german": "german",
256
+ "chinese": "chinese", "dane": "dane", "brit": "brit"}
257
+ for key_abbr in mapping.values():
258
+ if re.search(rf'\b{re.escape(key_abbr)}\b', token_san):
259
+ for cat, attrs in categories_to_search:
260
+ for attr in attrs:
261
+ attr_san = normalize_token(sanitize_token(attr), candidate_key)
262
+ if attr_san == key_abbr:
263
+ best = (cat, attr)
264
+ best_len = len(attr_san)
265
+ if self.debug:
266
+ st.write(f"Debug: Found fallback match: '{attr_san}' in token '{token_san}'.")
267
+ break
268
+ if best is not None:
269
+ break
270
+ if best is not None:
271
+ break
272
+ if best is None and self.debug:
273
+ st.write(f"DEBUG: No attribute found for token '{token}' (sanitized: '{token_san}', lemmatized: '{token_lemmatized}').")
274
+ return best
275
+
276
+ def find_all_attributes_in_text(self, text):
277
+ found = []
278
+ text_san = sanitize_token(text)
279
+ for cat, attrs in self.categories.items():
280
+ for attr in attrs:
281
+ attr_san = sanitize_token(attr)
282
+ if re.search(rf'\b{re.escape(attr_san)}\b', text_san):
283
+ found.append((cat, attr))
284
+ unique = []
285
+ seen = set()
286
+ for pair in found:
287
+ if pair not in seen:
288
+ unique.append(pair)
289
+ seen.add(pair)
290
+ return unique
291
+
292
+ def spacy_equality_extraction(self, text):
293
+ if nlp is None:
294
+ return None, None
295
+ doc = nlp(text)
296
+ for token in doc:
297
+ if token.lemma_ == "be" and token.dep_ == "ROOT":
298
+ subj = None
299
+ attr = None
300
+ for child in token.children:
301
+ if child.dep_ in ["nsubj", "nsubjpass"]:
302
+ subj = child
303
+ if child.dep_ in ["attr", "acomp"]:
304
+ attr = child
305
+ if subj and attr:
306
+ subject_span = doc[subj.left_edge.i : subj.right_edge.i+1].text
307
+ attr_span = doc[attr.left_edge.i : attr.right_edge.i+1].text
308
+ return subject_span, attr_span
309
+ ents = list(doc.ents)
310
+ if len(ents) >= 2:
311
+ return ents[0].text, ents[1].text
312
+ return None, None
313
+
314
+ def apply_constraint_equality(self, token1, token2):
315
+ a1 = self.find_attribute(token1)
316
+ a2 = self.find_attribute(token2)
317
+ if a1 and a2:
318
+ cat1, attr1 = a1
319
+ cat2, attr2 = a2
320
+ self.model.Add(self.var[cat1][attr1] == self.var[cat2][attr2])
321
+ if self.debug:
322
+ st.write(f"Added constraint: [{cat1}][{attr1}] == [{cat2}][{attr2}]")
323
+ else:
324
+ if self.debug:
325
+ st.write(f"Warning: could not apply equality between '{token1}' and '{token2}'")
326
+
327
+ def apply_constraint_inequality(self, token, house_number):
328
+ a1 = self.find_attribute(token)
329
+ if a1:
330
+ cat, attr = a1
331
+ self.model.Add(self.var[cat][attr] != house_number)
332
+ if self.debug:
333
+ st.write(f"Added constraint: [{cat}][{attr}] != {house_number}")
334
+ else:
335
+ if self.debug:
336
+ st.write(f"Warning: could not apply inequality for '{token}' at house {house_number}")
337
+
338
+ def apply_constraint_position(self, token1, op, token2):
339
+ a1 = self.find_attribute(token1)
340
+ a2 = self.find_attribute(token2)
341
+ if a1 and a2:
342
+ cat1, attr1 = a1
343
+ cat2, attr2 = a2
344
+ if op == "==":
345
+ self.model.Add(self.var[cat1][attr1] == self.var[cat2][attr2])
346
+ if self.debug:
347
+ st.write(f"Added constraint: [{cat1}][{attr1}] == [{cat2}][{attr2}]")
348
+ elif op == "<":
349
+ self.model.Add(self.var[cat1][attr1] < self.var[cat2][attr2])
350
+ if self.debug:
351
+ st.write(f"Added constraint: [{cat1}][{attr1}] < [{cat2}][{attr2}]")
352
+ elif op == ">":
353
+ self.model.Add(self.var[cat1][attr1] > self.var[cat2][attr2])
354
+ if self.debug:
355
+ st.write(f"Added constraint: [{cat1}][{attr1}] > [{cat2}][{attr2}]")
356
+ elif op == "+1":
357
+ self.model.Add(self.var[cat1][attr1] + 1 == self.var[cat2][attr2])
358
+ if self.debug:
359
+ st.write(f"Added constraint: [{cat1}][{attr1}] + 1 == [{cat2}][{attr2}]")
360
+ elif op == "-1":
361
+ self.model.Add(self.var[cat1][attr1] - 1 == self.var[cat2][attr2])
362
+ if self.debug:
363
+ st.write(f"Added constraint: [{cat1}][{attr1}] - 1 == [{cat2}][{attr2}]")
364
+ else:
365
+ if self.debug:
366
+ st.write(f"Warning: could not apply position constraint between '{token1}' and '{token2}' with op '{op}'")
367
+
368
+ def apply_constraint_next_to(self, token1, token2):
369
+ a1 = self.find_attribute(token1)
370
+ a2 = self.find_attribute(token2)
371
+ if a1 and a2:
372
+ cat1, attr1 = a1
373
+ cat2, attr2 = a2
374
+ diff = self.model.NewIntVar(0, self.num_houses, f"diff_{attr1}_{attr2}")
375
+ self.model.AddAbsEquality(diff, self.var[cat1][attr1] - self.var[cat2][attr2])
376
+ self.model.Add(diff == 1)
377
+ if self.debug:
378
+ st.write(f"Added next-to constraint: |[{cat1}][{attr1}] - [{cat2}][{attr2}]| == 1")
379
+ else:
380
+ if self.debug:
381
+ st.write(f"Warning: could not apply next-to constraint between '{token1}' and '{token2}'")
382
+
383
+ def apply_constraint_between(self, token1, token2, houses_between):
384
+ a1 = self.find_attribute(token1)
385
+ a2 = self.find_attribute(token2)
386
+ if a1 and a2:
387
+ cat1, attr1 = a1
388
+ cat2, attr2 = a2
389
+ diff = self.model.NewIntVar(0, self.num_houses, f"between_{attr1}_{attr2}")
390
+ self.model.AddAbsEquality(diff, self.var[cat1][attr1] - self.var[cat2][attr2])
391
+ self.model.Add(diff == houses_between + 1)
392
+ if self.debug:
393
+ st.write(f"Added between constraint: |[{cat1}][{attr1}] - [{cat2}][{attr2}]| == {houses_between + 1}")
394
+ else:
395
+ if self.debug:
396
+ st.write(f"Warning: could not apply between constraint for '{token1}' and '{token2}' with {houses_between} houses in between")
397
+
398
+ def apply_constraint_fixed(self, token, house_number):
399
+ a1 = self.find_attribute(token)
400
+ if a1:
401
+ cat, attr = a1
402
+ self.model.Add(self.var[cat][attr] == house_number)
403
+ if self.debug:
404
+ st.write(f"Added fixed constraint: [{cat}][{attr}] == {house_number}")
405
+ else:
406
+ if self.debug:
407
+ st.write(f"Warning: could not apply fixed constraint for '{token}' at house {house_number}")
408
+
409
+ def process_clue(self, clue):
410
+ text = re.sub(r'^\d+\.\s*', '', clue).strip()
411
+ if self.debug:
412
+ st.write(f"Processing clue: {text}")
413
+ ordinal_numbers = r"(?:\d+|first|second|third|fourth|fifth|sixth)"
414
+ m_fixed = re.search(rf"(.+?) is in the ({ordinal_numbers}) house", text, re.IGNORECASE)
415
+ if m_fixed:
416
+ token = m_fixed.group(1).strip()
417
+ num_str = m_fixed.group(2).strip().lower()
418
+ house_num = int(num_str) if num_str.isdigit() else ordinal_map.get(num_str)
419
+ if house_num is not None:
420
+ self.apply_constraint_fixed(token, house_num)
421
+ return
422
+ m_not = re.search(rf"(.+?) is not in the ({ordinal_numbers}) house", text, re.IGNORECASE)
423
+ if m_not:
424
+ token = m_not.group(1).strip()
425
+ num_str = m_not.group(2).strip().lower()
426
+ house_num = int(num_str) if num_str.isdigit() else ordinal_map.get(num_str)
427
+ if house_num is not None:
428
+ self.apply_constraint_inequality(token, house_num)
429
+ return
430
+ m_left = re.search(r"(.+?) is directly left of (.+)", text, re.IGNORECASE)
431
+ if m_left:
432
+ token1 = m_left.group(1).strip()
433
+ token2 = m_left.group(2).strip()
434
+ self.apply_constraint_position(token1, "+1", token2)
435
+ return
436
+ m_right = re.search(r"(.+?) is directly right of (.+)", text, re.IGNORECASE)
437
+ if m_right:
438
+ token1 = m_right.group(1).strip()
439
+ token2 = m_right.group(2).strip()
440
+ self.apply_constraint_position(token1, "-1", token2)
441
+ return
442
+ m_sl = re.search(r"(.+?) is somewhere to the left of (.+)", text, re.IGNORECASE)
443
+ if m_sl:
444
+ token1 = m_sl.group(1).strip()
445
+ token2 = m_sl.group(2).strip()
446
+ self.apply_constraint_position(token1, "<", token2)
447
+ return
448
+ m_sr = re.search(r"(.+?) is somewhere to the right of (.+)", text, re.IGNORECASE)
449
+ if m_sr:
450
+ token1 = m_sr.group(1).strip()
451
+ token2 = m_sr.group(2).strip()
452
+ self.apply_constraint_position(token1, ">", token2)
453
+ return
454
+ m_next = re.search(r"(.+?) and (.+?) are next to each other", text, re.IGNORECASE)
455
+ if m_next:
456
+ token1 = m_next.group(1).strip()
457
+ token2 = m_next.group(2).strip()
458
+ self.apply_constraint_next_to(token1, token2)
459
+ return
460
+ m_between = re.search(rf"There (?:are|is) (\d+|one|two|three|four|five|six) house(?:s)? between (.+?) and (.+)", text, re.IGNORECASE)
461
+ if m_between:
462
+ num_str = m_between.group(1).strip().lower()
463
+ houses_between = int(num_str) if num_str.isdigit() else word_to_num.get(num_str)
464
+ token1 = m_between.group(2).strip()
465
+ token2 = m_between.group(3).strip()
466
+ self.apply_constraint_between(token1, token2, houses_between)
467
+ return
468
+ m_eq = re.search(r"(.+)\sis(?: the)?\s(.+)", text, re.IGNORECASE)
469
+ if m_eq:
470
+ token1 = m_eq.group(1).strip()
471
+ token2 = m_eq.group(2).strip()
472
+ token1 = re.sub(r"^(the person who\s+|who\s+)", "", token1, flags=re.IGNORECASE).strip()
473
+ token2 = re.sub(r"^(a\s+|an\s+|the\s+)", "", token2, flags=re.IGNORECASE).strip()
474
+ a1 = self.find_attribute(token1)
475
+ a2 = self.find_attribute(token2)
476
+ if a1 and a2:
477
+ self.apply_constraint_equality(token1, token2)
478
+ return
479
+ else:
480
+ if self.debug:
481
+ st.write("Equality regex failed to extract valid attributes using token cleaning.")
482
+ if nlp is not None:
483
+ left, right = self.spacy_equality_extraction(text)
484
+ if left and right:
485
+ if self.debug:
486
+ st.write(f"spaCy extracted equality: '{left}' == '{right}'")
487
+ self.apply_constraint_equality(left, right)
488
+ return
489
+ if self.debug:
490
+ st.write(f"Unprocessed clue: {text}")
491
+
492
+ def process_all_clues(self):
493
+ for clue in self.clues:
494
+ self.process_clue(clue)
495
+
496
+ def solve(self):
497
+ solver = cp_model.CpSolver()
498
+ # Use all available cores (0 means all available, 1 means single core for deployment to streamlit community cloud)
499
+ solver.parameters.num_search_workers = 1
500
+ status = solver.Solve(self.model)
501
+ if status in (cp_model.OPTIMAL, cp_model.FEASIBLE):
502
+ solution = {}
503
+ for house in range(1, self.num_houses + 1):
504
+ solution[house] = {}
505
+ for cat, attr_dict in self.var.items():
506
+ for attr, var in attr_dict.items():
507
+ if solver.Value(var) == house:
508
+ solution[house][cat] = attr
509
+ return solution
510
+ else:
511
+ if self.debug:
512
+ st.write("No solution found. The clues may be contradictory or incomplete.")
513
+ return None
514
+
515
+ def print_solution(self, solution):
516
+ if solution:
517
+ headers = ["House"] + [shorten_category(cat) for cat in self.categories.keys()]
518
+ table = []
519
+ for house in sorted(solution.keys()):
520
+ row = [str(house)]
521
+ for cat in self.categories.keys():
522
+ row.append(solution[house].get(cat, ""))
523
+ table.append(row)
524
+ df = pd.DataFrame(table, columns=headers)
525
+ return df
526
+ else:
527
+ return None
528
+
529
+ # Streamlit UI
530
+ st.title("Zebra Logic Puzzle Solver")
531
+ st.subheader("🦓 ZebraLogic: Benchmarking the Logical Reasoning Ability of Language Models")
532
+ st.markdown("""
533
+ Copy the Zebra Logic Puzzles description [from the huggingface site](https://huggingface.co/spaces/allenai/ZebraLogic), and paste it below.
534
+ """)
535
+
536
+ puzzle_text = st.text_area("Puzzle Input", height=300)
537
+ show_debug = st.checkbox("Show Debug Output", value=False)
538
+
539
+ # Use session_state to ensure the solution is computed only once per click.
540
+ if "puzzle_solved" not in st.session_state:
541
+ st.session_state["puzzle_solved"] = False
542
+
543
+ if st.button("Solve Puzzle") or st.session_state["puzzle_solved"]:
544
+ # Indicate that we've clicked the button
545
+ st.session_state["puzzle_solved"] = True
546
+
547
+ solver_instance = PuzzleSolver(puzzle_text, debug=show_debug)
548
+ solver_instance.parse_puzzle()
549
+ solver_instance.build_variables()
550
+ solver_instance.process_all_clues()
551
+
552
+ # st.subheader("Parsed Attributes (Categories & Their Attributes)")
553
+ # for cat, attrs in solver_instance.categories.items():
554
+ # st.markdown(f"**{cat}**: {', '.join(attrs)}")
555
+
556
+ # st.subheader("Parsed Clues")
557
+ # for i, clue in enumerate(solver_instance.clues, start=1):
558
+ # st.markdown(f"{i}. {clue}")
559
+
560
+ solution = solver_instance.solve()
561
+ st.subheader("Solution Table")
562
+ df_solution = solver_instance.print_solution(solution)
563
+ if df_solution is not None:
564
+ st.table(df_solution)
565
+ else:
566
+ st.error("No solution found. The clues may be contradictory or incomplete.")