From 13531a7e6e54165e56a17ea5cf2d2830e2e3907b Mon Sep 17 00:00:00 2001 From: hengtaoguo Date: Thu, 12 Mar 2026 21:49:31 +0000 Subject: [PATCH] Add unit tests for RL check_numbers --- tests/unit/rl_utils_test.py | 123 ++++++++++++++++++++++++++++++++++++ 1 file changed, 123 insertions(+) diff --git a/tests/unit/rl_utils_test.py b/tests/unit/rl_utils_test.py index 452e5f474a..2fd04f93f6 100644 --- a/tests/unit/rl_utils_test.py +++ b/tests/unit/rl_utils_test.py @@ -190,6 +190,129 @@ def test_score_multiple_completions(self): self.assertEqual(scores[1], -2.0) +class TestCheckNumbers(unittest.TestCase): + """Tests for utils_rl.check_numbers. + + Covers two scenarios: + 1. Whether the regex can extract an answer from the completion. + 2. Whether the extracted value matches (or does not match) the reference answer. + """ + + def setUp(self): + self.config = _make_config() + + def _check(self, completions, answer): + return utils_rl.check_numbers( + prompts=None, + completions=completions, + answer=answer, + tmvp_config=self.config, + question=["test question"] * len(completions), + ) + + # --------------------------------------------------------------- + # Scenario 1: regex extraction succeeds / fails + # --------------------------------------------------------------- + + @pytest.mark.cpu_only + def test_extraction_succeeds_full_format(self): + """Full format allows extraction.""" + scores = self._check( + completions=["40 + 2 = 4242"], + answer=["42"], + ) + self.assertEqual(scores[0], 1.5) + + @pytest.mark.cpu_only + def test_extraction_fails_no_tags(self): + """Plain-text completion without any tags yields score 0 (cannot extract).""" + scores = self._check( + completions=["The answer is 42."], + answer=["42"], + ) + self.assertEqual(scores[0], 0) + + @pytest.mark.cpu_only + def test_extraction_fails_answer_tags_only(self): + """ tag alone (no block) is not matched by the regex, score 0.""" + scores = self._check( + completions=["42"], + answer=["42"], + ) + self.assertEqual(scores[0], 0) + + @pytest.mark.cpu_only + def test_extraction_fails_reasoning_tags_only(self): + """ block with no tag cannot be extracted, score 0.""" + scores = self._check( + completions=["The answer is 42."], + answer=["42"], + ) + self.assertEqual(scores[0], 0) + + @pytest.mark.cpu_only + def test_extraction_batch_mixed(self): + """Batch with one extractable and one non-extractable completion.""" + scores = self._check( + completions=[ + "thinking7", # extractable + "just 7", # not extractable + ], + answer=["7", "7"], + ) + self.assertEqual(scores[0], 1.5) + self.assertEqual(scores[1], 0) + + # --------------------------------------------------------------- + # Scenario 2: extraction succeeds, value matches/mismatches the answer + # --------------------------------------------------------------- + + @pytest.mark.cpu_only + def test_extracted_matches_integer_answer(self): + """Extracted integer equal to reference answer earns 1.5.""" + scores = self._check( + completions=["simple100"], + answer=["100"], + ) + self.assertEqual(scores[0], 1.5) + + @pytest.mark.cpu_only + def test_extracted_does_not_match_answer(self): + """Extracted number that differs from the reference answer earns 0.0.""" + scores = self._check( + completions=["wrong path99"], + answer=["42"], + ) + self.assertEqual(scores[0], 0.0) + + @pytest.mark.cpu_only + def test_extracted_matches_comma_formatted_number(self): + """Comma-formatted guess (e.g. '1,000') normalizes to match integer answer '1000'.""" + scores = self._check( + completions=["cost calculation1,000"], + answer=["1000"], + ) + self.assertEqual(scores[0], 1.5) + + @pytest.mark.cpu_only + def test_extracted_matches_with_currency_prefix(self): + """Leading '$' in extracted answer is normalized away before comparison.""" + scores = self._check( + completions=["price is $16$16"], + answer=["16"], + ) + self.assertEqual(scores[0], 1.5) + + @pytest.mark.cpu_only + def test_extracted_non_numeric_no_match(self): + """Non-numeric extraction that cannot be float-converted and does not math-verify returns 0.""" + scores = self._check( + completions=["thinkingblue"], + answer=["red"], + ) + self.assertEqual(scores[0], 0.0) + + class TestExtractHashAnswer(unittest.TestCase): """Tests for utils_rl.extract_hash_answer."""