Skip to content

Optimized rocm specific multicast transpose kernel#586

Open
alextmagro wants to merge 1 commit into
devfrom
multicasttranspose_opt
Open

Optimized rocm specific multicast transpose kernel#586
alextmagro wants to merge 1 commit into
devfrom
multicasttranspose_opt

Conversation

@alextmagro
Copy link
Copy Markdown
Contributor

Optimizes the multi_cast_transpose kernel for rocm.

Benchmark Results

Qwen has 128 experts, DS has 256 experts. Benchmarked with shapes derived from MBS={1,2,4}

Balanced Experts

Benchmark Base us Base TiB/s Opt us Opt TiB/s % Peak Speedup
128exp/4096cols/MBS1 668 0.75 140 3.58 49.2% 4.77x
128exp/4096cols/MBS2 1147 0.86 248 3.99 54.8% 4.63x
128exp/4096cols/MBS4 2388 0.82 478 4.12 56.6% 5.00x
128exp/1536cols/MBS1 281 0.67 54.8 3.44 47.3% 5.13x
128exp/1536cols/MBS2 592 0.63 101 3.67 50.4% 5.86x
128exp/1536cols/MBS4 934 0.79 191 3.86 53.1% 4.89x
128exp/3072cols/MBS1 548 0.69 102 3.71 51.0% 5.37x
128exp/3072cols/MBS2 904 0.82 204 3.65 50.2% 4.43x
128exp/3072cols/MBS4 1763 0.84 371 3.98 54.7% 4.75x
256exp/7168cols/MBS1 1621 0.56 289 3.13 43.0% 5.61x
256exp/7168cols/MBS2 2251 0.78 519 3.39 46.6% 4.34x
256exp/7168cols/MBS4 4289 0.81 861 4.03 55.4% 4.98x
256exp/2048cols/MBS1 489 0.53 83.5 3.08 42.3% 5.86x
256exp/2048cols/MBS2 765 0.66 147 3.41 46.9% 5.20x
256exp/2048cols/MBS4 1251 0.79 267 3.71 51.0% 4.69x
256exp/4096cols/MBS1 949 0.55 168 3.07 42.2% 5.65x
256exp/4096cols/MBS2 1235 0.81 315 3.20 44.0% 3.92x
256exp/4096cols/MBS4 2417 0.82 533 3.72 51.1% 4.53x

Skewed routing

Benchmark Base us Base TiB/s Opt us Opt TiB/s % Peak Speedup
128exp/4096cols/MBS1 810 0.62 137 3.69 50.7% 5.91x
128exp/4096cols/MBS2 1335 0.74 273 3.63 49.9% 4.89x
128exp/4096cols/MBS4 2660 0.74 513 3.83 52.6% 5.19x
128exp/1536cols/MBS1 349 0.54 53.2 3.55 48.8% 6.56x
128exp/1536cols/MBS2 635 0.59 87.6 4.24 58.3% 7.25x
128exp/1536cols/MBS4 1108 0.67 202 3.66 50.3% 5.49x
128exp/3072cols/MBS1 644 0.59 94.1 4.01 55.1% 6.84x
128exp/3072cols/MBS2 1136 0.65 210 3.54 48.7% 5.41x
128exp/3072cols/MBS4 2049 0.72 389 3.79 52.1% 5.27x
256exp/7168cols/MBS1 1676 0.54 310 2.91 40.0% 5.41x
256exp/7168cols/MBS2 2732 0.64 496 3.55 48.8% 5.51x
256exp/7168cols/MBS4 4674 0.74 928 3.74 51.4% 5.04x
256exp/2048cols/MBS1 635 0.41 87.9 2.93 40.3% 7.22x
256exp/2048cols/MBS2 933 0.54 141 3.56 48.9% 6.62x
256exp/2048cols/MBS4 1560 0.64 272 3.65 50.2% 5.74x
256exp/4096cols/MBS1 1039 0.50 185 2.79 38.3% 5.62x
256exp/4096cols/MBS2 1644 0.61 292 3.44 47.3% 5.63x
256exp/4096cols/MBS4 2728 0.73 534 3.71 51.0% 5.11x

Performance Summary

Average speedup (balanced): 5.0x
Average speedup (skewed): 5.8x
Average % peak (balanced): 49.4%
Average % peak (skewed): 49.0%

Change Summary

  • 512 threads/block (WPT=16) vs upstream's 128 (WPT=4) — 4x more threads for latency hiding
  • Non-temporal stores for both outputs via NTVec — upstream uses regular Vec::store_to which pollutes L2 (CDNA4 L2 is write-allocate)
  • Packed FP8 intrinsics via rocm_pack_4xfloat8 — 2 v_cvt_pk_fp8_f32 per 4 values vs upstream's scalar OType(scale * x) casts
  • Fused amax into the pack loop with tree reduction — upstream has a separate serial amax pass
  • 128-bit vectorized loads (LOAD_SZ=16 for BF16, NVEC_IN=8) — upstream uses 64-bit (LOAD_SZ=8, NVEC_IN=4)
  • kMCTMaxTensors=256 — single kernel launch for up to 256 experts. Upstream limited to 64 (CUDA 4KB kernarg limit doesn't apply on AMD)
  • Edge-tile bounds checking — handles any row count with pad-16 alignment. Interior tiles run the fast path, edge tiles are predicated per-row
  • Binary search for tensor lookup — O(log N) vs upstream's O(N) linear scan
  • rocm_block_reduce_max with rocm_atomicMaxFloat — uses atomicMax on int-reinterpreted float (single instruction) vs upstream's CAS loop
  • Column-major tile orderingtile_m = local_bid % tiles_m for L2 input locality
Rejected/skipped optimizations (click to expand)
  • WPT=8 — -25-33%, VGPR pressure from local_t[8][4]
  • WPT=32 — -10% for MBS1, 2 blocks/CU limit
  • Wave64 — -12-24%, 2x smem + narrower stores + 8 iterations
  • Cached stores for output_c — neutral
  • IS_EDGE template split — -20-26%, 2x launch overhead
  • STORE_SZ grouping — -10% Qwen3, more launches hurt
  • Persistent kernel — moot after kMCTMaxTensors=256
  • Inline ASM for FP8 pack — -2%, compiler already optimal
  • ds_read_b64_tr_b8 — 128 LDS calls vs 4 syncthreads
  • Row/column cascade — edge-tile bounds checking sufficient
  • Precompute tensor lookup in smem — binary search already <1%

@alextmagro alextmagro added ci-level 1 CI test level 1 ci-level 3 CI test level 3 and removed ci-level 1 CI test level 1 labels May 14, 2026
HIP_CHECK(hipEventCreate(&stop));

nvte_multi_cast_transpose(num_experts, nvte_in.data(), nvte_out.data(), stream);
HIP_CHECK(hipStreamSynchronize(stream));
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is synchronize needed here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 3 CI test level 3

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants