Skip to content

Commit

Permalink
allow repeats
Browse files Browse the repository at this point in the history
  • Loading branch information
baberabb committed Jan 21, 2025
1 parent f198094 commit 80b2244
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 14 deletions.
5 changes: 3 additions & 2 deletions lm_eval/tasks/math500/math500.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ test_split: test
doc_to_text: "Solve the following math problem efficiently and clearly:\n\n- For simple problems (2 steps or fewer):\nProvide a concise solution with minimal explanation.\n\n- For complex problems (3 steps or more):\nUse this step-by-step format:\n\n## Step 1: [Concise description]\n[Brief explanation and calculations]\n\n## Step 2: [Concise description]\n[Brief explanation and calculations]\n\n...\n\nRegardless of the approach, always conclude with:\n\nTherefore, the final answer is: $\\\\boxed{answer}$. I hope it is correct.\n\nWhere [answer] is just the final number or expression that solves the problem.\n\nProblem: {{ problem }}"
process_results: !function utils.process_results
doc_to_target: "{{answer if few_shot is undefined else solution}}"
repeats: 2
generation_kwargs:
until: []
max_gen_toks: 5120
do_sample: false
temperature: 0
do_sample: true
temperature: 0.6
metric_list:
- metric: exact_match
aggregation: mean
Expand Down
25 changes: 13 additions & 12 deletions lm_eval/tasks/math500/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def _process_doc(doc: dict) -> dict:
# ]


# calculate pass@1 for all results
def process_results(doc: dict, results: List[str]) -> Dict[str, int]:
candidates = results[0]

Expand Down Expand Up @@ -184,18 +185,18 @@ def is_equiv(x1: str, x2: str) -> bool:
return False


def get_unnormalized_answer(text: str) -> str:
INVALID_ANSWER = "[invalidanswer]"
end_seq = "I hope it is correct."
text += end_seq
match = re.search(
r"Final Answer: The final answer is(.*?). I hope it is correct.",
text,
)
if match:
return match.group(1).strip()
else:
return INVALID_ANSWER
# def get_unnormalized_answer(text: str) -> str:
# INVALID_ANSWER = "[invalidanswer]"
# end_seq = "I hope it is correct."
# text += end_seq
# match = re.search(
# r"Final Answer: The final answer is(.*?). I hope it is correct.",
# text,
# )
# if match:
# return match.group(1).strip()
# else:
# return INVALID_ANSWER


SUBSTITUTIONS = [
Expand Down

0 comments on commit 80b2244

Please sign in to comment.