Skip to content

[JAX] Fix get_seqlens_and_offsets() to accept vmapped seg ids and non vmapped seg offsets#2692

Open
KshitijLakhani wants to merge 12 commits intoNVIDIA:mainfrom
KshitijLakhani:klakhani/fix/vmap-get-seg-ids-pos
Open

[JAX] Fix get_seqlens_and_offsets() to accept vmapped seg ids and non vmapped seg offsets#2692
KshitijLakhani wants to merge 12 commits intoNVIDIA:mainfrom
KshitijLakhani:klakhani/fix/vmap-get-seg-ids-pos

Conversation

@KshitijLakhani
Copy link
Collaborator

@KshitijLakhani KshitijLakhani commented Feb 19, 2026

Description

What is the bug ?

TE provides a convenience function from_segment_ids_and_pos() which allows users to pass only segment ids and the function returns a SequenceDescriptor with internally generated segment pos and passed segment ids.

As mentioned in Issue #2685 , if a user were to vmap a function forward() which i) accepts the q,k,v,segment ids and then ii) calls from_segment_ids_and_pos() followed by iii) a call to DPA(), what happens is that JAX sees the segment ids as vmapped hence an extra leading dimension is added (e.g. 1,2,128) whereas the segment offsets are not given a leading dimension (e.g. 2,128). This results in the FusedAttn primitive impl() assert being triggered due to a shape mismatch between seg ids and seg pos as mentioned in issue #2685

What is the root cause for the bug ?

On debugging, it can be seen that the shape starts differing when the batcher is being traced for the FusedAttn primitive.
segment_ids in the primitive: treated as vmapped inputs hence batched → (1, 2, 128).
segment_pos in the primitive: treated as derived within the function hence not batched → (2, 128).

Fixes #2685

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

There are two possible approaches to solve this:

  1. Ensure that the issue is resolved at the source, i.e.. ensure that segment_pos has the same leading batching dims as segment_ids. Add any additional dims in the batcher for the same so that when impl() sees the shape they are the same. Pros: Issue resolved in a "JAX" way and at source. Cons: Increasing mem be expanding seg pos dims.
  2. Resolve the issue when impl() is called, i.e. accomodate for mismatched seg id and seg pos dims when generating the seqlens and offsets. Pros: No extra mem needed as no expansion of dims. Cons: Not "truely" solved (at source)

Second approach is chosen here as it more optimized. After this PR merge the end user can vmap wrap the TE API calls without worrying about the batching in TE.
Accomodate for

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

… the TE constructed segment pos are not thereby causing mismatches in impl()

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani self-assigned this Feb 19, 2026
@KshitijLakhani
Copy link
Collaborator Author

/te-ci jax L0 L1 L2

@KshitijLakhani KshitijLakhani marked this pull request as ready for review February 20, 2026 06:54
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 20, 2026

Greptile Summary

Fixed shape mismatch bug when using from_segment_ids_and_pos() inside vmapped functions. The issue occurred because JAX added leading batch dimensions to user-provided segment_ids but not to internally-generated segment_pos, causing FusedAttn primitive assertions to fail.

Key changes:

  • Modified get_seqlens_and_offsets() in SequenceDescriptor to detect and handle extra leading batch dims on segment_ids
  • When segment_ids has more dimensions than segment_pos, the code now flattens extra batch dims, vmaps the seqlens/offsets computation with segment_pos broadcast, then reshapes outputs back
  • Replaced strict shape equality assertions with more flexible validation that allows segment_ids to have additional leading dims
  • Updated comments in FusedAttn primitive batchers to document that segment_ids/segment_pos may have different batch dimensions

Implementation notes:

  • The vmap approach with in_axes=(0, 0, None, None) correctly broadcasts segment_pos across the batch dimension
  • JAX will raise clear errors if q and kv segment_ids have incompatible batch sizes, which is appropriate for catching user errors
  • Changes only affect THD layout; BSHD path remains unchanged

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk
  • The implementation correctly solves the vmap shape mismatch issue with a clean approach: detecting extra batch dims on segment_ids, flattening them, vmapping the computation with broadcasted segment_pos, and reshaping outputs. The validation logic appropriately checks shape compatibility. Edge cases with mismatched batch sizes are caught by JAX's vmap error handling. Changes are scoped to THD layout only, minimizing risk to other code paths.
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/jax/attention.py Modified get_seqlens_and_offsets() to handle vmapped segment_ids with broadcasted segment_pos by flattening extra batch dims, vmapping the computation, and reshaping outputs
transformer_engine/jax/cpp_extensions/attention.py Updated batcher comments to clarify that segment_ids/segment_pos may have different batch dims and conversion is handled in attention.py

Last reviewed commit: 40e4d28

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

for _ in range(leading_bdim):
expanded = lax.expand_dims(expanded, (0,))
batched_args_list[seg_pos_idx] = jnp.broadcast_to(expanded, target_shape)
updated_batch_dims[seg_pos_idx] = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider using seg_id_bdim instead of hardcoding 0 for consistency, even though check_valid_batch_dims ensures it's always 0

Suggested change
updated_batch_dims[seg_pos_idx] = 0
updated_batch_dims[seg_pos_idx] = seg_id_bdim

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

for _ in range(leading_bdim):
expanded = lax.expand_dims(expanded, (0,))
batched_args_list[seg_pos_idx] = jnp.broadcast_to(expanded, target_shape)
updated_batch_dims[seg_pos_idx] = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider using seg_id_bdim instead of hardcoding 0 for consistency, even though check_valid_batch_dims ensures it's always 0

Suggested change
updated_batch_dims[seg_pos_idx] = 0
updated_batch_dims[seg_pos_idx] = seg_id_bdim

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

…rts.

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
… the seqlens and offsets for fused attn

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Comment on lines 739 to 741
# assert flat_batch_q == flat_batch_kv, (
# f"segment_ids batch size mismatch: {batch_shape_q} vs {batch_shape_kv}"
# )
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

commented assertion could lead to unclear error if q and kv have mismatched batch sizes. vmap would fail but with a generic JAX error. consider uncommenting or adding a comment explaining why validation isn't needed

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

…ed to get_seqlens_and_offsets()

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani force-pushed the klakhani/fix/vmap-get-seg-ids-pos branch from a7c398c to 395ac54 Compare February 27, 2026 18:36
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani force-pushed the klakhani/fix/vmap-get-seg-ids-pos branch from 386a633 to 693ba65 Compare February 27, 2026 19:19
@KshitijLakhani KshitijLakhani changed the title [JAX] Fix batcher in FusedAttn primitive for when seg ids bdims != seg pos bdims [JAX] Fix get_seqlens_and_offsets() to accept vmapped seg ids and non vmapped seg offsets Feb 27, 2026
@KshitijLakhani
Copy link
Collaborator Author

/te-ci jax L0 L1 L2

@KshitijLakhani
Copy link
Collaborator Author

CI passes. The only one failure is due to HF requests for the A100 L2 test.
Rerunning passes these

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

JAX vmap issue with TE Attention

2 participants