-
Notifications
You must be signed in to change notification settings - Fork 974
Add trajectory-level deduplication for GRPO advantage normalization #462
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This pull request adds trajectory-level deduplication to GRPO advantage normalization to address turn-level bias in multi-turn reinforcement learning scenarios. The implementation introduces a new compute_grpo_outcome_advantage function that tracks unique (data_id, rollout_id) pairs to ensure each trajectory is counted only once when computing baseline statistics for advantage estimation.
Changes:
- Added
compute_grpo_outcome_advantagefunction with trajectory-level deduplication logic - Integrated new advantage computation into the training pipeline with configurable behavior via
compute_mean_std_cross_all_dataparameter - Added assertion to restrict trajectory-level normalization to GRPO algorithm only
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def compute_grpo_outcome_advantage( | ||
| token_level_rewards: torch.Tensor, | ||
| response_mask: torch.Tensor, | ||
| index: np.ndarray, | ||
| traj_index: np.ndarray | None = None, | ||
| epsilon: float = 1e-6, | ||
| norm_adv_by_std_in_grpo: bool = True, | ||
| compute_mean_std_cross_all_data: bool = True, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| """Compute advantage for GRPO with trajectory-level deduplication support. | ||
|
|
||
| This is a minimal extension of VeRL's GRPO implementation, adding support for | ||
| trajectory-level deduplication via `traj_index` and `compute_mean_std_cross_all_data`. | ||
|
|
||
| Args: | ||
| token_level_rewards: Shape (bs, response_length). | ||
| response_mask: Shape (bs, response_length). | ||
| index: Group index array (e.g., data_id). | ||
| traj_index: Trajectory index array (e.g., rollout_id). If None, no deduplication. | ||
| epsilon: Small value for numerical stability. | ||
| norm_adv_by_std_in_grpo: If True, normalize by std (original GRPO). If False, Dr.GRPO style. | ||
| compute_mean_std_cross_all_data: If True (default), compute mean/std across all data. | ||
| If False, compute mean/std per unique (index, traj_index) trajectory. | ||
|
|
||
| Returns: | ||
| Tuple of (advantages, returns), both shape (bs, response_length). | ||
| """ | ||
| scores = token_level_rewards.sum(dim=-1) | ||
|
|
||
| id2score: dict = defaultdict(list) | ||
| id2mean: dict = {} | ||
| id2std: dict = {} | ||
| seen_pairs: set = set() | ||
|
|
||
| with torch.no_grad(): | ||
| bsz = scores.shape[0] | ||
| for i in range(bsz): | ||
| # Trajectory deduplication: skip if (index, traj_index) already seen | ||
| if traj_index is not None and (index[i], traj_index[i]) in seen_pairs: | ||
| continue | ||
| id2score[index[i]].append(scores[i]) | ||
| # Mark as seen only when compute_mean_std_cross_all_data is False | ||
| if traj_index is not None and not compute_mean_std_cross_all_data: | ||
| seen_pairs.add((index[i], traj_index[i])) | ||
|
|
||
| for idx in id2score: | ||
| if len(id2score[idx]) == 1: | ||
| id2mean[idx] = torch.tensor(0.0) | ||
| id2std[idx] = torch.tensor(1.0) | ||
| elif len(id2score[idx]) > 1: | ||
| scores_tensor = torch.stack(id2score[idx]) | ||
| id2mean[idx] = torch.mean(scores_tensor) | ||
| id2std[idx] = torch.std(scores_tensor) | ||
| else: | ||
| raise ValueError(f"no score in prompt index: {idx}") | ||
|
|
||
| for i in range(bsz): | ||
| if norm_adv_by_std_in_grpo: | ||
| scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon) | ||
| else: | ||
| scores[i] = scores[i] - id2mean[index[i]] | ||
| scores = scores.unsqueeze(-1) * response_mask | ||
|
|
||
| return scores, scores | ||
|
|
Copilot
AI
Jan 21, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new compute_grpo_outcome_advantage function lacks test coverage. Given that this is a critical mathematical computation affecting training outcomes, unit tests should be added to verify:
- Correct behavior when
compute_mean_std_cross_all_data=TruevsFalse - Proper handling of trajectory deduplication with different
(index, traj_index)combinations - Device consistency (tensors on GPU)
- Edge cases: single-sample groups, all identical scores, etc.
- Correct advantage normalization with and without std division
Consider adding tests in tests/trainer/ directory or a new test file specifically for GRPO advantage computation.
| id2mean[idx] = torch.mean(scores_tensor) | ||
| id2std[idx] = torch.std(scores_tensor) | ||
| else: | ||
| raise ValueError(f"no score in prompt index: {idx}") |
Copilot
AI
Jan 21, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message uses f-string formatting but doesn't include the idx variable value. The message should be updated to include the actual index value that's causing the issue for better debugging:
raise ValueError(f"no score in prompt index: {idx}")should ensure the value is actually included in the error output.
| for i in range(bsz): | ||
| # Trajectory deduplication: skip if (index, traj_index) already seen | ||
| if traj_index is not None and (index[i], traj_index[i]) in seen_pairs: | ||
| continue | ||
| id2score[index[i]].append(scores[i]) | ||
| # Mark as seen only when compute_mean_std_cross_all_data is False | ||
| if traj_index is not None and not compute_mean_std_cross_all_data: | ||
| seen_pairs.add((index[i], traj_index[i])) | ||
|
|
||
| for idx in id2score: | ||
| if len(id2score[idx]) == 1: | ||
| id2mean[idx] = torch.tensor(0.0) | ||
| id2std[idx] = torch.tensor(1.0) | ||
| elif len(id2score[idx]) > 1: | ||
| scores_tensor = torch.stack(id2score[idx]) | ||
| id2mean[idx] = torch.mean(scores_tensor) | ||
| id2std[idx] = torch.std(scores_tensor) | ||
| else: | ||
| raise ValueError(f"no score in prompt index: {idx}") | ||
|
|
||
| for i in range(bsz): | ||
| if norm_adv_by_std_in_grpo: | ||
| scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon) | ||
| else: | ||
| scores[i] = scores[i] - id2mean[index[i]] |
Copilot
AI
Jan 21, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function accepts index as np.ndarray but uses it directly to index into dictionaries (lines 90, 108, 110). In Python dictionaries, NumPy array elements may not hash correctly depending on their dtype. If index contains NumPy scalars (e.g., np.int64), this could cause issues.
Consider converting array elements to Python native types when using them as dictionary keys:
idx_key = int(index[i])
id2score[idx_key].append(scores[i])Or document that index must contain hashable types that work as dictionary keys.
| if not compute_mean_std_cross_all_data: | ||
| assert self.config.algorithm.adv_estimator == AdvantageEstimator.GRPO, ( | ||
| f"compute_mean_std_cross_all_data=False is only supported for GRPO, " | ||
| f"got {self.config.algorithm.adv_estimator}" | ||
| ) |
Copilot
AI
Jan 21, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assertion on lines 432-435 only checks when compute_mean_std_cross_all_data=False, but the new GRPO implementation is used for ALL GRPO cases (line 438 condition). This means when compute_mean_std_cross_all_data=True with a non-GRPO estimator, the assertion is never checked, but the code would still go through the else branch at line 452.
While this is not necessarily incorrect (the else branch handles non-GRPO cases properly), the control flow could be clearer. Consider restructuring to make the relationship between the flag and the GRPO check more explicit, or add a comment explaining why the assertion only needs to check the False case.
@microsoft-github-policy-service agree |
Problem
Agent-lightning inherits VeRL's default advantage estimation, which assumes each batch sample is independent. In multi-turn scenarios, this causes turn-level bias: trajectories with more turns contribute more to baseline statistics (mean/std), leading to biased advantage estimation and inefficient optimization.
Solution
Implements trajectory-level deduplication using
(data_id, rollout_id)pairs. Setalgorithm.compute_mean_std_cross_all_data=Falseto ensure each trajectory is counted only once when computing baselines.In
agentlightning.verl.trainer, we re-implementcomputer_grpo_outcome_advantageto integrate the new trajectory-level deduplication logic while keeping dependency on VeRL minimal.Example Configuration
Control the normalization behavior via the
compute_mean_std_cross_all_dataparameter:compute_mean_std_cross_all_data=True(default): Cross-all-data normalization, more stable but still counts each turncompute_mean_std_cross_all_data=False: Trajectory-level normalization - each trajectory counted only once, eliminates biasImplementation
Affected algorithms (currently only GRPO is supported):
Files modified:
agentlightning/verl/trainer.py: Addcomputer_grpo_outcome_advantage