David Pomerenke commited on
Commit
0384b92
·
1 Parent(s): b0c61ed

Shorter classification prompt + error handling

Browse files
Files changed (1) hide show
  1. evals/tasks.py +34 -23
evals/tasks.py CHANGED
@@ -90,48 +90,59 @@ async def classify_and_evaluate(model, bcp_47, nr):
90
  paragraphs = paragraphs[paragraphs["topic"].isin(top_topics)]
91
  examples = pd.concat(
92
  [
93
- paragraphs[paragraphs["topic"] == t].sample(n=5, random_state=42)
94
  for t in top_topics
95
  ]
96
- ).sample(frac=1, random_state=42)
97
  test_paragraphs = paragraphs[~paragraphs["URL"].isin(examples["URL"])].sample(
98
  frac=1, random_state=42
99
  )
100
  test_paragraph = test_paragraphs.iloc[nr]
101
 
102
- def topic_to_number(topic):
103
- return top_topics.get_loc(topic)
104
 
105
  messages = []
106
  for example in examples.itertuples():
107
  messages += [
108
- {"role": "user", "content": example.text},
109
- {"role": "assistant", "content": str(topic_to_number(example.topic))},
110
  ]
111
- reply = await complete(
112
- model=model,
113
- messages=[
114
- *messages,
115
- {
116
- "role": "user",
117
- "content": test_paragraph.text,
118
- },
119
- ],
120
- temperature=0,
121
- max_tokens=5,
122
- )
123
  try:
124
- pred = int(reply.choices[0].message.content.strip())
125
- except ValueError:
126
- pred = -1
127
- true = topic_to_number(test_paragraph.topic)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  return [
129
  {
130
  "model": model,
131
  "bcp_47": bcp_47,
132
  "task": "classification",
133
  "metric": "accuracy",
134
- "score": int(pred == true),
135
  "sentence_nr": nr,
136
  }
137
  ]
 
90
  paragraphs = paragraphs[paragraphs["topic"].isin(top_topics)]
91
  examples = pd.concat(
92
  [
93
+ paragraphs[paragraphs["topic"] == t].sample(n=1, random_state=42)
94
  for t in top_topics
95
  ]
96
+ ).sample(frac=1, random_state=nr)
97
  test_paragraphs = paragraphs[~paragraphs["URL"].isin(examples["URL"])].sample(
98
  frac=1, random_state=42
99
  )
100
  test_paragraph = test_paragraphs.iloc[nr]
101
 
102
+ def format_prompt(text):
103
+ return f"{text}\n\nTopic: {'|'.join(top_topics)}?"
104
 
105
  messages = []
106
  for example in examples.itertuples():
107
  messages += [
108
+ {"role": "user", "content": format_prompt(example.text)},
109
+ {"role": "assistant", "content": example.topic},
110
  ]
111
+ # some models have poor tokenization for some languages, and the prompt for this task is relatively long, so it sometimes exceeds the context window
112
+ # this is not just to blame on the context window but mostly on the model's tokenization, so we assign 0 accuracy in this case
 
 
 
 
 
 
 
 
 
 
113
  try:
114
+ reply = await complete(
115
+ model=model,
116
+ messages=[
117
+ *messages,
118
+ {
119
+ "role": "user",
120
+ "content": format_prompt(test_paragraph.text),
121
+ },
122
+ ],
123
+ temperature=0,
124
+ max_tokens=30,
125
+ )
126
+ response = reply.choices[0].message.content.strip().lower()
127
+ true = test_paragraph.topic
128
+ others = [t for t in top_topics if t != true]
129
+ acc = int(
130
+ response.startswith(true)
131
+ or (true in response and not any(o in response for o in others))
132
+ )
133
+ except Exception as e:
134
+ if "`inputs` tokens + `max_new_tokens` must be <= 4097" in str(e):
135
+ print(f"Max tokens exceeded for {model} in {bcp_47}")
136
+ acc = 0
137
+ else:
138
+ raise e
139
  return [
140
  {
141
  "model": model,
142
  "bcp_47": bcp_47,
143
  "task": "classification",
144
  "metric": "accuracy",
145
+ "score": acc,
146
  "sentence_nr": nr,
147
  }
148
  ]