Skip to content

[RL] Fix shape mismatch on tail batch in GRPO training#4252

Open
susanbao wants to merge 2 commits into
mainfrom
sanbao/gpt
Open

[RL] Fix shape mismatch on tail batch in GRPO training#4252
susanbao wants to merge 2 commits into
mainfrom
sanbao/gpt

Conversation

@susanbao

Copy link
Copy Markdown
Collaborator

This PR fixes the JAX shard_map shape mismatch (ValueError) caused by lazy filtering of long prompts in grain dataset by adding drop_remainder=True and doubling the test dataset slice.

@codecov

codecov Bot commented Jun 24, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 20.00000% with 4 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/trainers/post_train/rl/train_rl.py 0.00% 3 Missing ⚠️
...maxtext/trainers/post_train/rl/math_verify_pool.py 50.00% 1 Missing ⚠️

📢 Thoughts on this report? Let us know!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant