Spaces:
Sleeping
Sleeping
imadcat
commited on
Commit
·
8ff1f8a
1
Parent(s):
0cf2063
init
Browse files
o3_mini_solver_generic_streamlit.py
ADDED
@@ -0,0 +1,566 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import streamlit as st
|
3 |
+
import pandas as pd
|
4 |
+
from functools import lru_cache
|
5 |
+
from ortools.sat.python import cp_model
|
6 |
+
|
7 |
+
# Precompile common regex patterns
|
8 |
+
NOISE_PATTERNS = [
|
9 |
+
re.compile(r'\bthe person\b', re.IGNORECASE),
|
10 |
+
re.compile(r'\bperson\b', re.IGNORECASE),
|
11 |
+
re.compile(r'\bsmoker\b', re.IGNORECASE),
|
12 |
+
re.compile(r'\buses?\b', re.IGNORECASE),
|
13 |
+
re.compile(r'\bloves?\b', re.IGNORECASE),
|
14 |
+
re.compile(r'\bpartial to\b', re.IGNORECASE),
|
15 |
+
re.compile(r'\bbouquet of\b', re.IGNORECASE),
|
16 |
+
re.compile(r'\bboquet of\b', re.IGNORECASE),
|
17 |
+
re.compile(r'\bvase of\b', re.IGNORECASE),
|
18 |
+
re.compile(r'\bmany unique\b', re.IGNORECASE),
|
19 |
+
re.compile(r'\bbouquet\b', re.IGNORECASE),
|
20 |
+
re.compile(r'\bvase\b', re.IGNORECASE),
|
21 |
+
re.compile(r'\barrangement\b', re.IGNORECASE)
|
22 |
+
]
|
23 |
+
ARTICLE_PATTERN = re.compile(r'\b(a|an|the)\b', re.IGNORECASE)
|
24 |
+
EXTRA_WORDS_PATTERN = re.compile(r'\b(owner|lover|enthusiast)\b', re.IGNORECASE)
|
25 |
+
NON_ALNUM_PATTERN = re.compile(r'[^a-z0-9 ]')
|
26 |
+
MULTISPACE_PATTERN = re.compile(r'\s+')
|
27 |
+
|
28 |
+
# Mapping for ordinal words.
|
29 |
+
ordinal_map = {
|
30 |
+
"first": 1,
|
31 |
+
"second": 2,
|
32 |
+
"third": 3,
|
33 |
+
"fourth": 4,
|
34 |
+
"fifth": 5,
|
35 |
+
"sixth": 6
|
36 |
+
}
|
37 |
+
|
38 |
+
# Mapping for number words in between–clues.
|
39 |
+
word_to_num = {
|
40 |
+
"one": 1,
|
41 |
+
"two": 2,
|
42 |
+
"three": 3,
|
43 |
+
"four": 4,
|
44 |
+
"five": 5,
|
45 |
+
"six": 6
|
46 |
+
}
|
47 |
+
|
48 |
+
def sanitize_token(text):
|
49 |
+
text = text.lower()
|
50 |
+
for pattern in NOISE_PATTERNS:
|
51 |
+
text = pattern.sub('', text)
|
52 |
+
text = ARTICLE_PATTERN.sub('', text)
|
53 |
+
text = EXTRA_WORDS_PATTERN.sub('', text)
|
54 |
+
text = NON_ALNUM_PATTERN.sub(' ', text)
|
55 |
+
text = MULTISPACE_PATTERN.sub(' ', text)
|
56 |
+
return text.strip()
|
57 |
+
|
58 |
+
def normalize_token(token, candidate_key=None):
|
59 |
+
token_norm = token.lower()
|
60 |
+
if candidate_key == "month":
|
61 |
+
month_map = {
|
62 |
+
"january": "jan",
|
63 |
+
"february": "feb",
|
64 |
+
"march": "mar",
|
65 |
+
"april": "april",
|
66 |
+
"may": "may",
|
67 |
+
"june": "jun",
|
68 |
+
"july": "jul",
|
69 |
+
"august": "aug",
|
70 |
+
"september": "sept",
|
71 |
+
"october": "oct",
|
72 |
+
"november": "nov",
|
73 |
+
"december": "dec",
|
74 |
+
}
|
75 |
+
for full, abbr in month_map.items():
|
76 |
+
token_norm = token_norm.replace(full, abbr)
|
77 |
+
elif candidate_key == "nationalities":
|
78 |
+
nat_map = {
|
79 |
+
"swedish": "swede",
|
80 |
+
"british": "brit",
|
81 |
+
"danish": "dane"
|
82 |
+
}
|
83 |
+
for full, abbr in nat_map.items():
|
84 |
+
token_norm = token_norm.replace(full, abbr)
|
85 |
+
return token_norm
|
86 |
+
|
87 |
+
@lru_cache(maxsize=1024)
|
88 |
+
def lemmatize_text_cached(text):
|
89 |
+
if nlp is not None:
|
90 |
+
doc = nlp(text)
|
91 |
+
return " ".join(token.lemma_ for token in doc)
|
92 |
+
return text
|
93 |
+
|
94 |
+
def lemmatize_text(text):
|
95 |
+
return lemmatize_text_cached(text)
|
96 |
+
|
97 |
+
def get_category_key(category):
|
98 |
+
cat_lower = category.lower()
|
99 |
+
if "favorite" in cat_lower and "color" in cat_lower:
|
100 |
+
return "favorite_color"
|
101 |
+
if "hair" in cat_lower:
|
102 |
+
return "hair_color"
|
103 |
+
if "name" in cat_lower:
|
104 |
+
return "name"
|
105 |
+
if "vacation" in cat_lower:
|
106 |
+
return "vacation"
|
107 |
+
if "occupation" in cat_lower:
|
108 |
+
return "occupation"
|
109 |
+
if "flower" in cat_lower:
|
110 |
+
return "flower"
|
111 |
+
if "lunch" in cat_lower:
|
112 |
+
return "lunch"
|
113 |
+
if "smoothie" in cat_lower:
|
114 |
+
return "smoothie"
|
115 |
+
if "hobby" in cat_lower:
|
116 |
+
return "hobby"
|
117 |
+
if "pet" in cat_lower or "animal" in cat_lower:
|
118 |
+
return "animals"
|
119 |
+
if "birthday" in cat_lower or "month" in cat_lower:
|
120 |
+
return "month"
|
121 |
+
if "nationalities" in cat_lower:
|
122 |
+
return "nationalities"
|
123 |
+
tokens = cat_lower.split()
|
124 |
+
return tokens[-1] if tokens else cat_lower
|
125 |
+
|
126 |
+
def shorten_category(category):
|
127 |
+
key = get_category_key(category)
|
128 |
+
return key.replace('_', ' ')
|
129 |
+
|
130 |
+
# Try loading spaCy with the transformer-based model.
|
131 |
+
try:
|
132 |
+
import spacy
|
133 |
+
nlp = spacy.load("en_core_web_trf")
|
134 |
+
except Exception as e:
|
135 |
+
st.warning("spaCy model could not be loaded; proceeding without it: " + str(e))
|
136 |
+
nlp = None
|
137 |
+
|
138 |
+
class PuzzleSolver:
|
139 |
+
def __init__(self, puzzle_text, debug=False):
|
140 |
+
self.puzzle_text = puzzle_text
|
141 |
+
self.num_houses = None
|
142 |
+
self.categories = {}
|
143 |
+
self.category_keys = {}
|
144 |
+
self.clues = []
|
145 |
+
self.var = {}
|
146 |
+
self.model = cp_model.CpModel()
|
147 |
+
self.debug = debug
|
148 |
+
self.category_keywords = {
|
149 |
+
"nationalities": ["swede", "norwegian", "german", "chinese", "dane", "brit", "danish", "swedish", "british"],
|
150 |
+
"name": ["name"],
|
151 |
+
"vacation": ["vacation", "trip", "break"],
|
152 |
+
"occupation": ["occupation", "job"],
|
153 |
+
"lunch": ["lunch", "soup", "stew", "grilled", "cheese", "spaghetti", "pizza", "stir"],
|
154 |
+
"smoothie": ["smoothie", "cherry", "dragonfruit", "watermelon", "lime", "blueberry", "desert"],
|
155 |
+
"models": ["phone", "model", "iphone", "pixel", "oneplus", "samsung", "xiaomi", "huawei"],
|
156 |
+
"hair_color": ["hair"],
|
157 |
+
"month": ["month", "birthday", "birth"],
|
158 |
+
"hobby": ["photography", "cooking", "knitting", "woodworking", "paints", "painting", "gardening"],
|
159 |
+
"pet": ["rabbit", "hamster", "fish", "cat", "bird", "dog"],
|
160 |
+
"animals": ["rabbit", "dog", "horse", "fish", "bird", "cat"]
|
161 |
+
}
|
162 |
+
|
163 |
+
def parse_puzzle(self):
|
164 |
+
m = re.search(r"There are (\d+) houses", self.puzzle_text, re.IGNORECASE)
|
165 |
+
self.num_houses = int(m.group(1)) if m else 6
|
166 |
+
cat_pattern = re.compile(r"^[-*]\s*(.*?):\s*(.+)$")
|
167 |
+
for line in self.puzzle_text.splitlines():
|
168 |
+
line = line.strip()
|
169 |
+
m = cat_pattern.match(line)
|
170 |
+
if m:
|
171 |
+
cat_label = m.group(1).strip()
|
172 |
+
attr_line = m.group(2).strip()
|
173 |
+
attrs = [x.strip() for x in attr_line.split(",") if x.strip()]
|
174 |
+
self.categories[cat_label] = attrs
|
175 |
+
self.category_keys[cat_label] = get_category_key(cat_label)
|
176 |
+
if self.debug:
|
177 |
+
st.write(f"Parsed category: '{cat_label}' with attributes {attrs}")
|
178 |
+
st.write(f"Assigned key for category: {self.category_keys[cat_label]}")
|
179 |
+
clues_section = False
|
180 |
+
for line in self.puzzle_text.splitlines():
|
181 |
+
if "### Clues:" in line:
|
182 |
+
clues_section = True
|
183 |
+
continue
|
184 |
+
if clues_section:
|
185 |
+
clean = line.strip()
|
186 |
+
if clean:
|
187 |
+
self.clues.append(clean)
|
188 |
+
if self.debug:
|
189 |
+
st.write(f"Parsed clue: {clean}")
|
190 |
+
|
191 |
+
def build_variables(self):
|
192 |
+
for cat, attrs in self.categories.items():
|
193 |
+
self.var[cat] = {}
|
194 |
+
for attr in attrs:
|
195 |
+
self.var[cat][attr] = self.model.NewIntVar(1, self.num_houses, f"{cat}_{attr}")
|
196 |
+
self.model.AddAllDifferent(list(self.var[cat].values()))
|
197 |
+
if self.debug:
|
198 |
+
st.write(f"Added all-different constraint for category '{cat}'.")
|
199 |
+
|
200 |
+
def find_attribute(self, token):
|
201 |
+
token_san = sanitize_token(token)
|
202 |
+
candidate_key = None
|
203 |
+
for key, kws in self.category_keywords.items():
|
204 |
+
if any(kw in token_san for kw in kws):
|
205 |
+
candidate_key = key
|
206 |
+
if self.debug:
|
207 |
+
st.write(f"Debug: Token '{token}' suggests category key '{candidate_key}' based on keywords {kws}.")
|
208 |
+
break
|
209 |
+
if candidate_key == "pet":
|
210 |
+
candidate_key = "animals"
|
211 |
+
token_lemmatized = lemmatize_text(token_san)
|
212 |
+
if self.debug:
|
213 |
+
st.write(f"Debug: Lemmatized token for '{token}': '{token_lemmatized}'")
|
214 |
+
if candidate_key == "hobby" and "paint" in token_lemmatized:
|
215 |
+
token_lemmatized = token_lemmatized.replace("paint", "painting")
|
216 |
+
if self.debug:
|
217 |
+
st.write(f"Debug: Adjusted hobby token to '{token_lemmatized}' for proper matching.")
|
218 |
+
if candidate_key in ["month", "nationalities"]:
|
219 |
+
token_san = normalize_token(token_san, candidate_key)
|
220 |
+
if self.debug:
|
221 |
+
st.write(f"Debug: Normalized token for {candidate_key}: '{token_san}'")
|
222 |
+
if candidate_key:
|
223 |
+
categories_to_search = [(cat, attrs) for cat, attrs in self.categories.items() if self.category_keys.get(cat) == candidate_key]
|
224 |
+
if self.debug:
|
225 |
+
st.write(f"Debug: Restricted search to categories: {[cat for cat, _ in categories_to_search]}")
|
226 |
+
else:
|
227 |
+
categories_to_search = self.categories.items()
|
228 |
+
best = None
|
229 |
+
best_len = 0
|
230 |
+
for cat, attrs in categories_to_search:
|
231 |
+
for attr in attrs:
|
232 |
+
attr_san = sanitize_token(attr)
|
233 |
+
if candidate_key in ["month", "nationalities"]:
|
234 |
+
attr_san = normalize_token(attr_san, candidate_key)
|
235 |
+
pattern = rf'\b{re.escape(attr_san)}\b'
|
236 |
+
if re.search(pattern, token_san) or re.search(pattern, token_lemmatized):
|
237 |
+
if len(attr_san) > best_len:
|
238 |
+
best = (cat, attr)
|
239 |
+
best_len = len(attr_san)
|
240 |
+
else:
|
241 |
+
alt = attr_san[:-1] if attr_san.endswith('s') else attr_san + 's'
|
242 |
+
if re.search(rf'\b{re.escape(alt)}\b', token_san) or re.search(rf'\b{re.escape(alt)}\b', token_lemmatized):
|
243 |
+
if len(attr_san) > best_len:
|
244 |
+
best = (cat, attr)
|
245 |
+
best_len = len(attr_san)
|
246 |
+
if best is None and candidate_key in ["month", "nationalities"]:
|
247 |
+
if self.debug:
|
248 |
+
st.write(f"Debug: Fallback for {candidate_key}: no match found in token '{token_san}'. Trying explicit substrings.")
|
249 |
+
mapping = {}
|
250 |
+
if candidate_key == "month":
|
251 |
+
mapping = {"jan": "jan", "feb": "feb", "mar": "mar",
|
252 |
+
"april": "april", "may": "may", "jun": "jun",
|
253 |
+
"jul": "jul", "aug": "aug", "sept": "sept", "oct": "oct", "nov": "nov", "dec": "dec"}
|
254 |
+
elif candidate_key == "nationalities":
|
255 |
+
mapping = {"swede": "swede", "norwegian": "norwegian", "german": "german",
|
256 |
+
"chinese": "chinese", "dane": "dane", "brit": "brit"}
|
257 |
+
for key_abbr in mapping.values():
|
258 |
+
if re.search(rf'\b{re.escape(key_abbr)}\b', token_san):
|
259 |
+
for cat, attrs in categories_to_search:
|
260 |
+
for attr in attrs:
|
261 |
+
attr_san = normalize_token(sanitize_token(attr), candidate_key)
|
262 |
+
if attr_san == key_abbr:
|
263 |
+
best = (cat, attr)
|
264 |
+
best_len = len(attr_san)
|
265 |
+
if self.debug:
|
266 |
+
st.write(f"Debug: Found fallback match: '{attr_san}' in token '{token_san}'.")
|
267 |
+
break
|
268 |
+
if best is not None:
|
269 |
+
break
|
270 |
+
if best is not None:
|
271 |
+
break
|
272 |
+
if best is None and self.debug:
|
273 |
+
st.write(f"DEBUG: No attribute found for token '{token}' (sanitized: '{token_san}', lemmatized: '{token_lemmatized}').")
|
274 |
+
return best
|
275 |
+
|
276 |
+
def find_all_attributes_in_text(self, text):
|
277 |
+
found = []
|
278 |
+
text_san = sanitize_token(text)
|
279 |
+
for cat, attrs in self.categories.items():
|
280 |
+
for attr in attrs:
|
281 |
+
attr_san = sanitize_token(attr)
|
282 |
+
if re.search(rf'\b{re.escape(attr_san)}\b', text_san):
|
283 |
+
found.append((cat, attr))
|
284 |
+
unique = []
|
285 |
+
seen = set()
|
286 |
+
for pair in found:
|
287 |
+
if pair not in seen:
|
288 |
+
unique.append(pair)
|
289 |
+
seen.add(pair)
|
290 |
+
return unique
|
291 |
+
|
292 |
+
def spacy_equality_extraction(self, text):
|
293 |
+
if nlp is None:
|
294 |
+
return None, None
|
295 |
+
doc = nlp(text)
|
296 |
+
for token in doc:
|
297 |
+
if token.lemma_ == "be" and token.dep_ == "ROOT":
|
298 |
+
subj = None
|
299 |
+
attr = None
|
300 |
+
for child in token.children:
|
301 |
+
if child.dep_ in ["nsubj", "nsubjpass"]:
|
302 |
+
subj = child
|
303 |
+
if child.dep_ in ["attr", "acomp"]:
|
304 |
+
attr = child
|
305 |
+
if subj and attr:
|
306 |
+
subject_span = doc[subj.left_edge.i : subj.right_edge.i+1].text
|
307 |
+
attr_span = doc[attr.left_edge.i : attr.right_edge.i+1].text
|
308 |
+
return subject_span, attr_span
|
309 |
+
ents = list(doc.ents)
|
310 |
+
if len(ents) >= 2:
|
311 |
+
return ents[0].text, ents[1].text
|
312 |
+
return None, None
|
313 |
+
|
314 |
+
def apply_constraint_equality(self, token1, token2):
|
315 |
+
a1 = self.find_attribute(token1)
|
316 |
+
a2 = self.find_attribute(token2)
|
317 |
+
if a1 and a2:
|
318 |
+
cat1, attr1 = a1
|
319 |
+
cat2, attr2 = a2
|
320 |
+
self.model.Add(self.var[cat1][attr1] == self.var[cat2][attr2])
|
321 |
+
if self.debug:
|
322 |
+
st.write(f"Added constraint: [{cat1}][{attr1}] == [{cat2}][{attr2}]")
|
323 |
+
else:
|
324 |
+
if self.debug:
|
325 |
+
st.write(f"Warning: could not apply equality between '{token1}' and '{token2}'")
|
326 |
+
|
327 |
+
def apply_constraint_inequality(self, token, house_number):
|
328 |
+
a1 = self.find_attribute(token)
|
329 |
+
if a1:
|
330 |
+
cat, attr = a1
|
331 |
+
self.model.Add(self.var[cat][attr] != house_number)
|
332 |
+
if self.debug:
|
333 |
+
st.write(f"Added constraint: [{cat}][{attr}] != {house_number}")
|
334 |
+
else:
|
335 |
+
if self.debug:
|
336 |
+
st.write(f"Warning: could not apply inequality for '{token}' at house {house_number}")
|
337 |
+
|
338 |
+
def apply_constraint_position(self, token1, op, token2):
|
339 |
+
a1 = self.find_attribute(token1)
|
340 |
+
a2 = self.find_attribute(token2)
|
341 |
+
if a1 and a2:
|
342 |
+
cat1, attr1 = a1
|
343 |
+
cat2, attr2 = a2
|
344 |
+
if op == "==":
|
345 |
+
self.model.Add(self.var[cat1][attr1] == self.var[cat2][attr2])
|
346 |
+
if self.debug:
|
347 |
+
st.write(f"Added constraint: [{cat1}][{attr1}] == [{cat2}][{attr2}]")
|
348 |
+
elif op == "<":
|
349 |
+
self.model.Add(self.var[cat1][attr1] < self.var[cat2][attr2])
|
350 |
+
if self.debug:
|
351 |
+
st.write(f"Added constraint: [{cat1}][{attr1}] < [{cat2}][{attr2}]")
|
352 |
+
elif op == ">":
|
353 |
+
self.model.Add(self.var[cat1][attr1] > self.var[cat2][attr2])
|
354 |
+
if self.debug:
|
355 |
+
st.write(f"Added constraint: [{cat1}][{attr1}] > [{cat2}][{attr2}]")
|
356 |
+
elif op == "+1":
|
357 |
+
self.model.Add(self.var[cat1][attr1] + 1 == self.var[cat2][attr2])
|
358 |
+
if self.debug:
|
359 |
+
st.write(f"Added constraint: [{cat1}][{attr1}] + 1 == [{cat2}][{attr2}]")
|
360 |
+
elif op == "-1":
|
361 |
+
self.model.Add(self.var[cat1][attr1] - 1 == self.var[cat2][attr2])
|
362 |
+
if self.debug:
|
363 |
+
st.write(f"Added constraint: [{cat1}][{attr1}] - 1 == [{cat2}][{attr2}]")
|
364 |
+
else:
|
365 |
+
if self.debug:
|
366 |
+
st.write(f"Warning: could not apply position constraint between '{token1}' and '{token2}' with op '{op}'")
|
367 |
+
|
368 |
+
def apply_constraint_next_to(self, token1, token2):
|
369 |
+
a1 = self.find_attribute(token1)
|
370 |
+
a2 = self.find_attribute(token2)
|
371 |
+
if a1 and a2:
|
372 |
+
cat1, attr1 = a1
|
373 |
+
cat2, attr2 = a2
|
374 |
+
diff = self.model.NewIntVar(0, self.num_houses, f"diff_{attr1}_{attr2}")
|
375 |
+
self.model.AddAbsEquality(diff, self.var[cat1][attr1] - self.var[cat2][attr2])
|
376 |
+
self.model.Add(diff == 1)
|
377 |
+
if self.debug:
|
378 |
+
st.write(f"Added next-to constraint: |[{cat1}][{attr1}] - [{cat2}][{attr2}]| == 1")
|
379 |
+
else:
|
380 |
+
if self.debug:
|
381 |
+
st.write(f"Warning: could not apply next-to constraint between '{token1}' and '{token2}'")
|
382 |
+
|
383 |
+
def apply_constraint_between(self, token1, token2, houses_between):
|
384 |
+
a1 = self.find_attribute(token1)
|
385 |
+
a2 = self.find_attribute(token2)
|
386 |
+
if a1 and a2:
|
387 |
+
cat1, attr1 = a1
|
388 |
+
cat2, attr2 = a2
|
389 |
+
diff = self.model.NewIntVar(0, self.num_houses, f"between_{attr1}_{attr2}")
|
390 |
+
self.model.AddAbsEquality(diff, self.var[cat1][attr1] - self.var[cat2][attr2])
|
391 |
+
self.model.Add(diff == houses_between + 1)
|
392 |
+
if self.debug:
|
393 |
+
st.write(f"Added between constraint: |[{cat1}][{attr1}] - [{cat2}][{attr2}]| == {houses_between + 1}")
|
394 |
+
else:
|
395 |
+
if self.debug:
|
396 |
+
st.write(f"Warning: could not apply between constraint for '{token1}' and '{token2}' with {houses_between} houses in between")
|
397 |
+
|
398 |
+
def apply_constraint_fixed(self, token, house_number):
|
399 |
+
a1 = self.find_attribute(token)
|
400 |
+
if a1:
|
401 |
+
cat, attr = a1
|
402 |
+
self.model.Add(self.var[cat][attr] == house_number)
|
403 |
+
if self.debug:
|
404 |
+
st.write(f"Added fixed constraint: [{cat}][{attr}] == {house_number}")
|
405 |
+
else:
|
406 |
+
if self.debug:
|
407 |
+
st.write(f"Warning: could not apply fixed constraint for '{token}' at house {house_number}")
|
408 |
+
|
409 |
+
def process_clue(self, clue):
|
410 |
+
text = re.sub(r'^\d+\.\s*', '', clue).strip()
|
411 |
+
if self.debug:
|
412 |
+
st.write(f"Processing clue: {text}")
|
413 |
+
ordinal_numbers = r"(?:\d+|first|second|third|fourth|fifth|sixth)"
|
414 |
+
m_fixed = re.search(rf"(.+?) is in the ({ordinal_numbers}) house", text, re.IGNORECASE)
|
415 |
+
if m_fixed:
|
416 |
+
token = m_fixed.group(1).strip()
|
417 |
+
num_str = m_fixed.group(2).strip().lower()
|
418 |
+
house_num = int(num_str) if num_str.isdigit() else ordinal_map.get(num_str)
|
419 |
+
if house_num is not None:
|
420 |
+
self.apply_constraint_fixed(token, house_num)
|
421 |
+
return
|
422 |
+
m_not = re.search(rf"(.+?) is not in the ({ordinal_numbers}) house", text, re.IGNORECASE)
|
423 |
+
if m_not:
|
424 |
+
token = m_not.group(1).strip()
|
425 |
+
num_str = m_not.group(2).strip().lower()
|
426 |
+
house_num = int(num_str) if num_str.isdigit() else ordinal_map.get(num_str)
|
427 |
+
if house_num is not None:
|
428 |
+
self.apply_constraint_inequality(token, house_num)
|
429 |
+
return
|
430 |
+
m_left = re.search(r"(.+?) is directly left of (.+)", text, re.IGNORECASE)
|
431 |
+
if m_left:
|
432 |
+
token1 = m_left.group(1).strip()
|
433 |
+
token2 = m_left.group(2).strip()
|
434 |
+
self.apply_constraint_position(token1, "+1", token2)
|
435 |
+
return
|
436 |
+
m_right = re.search(r"(.+?) is directly right of (.+)", text, re.IGNORECASE)
|
437 |
+
if m_right:
|
438 |
+
token1 = m_right.group(1).strip()
|
439 |
+
token2 = m_right.group(2).strip()
|
440 |
+
self.apply_constraint_position(token1, "-1", token2)
|
441 |
+
return
|
442 |
+
m_sl = re.search(r"(.+?) is somewhere to the left of (.+)", text, re.IGNORECASE)
|
443 |
+
if m_sl:
|
444 |
+
token1 = m_sl.group(1).strip()
|
445 |
+
token2 = m_sl.group(2).strip()
|
446 |
+
self.apply_constraint_position(token1, "<", token2)
|
447 |
+
return
|
448 |
+
m_sr = re.search(r"(.+?) is somewhere to the right of (.+)", text, re.IGNORECASE)
|
449 |
+
if m_sr:
|
450 |
+
token1 = m_sr.group(1).strip()
|
451 |
+
token2 = m_sr.group(2).strip()
|
452 |
+
self.apply_constraint_position(token1, ">", token2)
|
453 |
+
return
|
454 |
+
m_next = re.search(r"(.+?) and (.+?) are next to each other", text, re.IGNORECASE)
|
455 |
+
if m_next:
|
456 |
+
token1 = m_next.group(1).strip()
|
457 |
+
token2 = m_next.group(2).strip()
|
458 |
+
self.apply_constraint_next_to(token1, token2)
|
459 |
+
return
|
460 |
+
m_between = re.search(rf"There (?:are|is) (\d+|one|two|three|four|five|six) house(?:s)? between (.+?) and (.+)", text, re.IGNORECASE)
|
461 |
+
if m_between:
|
462 |
+
num_str = m_between.group(1).strip().lower()
|
463 |
+
houses_between = int(num_str) if num_str.isdigit() else word_to_num.get(num_str)
|
464 |
+
token1 = m_between.group(2).strip()
|
465 |
+
token2 = m_between.group(3).strip()
|
466 |
+
self.apply_constraint_between(token1, token2, houses_between)
|
467 |
+
return
|
468 |
+
m_eq = re.search(r"(.+)\sis(?: the)?\s(.+)", text, re.IGNORECASE)
|
469 |
+
if m_eq:
|
470 |
+
token1 = m_eq.group(1).strip()
|
471 |
+
token2 = m_eq.group(2).strip()
|
472 |
+
token1 = re.sub(r"^(the person who\s+|who\s+)", "", token1, flags=re.IGNORECASE).strip()
|
473 |
+
token2 = re.sub(r"^(a\s+|an\s+|the\s+)", "", token2, flags=re.IGNORECASE).strip()
|
474 |
+
a1 = self.find_attribute(token1)
|
475 |
+
a2 = self.find_attribute(token2)
|
476 |
+
if a1 and a2:
|
477 |
+
self.apply_constraint_equality(token1, token2)
|
478 |
+
return
|
479 |
+
else:
|
480 |
+
if self.debug:
|
481 |
+
st.write("Equality regex failed to extract valid attributes using token cleaning.")
|
482 |
+
if nlp is not None:
|
483 |
+
left, right = self.spacy_equality_extraction(text)
|
484 |
+
if left and right:
|
485 |
+
if self.debug:
|
486 |
+
st.write(f"spaCy extracted equality: '{left}' == '{right}'")
|
487 |
+
self.apply_constraint_equality(left, right)
|
488 |
+
return
|
489 |
+
if self.debug:
|
490 |
+
st.write(f"Unprocessed clue: {text}")
|
491 |
+
|
492 |
+
def process_all_clues(self):
|
493 |
+
for clue in self.clues:
|
494 |
+
self.process_clue(clue)
|
495 |
+
|
496 |
+
def solve(self):
|
497 |
+
solver = cp_model.CpSolver()
|
498 |
+
# Use all available cores (0 means all available, 1 means single core for deployment to streamlit community cloud)
|
499 |
+
solver.parameters.num_search_workers = 1
|
500 |
+
status = solver.Solve(self.model)
|
501 |
+
if status in (cp_model.OPTIMAL, cp_model.FEASIBLE):
|
502 |
+
solution = {}
|
503 |
+
for house in range(1, self.num_houses + 1):
|
504 |
+
solution[house] = {}
|
505 |
+
for cat, attr_dict in self.var.items():
|
506 |
+
for attr, var in attr_dict.items():
|
507 |
+
if solver.Value(var) == house:
|
508 |
+
solution[house][cat] = attr
|
509 |
+
return solution
|
510 |
+
else:
|
511 |
+
if self.debug:
|
512 |
+
st.write("No solution found. The clues may be contradictory or incomplete.")
|
513 |
+
return None
|
514 |
+
|
515 |
+
def print_solution(self, solution):
|
516 |
+
if solution:
|
517 |
+
headers = ["House"] + [shorten_category(cat) for cat in self.categories.keys()]
|
518 |
+
table = []
|
519 |
+
for house in sorted(solution.keys()):
|
520 |
+
row = [str(house)]
|
521 |
+
for cat in self.categories.keys():
|
522 |
+
row.append(solution[house].get(cat, ""))
|
523 |
+
table.append(row)
|
524 |
+
df = pd.DataFrame(table, columns=headers)
|
525 |
+
return df
|
526 |
+
else:
|
527 |
+
return None
|
528 |
+
|
529 |
+
# Streamlit UI
|
530 |
+
st.title("Zebra Logic Puzzle Solver")
|
531 |
+
st.subheader("🦓 ZebraLogic: Benchmarking the Logical Reasoning Ability of Language Models")
|
532 |
+
st.markdown("""
|
533 |
+
Copy the Zebra Logic Puzzles description [from the huggingface site](https://huggingface.co/spaces/allenai/ZebraLogic), and paste it below.
|
534 |
+
""")
|
535 |
+
|
536 |
+
puzzle_text = st.text_area("Puzzle Input", height=300)
|
537 |
+
show_debug = st.checkbox("Show Debug Output", value=False)
|
538 |
+
|
539 |
+
# Use session_state to ensure the solution is computed only once per click.
|
540 |
+
if "puzzle_solved" not in st.session_state:
|
541 |
+
st.session_state["puzzle_solved"] = False
|
542 |
+
|
543 |
+
if st.button("Solve Puzzle") or st.session_state["puzzle_solved"]:
|
544 |
+
# Indicate that we've clicked the button
|
545 |
+
st.session_state["puzzle_solved"] = True
|
546 |
+
|
547 |
+
solver_instance = PuzzleSolver(puzzle_text, debug=show_debug)
|
548 |
+
solver_instance.parse_puzzle()
|
549 |
+
solver_instance.build_variables()
|
550 |
+
solver_instance.process_all_clues()
|
551 |
+
|
552 |
+
# st.subheader("Parsed Attributes (Categories & Their Attributes)")
|
553 |
+
# for cat, attrs in solver_instance.categories.items():
|
554 |
+
# st.markdown(f"**{cat}**: {', '.join(attrs)}")
|
555 |
+
|
556 |
+
# st.subheader("Parsed Clues")
|
557 |
+
# for i, clue in enumerate(solver_instance.clues, start=1):
|
558 |
+
# st.markdown(f"{i}. {clue}")
|
559 |
+
|
560 |
+
solution = solver_instance.solve()
|
561 |
+
st.subheader("Solution Table")
|
562 |
+
df_solution = solver_instance.print_solution(solution)
|
563 |
+
if df_solution is not None:
|
564 |
+
st.table(df_solution)
|
565 |
+
else:
|
566 |
+
st.error("No solution found. The clues may be contradictory or incomplete.")
|