[WIP]Support dpo/grpo simple padding_free#181
[WIP]Support dpo/grpo simple padding_free#181tastelikefeet wants to merge 3 commits intomodelscope:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request centralizes the logic for unpacking packed sequences (padding-free mode) into the InputProcessor class, moving it out of specific loss implementations like GRPO. It introduces a canonical method to detect packing and unpack tensors such as log-probabilities and labels into a per-sequence batch format. These changes are integrated into both Megatron and Transformers sequence parallel strategies. The review feedback identifies several improvement opportunities: ensuring the boundary detection logic explicitly includes the first sequence, relaxing the packing detection heuristic to support sequences of length one, and optimizing the unpacking process for better performance and consistency.
| pos = pos[0] # mrope | ||
| pos_flat = pos.view(-1) | ||
|
|
||
| boundaries = (pos_flat == 0).nonzero(as_tuple=True)[0].unique(sorted=True) |
There was a problem hiding this comment.
The boundary detection logic assumes that the first sequence in the packed tensor always starts at index 0 and that index 0 is always marked with position_id == 0. While generally true for PackingDataset, if the tensor has been sliced or if the first token has a non-zero position ID for any reason, the tokens before the first 0 will be lost. It is safer to explicitly include index 0 as a boundary.
| boundaries = (pos_flat == 0).nonzero(as_tuple=True)[0].unique(sorted=True) | |
| zero_indices = (pos_flat == 0).nonzero(as_tuple=True)[0] | |
| boundaries = torch.cat([pos_flat.new_tensor([0]), zero_indices]).unique(sorted=True) |
| return False | ||
| for i in range(pos.shape[0]): | ||
| row = pos[i] | ||
| if int((row == 0).sum()) > 1 and int((row == 1).sum()) > 1: |
There was a problem hiding this comment.
The heuristic for detecting packed sequences is too restrictive. Checking for (row == 1).sum() > 1 means that at least two sequences in the pack must have a length of at least 2. If a pack contains multiple sequences where some have length 1 (e.g., [0, 1, 2, 0]), this check will return False, and the sequences won't be unpacked, leading to incorrect loss calculation in GRPO. Checking for multiple 0 resets is sufficient to identify packing.
| if int((row == 0).sum()) > 1 and int((row == 1).sum()) > 1: | |
| if (row == 0).sum().item() > 1: |
| from copy import copy | ||
| unpacked_logps, unpacked_labels = self._unpack_by_position_ids( | ||
| position_ids, logps, labels, padding_values=[0, -100]) | ||
|
|
||
| inputs = copy(inputs) | ||
| outputs = copy(outputs) | ||
| outputs['logps'] = unpacked_logps | ||
| inputs['labels'] = unpacked_labels | ||
| return inputs, outputs |
There was a problem hiding this comment.
There are several improvements possible in this block:
- Performance: Importing
copyinside a method called every forward pass is inefficient. - Idiomatic Python: For dictionaries,
dict.copy()is preferred overcopy.copy(dict). - Consistency: It is better to also unpack
position_idsand update it in theinputsdictionary. This ensures that all sequence-length tensors in the batch have consistent shapes[num_seqs, max_len], preventing potential shape mismatches in downstream metrics or custom loss functions.
| from copy import copy | |
| unpacked_logps, unpacked_labels = self._unpack_by_position_ids( | |
| position_ids, logps, labels, padding_values=[0, -100]) | |
| inputs = copy(inputs) | |
| outputs = copy(outputs) | |
| outputs['logps'] = unpacked_logps | |
| inputs['labels'] = unpacked_labels | |
| return inputs, outputs | |
| unpacked_logps, unpacked_labels, unpacked_pos = self._unpack_by_position_ids( | |
| position_ids, logps, labels, position_ids, padding_values=[0, -100, -1]) | |
| inputs = inputs.copy() | |
| outputs = outputs.copy() | |
| outputs['logps'] = unpacked_logps | |
| inputs['labels'] = unpacked_labels | |
| inputs['position_ids'] = unpacked_pos | |
| return inputs, outputs |
PR type
PR information
Write the detail information belongs to this PR.
Experiment results
Paste your experiment result here(if needed).