Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions tests/unit/rl_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,129 @@ def test_score_multiple_completions(self):
self.assertEqual(scores[1], -2.0)


class TestCheckNumbers(unittest.TestCase):
"""Tests for utils_rl.check_numbers.

Covers two scenarios:
1. Whether the regex can extract an answer from the completion.
2. Whether the extracted value matches (or does not match) the reference answer.
"""

def setUp(self):
self.config = _make_config()

def _check(self, completions, answer):
return utils_rl.check_numbers(
prompts=None,
completions=completions,
answer=answer,
tmvp_config=self.config,
question=["test question"] * len(completions),
)

# ---------------------------------------------------------------
# Scenario 1: regex extraction succeeds / fails
# ---------------------------------------------------------------

@pytest.mark.cpu_only
def test_extraction_succeeds_full_format(self):
"""Full <reasoning>…</reasoning><answer>…</answer> format allows extraction."""
scores = self._check(
completions=["<reasoning>40 + 2 = 42</reasoning><answer>42</answer>"],
answer=["42"],
)
self.assertEqual(scores[0], 1.5)

@pytest.mark.cpu_only
def test_extraction_fails_no_tags(self):
"""Plain-text completion without any tags yields score 0 (cannot extract)."""
scores = self._check(
completions=["The answer is 42."],
answer=["42"],
)
self.assertEqual(scores[0], 0)

@pytest.mark.cpu_only
def test_extraction_fails_answer_tags_only(self):
"""<answer> tag alone (no <reasoning> block) is not matched by the regex, score 0."""
scores = self._check(
completions=["<answer>42</answer>"],
answer=["42"],
)
self.assertEqual(scores[0], 0)

@pytest.mark.cpu_only
def test_extraction_fails_reasoning_tags_only(self):
"""<reasoning> block with no <answer> tag cannot be extracted, score 0."""
scores = self._check(
completions=["<reasoning>The answer is 42.</reasoning>"],
answer=["42"],
)
self.assertEqual(scores[0], 0)

@pytest.mark.cpu_only
def test_extraction_batch_mixed(self):
"""Batch with one extractable and one non-extractable completion."""
scores = self._check(
completions=[
"<reasoning>thinking</reasoning><answer>7</answer>", # extractable
"just 7", # not extractable
],
answer=["7", "7"],
)
self.assertEqual(scores[0], 1.5)
self.assertEqual(scores[1], 0)

# ---------------------------------------------------------------
# Scenario 2: extraction succeeds, value matches/mismatches the answer
# ---------------------------------------------------------------

@pytest.mark.cpu_only
def test_extracted_matches_integer_answer(self):
"""Extracted integer equal to reference answer earns 1.5."""
scores = self._check(
completions=["<reasoning>simple</reasoning><answer>100</answer>"],
answer=["100"],
)
self.assertEqual(scores[0], 1.5)

@pytest.mark.cpu_only
def test_extracted_does_not_match_answer(self):
"""Extracted number that differs from the reference answer earns 0.0."""
scores = self._check(
completions=["<reasoning>wrong path</reasoning><answer>99</answer>"],
answer=["42"],
)
self.assertEqual(scores[0], 0.0)

@pytest.mark.cpu_only
def test_extracted_matches_comma_formatted_number(self):
"""Comma-formatted guess (e.g. '1,000') normalizes to match integer answer '1000'."""
scores = self._check(
completions=["<reasoning>cost calculation</reasoning><answer>1,000</answer>"],
answer=["1000"],
)
self.assertEqual(scores[0], 1.5)

@pytest.mark.cpu_only
def test_extracted_matches_with_currency_prefix(self):
"""Leading '$' in extracted answer is normalized away before comparison."""
scores = self._check(
completions=["<reasoning>price is $16</reasoning><answer>$16</answer>"],
answer=["16"],
)
self.assertEqual(scores[0], 1.5)

@pytest.mark.cpu_only
def test_extracted_non_numeric_no_match(self):
"""Non-numeric extraction that cannot be float-converted and does not math-verify returns 0."""
scores = self._check(
completions=["<reasoning>thinking</reasoning><answer>blue</answer>"],
answer=["red"],
)
self.assertEqual(scores[0], 0.0)


class TestExtractHashAnswer(unittest.TestCase):
"""Tests for utils_rl.extract_hash_answer."""

Expand Down
Loading