[JAX] Fix get_seqlens_and_offsets() to accept vmapped seg ids and non vmapped seg offsets#2692
[JAX] Fix get_seqlens_and_offsets() to accept vmapped seg ids and non vmapped seg offsets#2692KshitijLakhani wants to merge 12 commits intoNVIDIA:mainfrom
Conversation
… the TE constructed segment pos are not thereby causing mismatches in impl() Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
|
/te-ci jax L0 L1 L2 |
Greptile SummaryFixed shape mismatch bug when using Key changes:
Implementation notes:
Confidence Score: 5/5
Important Files Changed
Last reviewed commit: 40e4d28 |
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
LGTM, thanks!
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
da19f26 to
35d6d0f
Compare
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
| for _ in range(leading_bdim): | ||
| expanded = lax.expand_dims(expanded, (0,)) | ||
| batched_args_list[seg_pos_idx] = jnp.broadcast_to(expanded, target_shape) | ||
| updated_batch_dims[seg_pos_idx] = 0 |
There was a problem hiding this comment.
consider using seg_id_bdim instead of hardcoding 0 for consistency, even though check_valid_batch_dims ensures it's always 0
| updated_batch_dims[seg_pos_idx] = 0 | |
| updated_batch_dims[seg_pos_idx] = seg_id_bdim |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| for _ in range(leading_bdim): | ||
| expanded = lax.expand_dims(expanded, (0,)) | ||
| batched_args_list[seg_pos_idx] = jnp.broadcast_to(expanded, target_shape) | ||
| updated_batch_dims[seg_pos_idx] = 0 |
There was a problem hiding this comment.
consider using seg_id_bdim instead of hardcoding 0 for consistency, even though check_valid_batch_dims ensures it's always 0
| updated_batch_dims[seg_pos_idx] = 0 | |
| updated_batch_dims[seg_pos_idx] = seg_id_bdim |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
…rts. Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
… the seqlens and offsets for fused attn Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
transformer_engine/jax/attention.py
Outdated
| # assert flat_batch_q == flat_batch_kv, ( | ||
| # f"segment_ids batch size mismatch: {batch_shape_q} vs {batch_shape_kv}" | ||
| # ) |
There was a problem hiding this comment.
commented assertion could lead to unclear error if q and kv have mismatched batch sizes. vmap would fail but with a generic JAX error. consider uncommenting or adding a comment explaining why validation isn't needed
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
…ed to get_seqlens_and_offsets() Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
a7c398c to
395ac54
Compare
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
386a633 to
693ba65
Compare
for more information, see https://pre-commit.ci
|
/te-ci jax L0 L1 L2 |
|
CI passes. The only one failure is due to HF requests for the A100 L2 test. |
Description
What is the bug ?
TE provides a convenience function
from_segment_ids_and_pos()which allows users to pass only segment ids and the function returns aSequenceDescriptorwith internally generated segment pos and passed segment ids.As mentioned in Issue #2685 , if a user were to vmap a function forward() which i) accepts the q,k,v,segment ids and then ii) calls
from_segment_ids_and_pos()followed by iii) a call toDPA(), what happens is that JAX sees the segment ids as vmapped hence an extra leading dimension is added (e.g. 1,2,128) whereas the segment offsets are not given a leading dimension (e.g. 2,128). This results in the FusedAttn primitive impl() assert being triggered due to a shape mismatch between seg ids and seg pos as mentioned in issue #2685What is the root cause for the bug ?
On debugging, it can be seen that the shape starts differing when the batcher is being traced for the FusedAttn primitive.
segment_idsin the primitive: treated as vmapped inputs hence batched → (1, 2, 128).segment_posin the primitive: treated as derived within the function hence not batched → (2, 128).Fixes #2685
Type of change
Changes
There are two possible approaches to solve this:
segment_poshas the same leading batching dims assegment_ids. Add any additional dims in the batcher for the same so that when impl() sees the shape they are the same. Pros: Issue resolved in a "JAX" way and at source. Cons: Increasing mem be expanding seg pos dims.Second approach is chosen here as it more optimized. After this PR merge the end user can vmap wrap the TE API calls without worrying about the batching in TE.
Accomodate for
Checklist: