diff --git a/benchmarks/attention/benchmark_rope_thd_full_layer.py b/benchmarks/attention/benchmark_rope_thd_full_layer.py new file mode 100644 index 0000000000..2533dad3ed --- /dev/null +++ b/benchmarks/attention/benchmark_rope_thd_full_layer.py @@ -0,0 +1,361 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Full TransformerLayer benchmark for token-linear THD fused RoPE. + +This benchmark keeps the local packed-token count and RoPE table length fixed +while varying the number of packed THD spans. It compares the old fused RoPE +launch, the new token-linear launch, and the heuristic path on a TE +TransformerLayer using THD input and rotary embeddings. It also measures a +paired RoPE-only operation with the same tensor shape, so the output table can +report both end-to-end layer speedup and the fraction of layer time attributable +to fused RoPE. +""" + +from __future__ import annotations + +import argparse +import csv +import os +from contextlib import contextmanager +from pathlib import Path +from typing import Callable, Iterable + +import torch + +import transformer_engine.pytorch as te +from transformer_engine.pytorch.attention.rope import ( + RotaryPositionEmbedding, + apply_rotary_pos_emb, +) + + +@contextmanager +def env(name: str, value: str | None): + prev = os.environ.get(name) + if value is None: + os.environ.pop(name, None) + else: + os.environ[name] = value + try: + yield + finally: + if prev is None: + os.environ.pop(name, None) + else: + os.environ[name] = prev + + +def build_cu_seqlens(total_tokens: int, n_seqs: int) -> tuple[torch.Tensor, int]: + """Build balanced packed THD cu_seqlens with an exact total token count.""" + per = total_tokens // n_seqs + if per <= 0: + raise ValueError(f"n_seqs={n_seqs} is too large for total_tokens={total_tokens}") + rem = total_tokens - per * n_seqs + lengths = [per + (1 if i < rem else 0) for i in range(n_seqs)] + cu = [0] + max_seqlen = 0 + for length in lengths: + cu.append(cu[-1] + length) + max_seqlen = max(max_seqlen, length) + return torch.tensor(cu, dtype=torch.int32), max_seqlen + + +def zero_grads(params: Iterable[torch.Tensor], x: torch.Tensor) -> None: + if x.grad is not None: + x.grad = None + for p in params: + if p.grad is not None: + p.grad = None + + +def time_fwd_bwd(fn: Callable[[], torch.Tensor], warmup: int, iters: int) -> tuple[float, float]: + torch.cuda.synchronize() + for _ in range(warmup): + out = fn() + out.sum().backward() + torch.cuda.synchronize() + + start = torch.cuda.Event(enable_timing=True) + fwd_end = torch.cuda.Event(enable_timing=True) + bwd_end = torch.cuda.Event(enable_timing=True) + fwd_total = 0.0 + full_total = 0.0 + for _ in range(iters): + start.record() + out = fn() + fwd_end.record() + out.sum().backward() + bwd_end.record() + torch.cuda.synchronize() + fwd_total += start.elapsed_time(fwd_end) + full_total += start.elapsed_time(bwd_end) + return fwd_total / iters, full_total / iters + + +def make_layer(args: argparse.Namespace, dtype: torch.dtype) -> te.TransformerLayer: + sigma = 0.02 + + def init_method(tensor: torch.Tensor) -> torch.Tensor: + return torch.nn.init.normal_(tensor, mean=0.0, std=sigma) + + return te.TransformerLayer( + args.hidden_size, + args.ffn_hidden_size, + args.num_heads, + layernorm_epsilon=1e-5, + hidden_dropout=0.0, + attention_dropout=0.0, + init_method=init_method, + output_layer_init_method=init_method, + layer_number=1, + kv_channels=args.head_dim, + self_attn_mask_type="padding_causal", + tp_group=None, + tp_size=1, + params_dtype=dtype, + get_rng_state_tracker=None, + fuse_wgrad_accumulation=False, + seq_length=args.freqs_len, + micro_batch_size=1, + sequence_parallel=False, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + layer_type="encoder", + set_parallel_mode=True, + fuse_qkv_params=True, + zero_centered_gamma=False, + qkv_weight_interleaved=True, + bias=True, + attn_input_format="thd", + rotary_pos_interleaved=args.interleaved, + device="cuda", + ).to(dtype=dtype, device="cuda") + + +def main(argv: Iterable[str] | None = None) -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--total-tokens", type=int, default=65536) + parser.add_argument("--freqs-len", type=int, default=65536) + parser.add_argument("--hidden-size", type=int, default=1536) + parser.add_argument("--ffn-hidden-size", type=int, default=6144) + parser.add_argument("--num-heads", type=int, default=12) + parser.add_argument("--dtype", choices=["bf16", "fp16"], default="bf16") + parser.add_argument("--interleaved", action="store_true") + parser.add_argument("--warmup", type=int, default=2) + parser.add_argument("--iters", type=int, default=5) + # n_seqs=50 is intentionally omitted from the default sweep because the + # balanced-span shape has max_seqlen~=1311 and can hit a cuDNN fused-attn + # execution failure unrelated to RoPE on the tested H100 stack. The high-span + # cases below are the issue-relevant regime where RoPE launch waste dominates. + parser.add_argument("--n-seqs", type=int, nargs="+", default=[128, 512, 1024, 2401]) + parser.add_argument("--out-dir", type=Path, default=Path("rope_thd_full_layer_bench")) + args = parser.parse_args(argv) + + if not torch.cuda.is_available(): + raise SystemExit("CUDA is required") + if args.hidden_size % args.num_heads != 0: + raise SystemExit("--hidden-size must be divisible by --num-heads") + args.head_dim = args.hidden_size // args.num_heads + if args.freqs_len < args.total_tokens: + raise SystemExit("--freqs-len should be >= --total-tokens for this long-context benchmark") + + torch.manual_seed(1234) + dtype = {"bf16": torch.bfloat16, "fp16": torch.float16}[args.dtype] + device = torch.device("cuda") + + rotary = RotaryPositionEmbedding(args.head_dim, interleaved=args.interleaved) + freqs = rotary(args.freqs_len).to(device=device) + + args.out_dir.mkdir(parents=True, exist_ok=True) + csv_path = args.out_dir / "rope_thd_full_layer_bench.csv" + fields = [ + "n_seqs", + "regime", + "max_seqlen", + "layer_fwd_ms", + "layer_fwd_bwd_ms", + "layer_bwd_ms", + "rope_pair_fwd_ms", + "rope_pair_fwd_bwd_ms", + "rope_pair_bwd_ms", + "rope_pair_pct_layer", + "layer_speedup_vs_old", + "rope_pair_speedup_vs_old", + ] + rows: list[dict[str, str | int]] = [] + + print( + "# full-layer THD RoPE benchmark: " + f"T={args.total_tokens} freqs_len={args.freqs_len} hidden={args.hidden_size} " + f"ffn={args.ffn_hidden_size} heads={args.num_heads} dtype={args.dtype}", + flush=True, + ) + print( + "# n_seqs regime max_seqlen layer_fwd layer_fwd_bwd rope_pair_fwd_bwd " + "rope_pct layer_speedup", + flush=True, + ) + + for n_seqs in args.n_seqs: + cu_cpu, max_seqlen = build_cu_seqlens(args.total_tokens, n_seqs) + cu = cu_cpu.to(device=device) + x = torch.randn( + args.total_tokens, + args.hidden_size, + dtype=dtype, + device=device, + requires_grad=True, + ) + q = torch.randn( + args.total_tokens, + args.num_heads, + args.head_dim, + dtype=dtype, + device=device, + requires_grad=True, + ) + k = torch.randn_like(q, requires_grad=True) + layer = make_layer(args, dtype) + layer.train() + params = tuple(layer.parameters()) + + layer_old = None + rope_old = None + + for regime, override in (("old", "0"), ("new", "1"), ("heuristic", None)): + with env("NVTE_FUSED_ROPE_THD_TOKEN_LINEAR", override): + + def layer_fn() -> torch.Tensor: + zero_grads(params, x) + return layer( + x, + rotary_pos_emb=freqs, + cu_seqlens_q=cu, + cu_seqlens_kv=cu, + max_seqlen_q=max_seqlen, + max_seqlen_kv=max_seqlen, + ) + + def rope_pair_fn() -> torch.Tensor: + if q.grad is not None: + q.grad = None + if k.grad is not None: + k.grad = None + q_out = apply_rotary_pos_emb( + q, + freqs, + tensor_format="thd", + fused=True, + cu_seqlens=cu, + interleaved=args.interleaved, + ) + k_out = apply_rotary_pos_emb( + k, + freqs, + tensor_format="thd", + fused=True, + cu_seqlens=cu, + interleaved=args.interleaved, + ) + return q_out + k_out + + layer_fwd, layer_full = time_fwd_bwd(layer_fn, args.warmup, args.iters) + rope_fwd, rope_full = time_fwd_bwd(rope_pair_fn, args.warmup, args.iters) + + if regime == "old": + layer_old = layer_full + rope_old = rope_full + assert layer_old is not None and rope_old is not None + layer_speedup = layer_old / layer_full + rope_speedup = rope_old / rope_full + rope_pct = 100.0 * rope_full / layer_full + rows.append( + { + "n_seqs": n_seqs, + "regime": regime, + "max_seqlen": max_seqlen, + "layer_fwd_ms": f"{layer_fwd:.4f}", + "layer_fwd_bwd_ms": f"{layer_full:.4f}", + "layer_bwd_ms": f"{layer_full - layer_fwd:.4f}", + "rope_pair_fwd_ms": f"{rope_fwd:.4f}", + "rope_pair_fwd_bwd_ms": f"{rope_full:.4f}", + "rope_pair_bwd_ms": f"{rope_full - rope_fwd:.4f}", + "rope_pair_pct_layer": f"{rope_pct:.2f}", + "layer_speedup_vs_old": f"{layer_speedup:.3f}", + "rope_pair_speedup_vs_old": f"{rope_speedup:.3f}", + } + ) + print( + f"{n_seqs:>6} {regime:>10} {max_seqlen:>10} " + f"layer_fwd={layer_fwd:8.3f} layer_fwd_bwd={layer_full:8.3f} " + f"rope_pair_fwd_bwd={rope_full:8.3f} rope_pct={rope_pct:6.2f}% " + f"layer_speedup={layer_speedup:6.3f}x", + flush=True, + ) + + with csv_path.open("w", newline="") as fh: + writer = csv.DictWriter(fh, fieldnames=fields) + writer.writeheader() + writer.writerows(rows) + print(f"\nWrote {csv_path}") + + try: + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + except ImportError: + print("matplotlib not installed; skipping plot") + return + + nseqs = sorted({int(r["n_seqs"]) for r in rows}) + by_regime = {regime: [] for regime in ("old", "new", "heuristic")} + pct_by_regime = {regime: [] for regime in ("old", "new", "heuristic")} + for n in nseqs: + for regime in by_regime: + row = next(r for r in rows if int(r["n_seqs"]) == n and r["regime"] == regime) + by_regime[regime].append(float(row["layer_fwd_bwd_ms"])) + pct_by_regime[regime].append(float(row["rope_pair_pct_layer"])) + + fig, axes = plt.subplots(1, 3, figsize=(17, 5)) + ax = axes[0] + for regime, values in by_regime.items(): + ax.plot(nseqs, values, marker="o", label=regime) + ax.set_xscale("log") + ax.set_yscale("log") + ax.set_xlabel("n_seqs") + ax.set_ylabel("TransformerLayer fwd+bwd (ms)") + ax.set_title("Full THD TransformerLayer") + ax.grid(True, which="both", alpha=0.3) + ax.legend() + + ax = axes[1] + speedups = [by_regime["old"][i] / by_regime["new"][i] for i in range(len(nseqs))] + ax.plot(nseqs, speedups, marker="o", color="tab:green") + ax.axhline(1.0, color="gray", linestyle="--", alpha=0.5) + ax.set_xscale("log") + ax.set_xlabel("n_seqs") + ax.set_ylabel("Layer speedup (old / new)") + ax.set_title("End-to-end layer speedup") + ax.grid(True, which="both", alpha=0.3) + + ax = axes[2] + for regime, values in pct_by_regime.items(): + ax.plot(nseqs, values, marker="o", label=regime) + ax.set_xscale("log") + ax.set_xlabel("n_seqs") + ax.set_ylabel("paired RoPE fwd+bwd / layer fwd+bwd (%)") + ax.set_title("RoPE share estimate") + ax.grid(True, which="both", alpha=0.3) + ax.legend() + + fig.tight_layout() + png_path = args.out_dir / "rope_thd_full_layer_bench.png" + fig.savefig(png_path, dpi=120) + print(f"Wrote {png_path}") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/attention/benchmark_rope_thd_token_linear.py b/benchmarks/attention/benchmark_rope_thd_token_linear.py new file mode 100644 index 0000000000..734c7289be --- /dev/null +++ b/benchmarks/attention/benchmark_rope_thd_token_linear.py @@ -0,0 +1,264 @@ +"""Microbenchmark for the token-linear THD fused RoPE path. + +Holds the local packed-token count fixed and sweeps the number of packed +sequences. For each point, measures forward and backward latency of the fused +RoPE kernel under three regimes: + + * forced-old: ``NVTE_FUSED_ROPE_THD_TOKEN_LINEAR=0`` + * forced-new: ``NVTE_FUSED_ROPE_THD_TOKEN_LINEAR=1`` + * heuristic: variable unset + +Outputs a CSV and a PNG. Intended to be run on a single GPU; not distributed. +""" + +from __future__ import annotations + +import argparse +import csv +import os +from contextlib import contextmanager +from pathlib import Path +from typing import Iterable + +import torch + +from transformer_engine.pytorch.attention.rope import ( + RotaryPositionEmbedding, + apply_rotary_pos_emb, +) + + +@contextmanager +def env(name: str, value: str | None): + prev = os.environ.get(name) + if value is None: + os.environ.pop(name, None) + else: + os.environ[name] = value + try: + yield + finally: + if prev is None: + os.environ.pop(name, None) + else: + os.environ[name] = prev + + +def build_cu_seqlens(total_tokens: int, n_seqs: int, cp_size: int = 1) -> torch.Tensor: + """Build a cu_seqlens whose local packed length equals ``total_tokens``. + + Per-sequence lengths are equal to ``total_tokens / n_seqs`` rounded down to + a multiple of ``2 * cp_size``; any leftover tokens are tacked onto the last + span so that the local total is exact. + """ + pad = 2 * cp_size + per = (total_tokens // n_seqs // pad) * pad + if per <= 0: + raise ValueError( + f"n_seqs={n_seqs} is too large for total_tokens={total_tokens} with cp_size={cp_size}" + ) + lengths = [per] * n_seqs + deficit = total_tokens - per * n_seqs + lengths[-1] += (deficit // pad) * pad + cu = [0] + for length in lengths: + cu.append(cu[-1] + length) + return torch.tensor(cu, dtype=torch.int32) + + +def time_fwd_bwd( + fn, + iters: int, + warmup: int, +) -> tuple[float, float]: + """Return (fwd_ms, fwd_plus_bwd_ms) averaged across ``iters`` iterations.""" + torch.cuda.synchronize() + for _ in range(warmup): + out = fn() + out.sum().backward() + torch.cuda.synchronize() + + start_fwd = torch.cuda.Event(enable_timing=True) + end_fwd = torch.cuda.Event(enable_timing=True) + end_bwd = torch.cuda.Event(enable_timing=True) + + fwd_total = 0.0 + full_total = 0.0 + for _ in range(iters): + start_fwd.record() + out = fn() + end_fwd.record() + out.sum().backward() + end_bwd.record() + torch.cuda.synchronize() + fwd_total += start_fwd.elapsed_time(end_fwd) + full_total += start_fwd.elapsed_time(end_bwd) + return fwd_total / iters, full_total / iters + + +def main(argv: Iterable[str] | None = None) -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--total-tokens", type=int, default=65536) + parser.add_argument("--freqs-len", type=int, default=65536) + parser.add_argument("--head-num", type=int, default=32) + parser.add_argument("--hidden", type=int, default=128) + parser.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16") + parser.add_argument("--rotary-percent", type=float, default=1.0) + parser.add_argument("--interleaved", action="store_true") + parser.add_argument("--cp-size", type=int, default=1) + parser.add_argument( + "--n-seqs", + type=int, + nargs="+", + default=[1, 8, 32, 128, 512, 1024, 2401], + ) + parser.add_argument("--warmup", type=int, default=5) + parser.add_argument("--iters", type=int, default=20) + parser.add_argument("--out-dir", type=Path, default=Path("rope_thd_bench")) + args = parser.parse_args(argv) + + if not torch.cuda.is_available(): + raise SystemExit("CUDA is required to run this benchmark") + + dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[args.dtype] + device = torch.device("cuda:0") + + rotary = RotaryPositionEmbedding(args.hidden, args.rotary_percent, interleaved=args.interleaved) + freqs = rotary(args.freqs_len).to(device) + + args.out_dir.mkdir(parents=True, exist_ok=True) + csv_path = args.out_dir / "rope_thd_bench.csv" + fields = [ + "n_seqs", + "regime", + "fwd_ms", + "fwd_bwd_ms", + "bwd_ms", + "blocks_old", + "blocks_new", + "speedup_fwd_bwd_vs_old", + ] + rows = [] + + print( + f"# total_tokens={args.total_tokens} freqs_len={args.freqs_len} h={args.head_num} " + f"d={args.hidden} dtype={args.dtype} cp={args.cp_size}" + ) + print("# n_seqs regime fwd_ms fwd_bwd_ms bwd_ms blocks_old blocks_new speedup") + + by_nseq_old: dict[int, float] = {} + + for n_seqs in args.n_seqs: + cu = build_cu_seqlens(args.total_tokens, n_seqs, cp_size=args.cp_size).to(device) + actual_total = int(cu[-1].item()) + t = torch.rand( + (actual_total // args.cp_size, args.head_num, args.hidden), + dtype=dtype, + device=device, + requires_grad=True, + ) + + def runner() -> torch.Tensor: + # Reset grad in-place to keep the autograd graph fresh. + if t.grad is not None: + t.grad = None + return apply_rotary_pos_emb( + t, + freqs, + tensor_format="thd", + fused=True, + cu_seqlens=cu, + interleaved=args.interleaved, + cp_size=args.cp_size, + cp_rank=0, + ) + + blocks_old = args.freqs_len * n_seqs + blocks_new = actual_total // args.cp_size + + for regime, value in [("old", "0"), ("new", "1"), ("heuristic", None)]: + with env("NVTE_FUSED_ROPE_THD_TOKEN_LINEAR", value): + fwd_ms, full_ms = time_fwd_bwd(runner, iters=args.iters, warmup=args.warmup) + bwd_ms = full_ms - fwd_ms + if regime == "old": + by_nseq_old[n_seqs] = full_ms + speedup = (by_nseq_old[n_seqs] / full_ms) if regime != "old" else 1.0 + rows.append( + { + "n_seqs": n_seqs, + "regime": regime, + "fwd_ms": f"{fwd_ms:.4f}", + "fwd_bwd_ms": f"{full_ms:.4f}", + "bwd_ms": f"{bwd_ms:.4f}", + "blocks_old": blocks_old, + "blocks_new": blocks_new, + "speedup_fwd_bwd_vs_old": f"{speedup:.3f}", + } + ) + print( + f"{n_seqs:>6} {regime:>10} fwd={fwd_ms:7.3f} fwd_bwd={full_ms:7.3f} " + f"bwd={bwd_ms:7.3f} blocks_old={blocks_old} blocks_new={blocks_new} " + f"speedup={speedup:.2f}x" + ) + + with csv_path.open("w", newline="") as fh: + writer = csv.DictWriter(fh, fieldnames=fields) + writer.writeheader() + writer.writerows(rows) + print(f"\nWrote {csv_path}") + + try: + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + except ImportError: + print("matplotlib not installed; skipping plot") + return + + # Aggregate per regime. + nseqs = sorted({int(r["n_seqs"]) for r in rows}) + by_regime = {regime: [] for regime in ("old", "new", "heuristic")} + for n in nseqs: + for regime in by_regime: + for r in rows: + if int(r["n_seqs"]) == n and r["regime"] == regime: + by_regime[regime].append(float(r["fwd_bwd_ms"])) + break + + fig, axes = plt.subplots(1, 2, figsize=(13, 5)) + + ax = axes[0] + for regime in ("old", "new", "heuristic"): + ax.plot(nseqs, by_regime[regime], marker="o", label=regime) + ax.set_xscale("log") + ax.set_yscale("log") + ax.set_xlabel("n_seqs (packed spans, log)") + ax.set_ylabel("fwd + bwd latency (ms, log)") + ax.set_title( + f"Fused THD RoPE latency vs n_seqs\nT_local={args.total_tokens}, " + f"freqs_len={args.freqs_len}, h={args.head_num}, d={args.hidden}, " + f"{args.dtype}, cp={args.cp_size}" + ) + ax.grid(True, which="both", alpha=0.3) + ax.legend() + + ax = axes[1] + speedup_new = [by_regime["old"][i] / by_regime["new"][i] for i in range(len(nseqs))] + ax.plot(nseqs, speedup_new, marker="o", color="tab:green", label="new vs old") + ax.axhline(1.0, color="gray", linestyle="--", alpha=0.5) + ax.set_xscale("log") + ax.set_xlabel("n_seqs (log)") + ax.set_ylabel("speedup (old / new)") + ax.set_title("Token-linear speedup over (s × b) launch") + ax.grid(True, which="both", alpha=0.3) + ax.legend() + + fig.tight_layout() + png_path = args.out_dir / "rope_thd_bench.png" + fig.savefig(png_path, dpi=120) + print(f"Wrote {png_path}") + + +if __name__ == "__main__": + main() diff --git a/tests/pytorch/test_fused_rope.py b/tests/pytorch/test_fused_rope.py index 50624df9e0..332e8ff6ec 100644 --- a/tests/pytorch/test_fused_rope.py +++ b/tests/pytorch/test_fused_rope.py @@ -345,7 +345,8 @@ def test_unfused_rope_thd_vs_bshd( grad_unfused_bshd.reshape(*grad_unfused_thd.shape), grad_unfused_thd ) torch.testing.assert_close( - grad_unfused_sbhd.transpose(1, 0).reshape(*grad_unfused_thd.shape), grad_unfused_thd + grad_unfused_sbhd.transpose(1, 0).reshape(*grad_unfused_thd.shape), + grad_unfused_thd, ) assert output_unfused_thd.is_contiguous() @@ -495,3 +496,134 @@ def test_rotary_position_embedding_forward_with_autocast_gives_same_result_as_wi atol=1e-8, rtol=1e-8, ) + + +def _make_packed_thd_cu_seqlens( + n_seqs: int, + mean_len: int, + cp_size: int, + rng: torch.Generator, + include_zero_length: bool = False, +) -> torch.Tensor: + """Build a cu_seqlens tensor for a packed THD batch. + + Each per-sequence length is padded to a multiple of ``2 * cp_size`` so the + integer divisions inside the kernel are exact (matching how Megatron-style + callers pad cu_seqlens for context parallel). Optionally injects zero-length + spans to exercise the upper-bound search. + """ + lengths = torch.randint( + low=1, + high=max(2, 2 * mean_len), + size=(n_seqs,), + generator=rng, + dtype=torch.int64, + ) + if include_zero_length and n_seqs >= 4: + # Sprinkle a handful of zero-length spans, including back-to-back ones + # and one at the front, to exercise boundary cases. + zero_idx = [0, n_seqs // 3, n_seqs // 3 + 1, n_seqs - 2] + for idx in zero_idx: + if 0 <= idx < n_seqs: + lengths[idx] = 0 + pad = 2 * cp_size + lengths = ((lengths + pad - 1) // pad) * pad + # Restore zero-length spans after padding (pad rounds 0 to 0 already). + cu = torch.zeros(n_seqs + 1, dtype=torch.int32) + cu[1:] = torch.cumsum(lengths, dim=0).to(torch.int32) + return cu + + +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("hidden_size", [128]) +@pytest.mark.parametrize("rotary_percent", [0.5, 1.0]) +@pytest.mark.parametrize("interleaved", [False, True]) +@pytest.mark.parametrize("cp_size", [1, 2]) +@pytest.mark.parametrize( + "n_seqs,mean_len,include_zero_length", + [ + (1, 2048, False), + (8, 256, False), + (64, 64, False), + (513, 16, False), + (2401, 8, False), + (128, 32, True), + ], +) +@pytest.mark.parametrize("start_positions", [False, True]) +def test_fused_rope_thd_token_linear_parity( + monkeypatch: pytest.MonkeyPatch, + dtype: torch.dtype, + hidden_size: int, + rotary_percent: float, + interleaved: bool, + cp_size: int, + n_seqs: int, + mean_len: int, + include_zero_length: bool, + start_positions: bool, +) -> None: + """Forces the old and the new THD fused kernel back-to-back and asserts + bitwise equality on both the forward output and the input gradient. The new + kernel must enumerate exactly the same useful blocks as the old one, with + identical per-token math, so equality must hold without tolerance. + """ + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + device = torch.device("cuda:0") + head_num = 16 + + rng = torch.Generator(device="cpu") + rng.manual_seed(0xC0FFEE + n_seqs * 13 + (1 if include_zero_length else 0)) + + cu_seqlens = _make_packed_thd_cu_seqlens( + n_seqs, mean_len, cp_size, rng, include_zero_length=include_zero_length + ).to(device) + total_local = int(cu_seqlens[-1].item()) // cp_size + if total_local == 0: + pytest.skip("empty packed batch after padding") + + start_positions_t = ( + torch.randint(0, 4, (n_seqs,), dtype=torch.int32, device=device) + if start_positions + else None + ) + + t = torch.rand((total_local, head_num, hidden_size), dtype=dtype, device=device, generator=None) + t.requires_grad = True + + rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved) + # `freqs` must cover (max span length per CP rank + start_positions offset + + # CP dual-chunk offset). Use the global cu_seqlens[-1] length as an upper + # bound, matching how callers size the freqs tensor in practice. + emb = rotary_pos_emb(int(cu_seqlens[-1].item()) + 32) + + def run(force_path: str) -> Tuple[torch.Tensor, torch.Tensor]: + monkeypatch.setenv("NVTE_FUSED_ROPE_THD_TOKEN_LINEAR", force_path) + cp_rank = 0 + out = apply_rotary_pos_emb( + t, + emb, + start_positions=start_positions_t, + interleaved=interleaved, + fused=True, + tensor_format="thd", + cu_seqlens=cu_seqlens, + cp_size=cp_size, + cp_rank=cp_rank, + ) + loss = _overlapping_grad(out) + loss.backward() + grad = t.grad.detach().clone() + t.grad = None + return out.detach().clone(), grad + + out_old, grad_old = run("0") + out_new, grad_new = run("1") + + # Both paths call the same per-token device function with the same + # arguments and write disjoint output rows. Bitwise equality is the right + # bar. + torch.testing.assert_close(out_new, out_old, rtol=0.0, atol=0.0) + torch.testing.assert_close(grad_new, grad_old, rtol=0.0, atol=0.0) diff --git a/transformer_engine/common/fused_rope/fused_rope.cu b/transformer_engine/common/fused_rope/fused_rope.cu index 27dc11ab43..3bcfe57888 100644 --- a/transformer_engine/common/fused_rope/fused_rope.cu +++ b/transformer_engine/common/fused_rope/fused_rope.cu @@ -10,10 +10,32 @@ #include "../common.h" #include "../util/logging.h" +#include "../util/system.h" #include "../utils.cuh" namespace transformer_engine { +// Returns the largest sequence index `b` such that +// `cu_seqlens[b] / cp_size <= t_id`. Used by the token-linear THD kernels to +// locate the sequence span that owns local packed token `t_id`. Uses the same +// integer-division semantics as the existing THD kernels so that the +// per-sequence boundaries agree exactly. +__device__ __forceinline__ int fused_rope_thd_find_seq_id(const int *cu_seqlens, const int nseq, + const int t_id, const int cp_size) { + int lo = 0; + int hi = nseq; + while (lo + 1 < hi) { + int mid = (lo + hi) >> 1; + int mid_start = cu_seqlens[mid] / cp_size; + if (mid_start <= t_id) { + lo = mid; + } else { + hi = mid; + } + } + return lo; +} + template __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs, scalar_t *dst, const bool interleaved, const int s_id, @@ -215,6 +237,98 @@ __global__ void fused_rope_backward_kernel( offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h, o_stride_d); } +// Token-linear THD forward kernel. Each block handles exactly one packed local +// token row. The block locates its owning sequence via binary search over the +// divided cumulative sequence boundaries, then defers to the same +// `fused_rope_block_forward` device function as the original kernel. +template +__global__ void fused_rope_thd_token_forward_kernel( + const scalar_t *src, const int *cu_seqlens, const float *freqs, const int *start_positions, + scalar_t *dst, const bool interleaved, const int cp_size, const int cp_rank, const int nseq, + const int h, const int d, const int d2, const int stride_t, const int stride_h, + const int stride_d, const int o_stride_t, const int o_stride_h, const int o_stride_d) { + int t_id = blockIdx.x; + int b_id = fused_rope_thd_find_seq_id(cu_seqlens, nseq, t_id, cp_size); + int start = cu_seqlens[b_id] / cp_size; + int end = cu_seqlens[b_id + 1] / cp_size; + int s_id = t_id - start; + int cur_seqlens = end - start; + + int offset_block = t_id * stride_t; + int offset_block_dst = t_id * o_stride_t; + + int begin_offset = (start_positions == nullptr) ? 0 : start_positions[b_id]; + int s_id_for_freqs = s_id + begin_offset; + + if (cp_size > 1) { + assert(cur_seqlens % 2 == 0); + if (s_id < cur_seqlens / 2) { + s_id_for_freqs += cp_rank * cur_seqlens / 2; + } else { + s_id_for_freqs += cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 - cur_seqlens / 2; + } + } + + fused_rope_block_forward(src, freqs, dst, interleaved, s_id_for_freqs, offset_block, + offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h, o_stride_d); +} + +// Token-linear THD backward kernel. Mirrors the forward variant and dispatches +// to `fused_rope_block_backward`. +template +__global__ void fused_rope_thd_token_backward_kernel( + const scalar_t *src, const int *cu_seqlens, const float *freqs, const int *start_positions, + scalar_t *dst, const bool interleaved, const int cp_size, const int cp_rank, const int nseq, + const int h, const int d, const int d2, const int stride_t, const int stride_h, + const int stride_d, const int o_stride_t, const int o_stride_h, const int o_stride_d) { + int t_id = blockIdx.x; + int b_id = fused_rope_thd_find_seq_id(cu_seqlens, nseq, t_id, cp_size); + int start = cu_seqlens[b_id] / cp_size; + int end = cu_seqlens[b_id + 1] / cp_size; + int s_id = t_id - start; + int cur_seqlens = end - start; + + int offset_block = t_id * stride_t; + int offset_block_dst = t_id * o_stride_t; + + int begin_offset = (start_positions == nullptr) ? 0 : start_positions[b_id]; + int s_id_for_freqs = s_id + begin_offset; + + if (cp_size > 1) { + assert(cur_seqlens % 2 == 0); + if (s_id < cur_seqlens / 2) { + s_id_for_freqs += cp_rank * cur_seqlens / 2; + } else { + s_id_for_freqs += cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 - cur_seqlens / 2; + } + } + + fused_rope_block_backward(src, freqs, dst, interleaved, s_id_for_freqs, offset_block, + offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h, o_stride_d); +} + +// Host-side dispatcher. Selects the token-linear THD path when it would +// eliminate a meaningful number of dead blocks. The environment variable +// NVTE_FUSED_ROPE_THD_TOKEN_LINEAR overrides the heuristic for testing and +// benchmarking: "0" forces the old kernel, "1" forces the new one. Read on +// every call so tests can toggle it inside a single process. +inline bool fused_rope_thd_use_token_linear(const NVTE_QKV_Format qkv_format, const int b, + const int s, const int64_t total_tokens) { + if (qkv_format != NVTE_QKV_Format::NVTE_THD) return false; + if (total_tokens <= 0) return false; + + const int env_override = transformer_engine::getenv("NVTE_FUSED_ROPE_THD_TOKEN_LINEAR", -1); + if (env_override == 0) return false; + if (env_override == 1) return true; + + // Heuristic: only worth it when the old launch would issue at least 8x as + // many blocks as there are useful tokens, and when there are enough + // sequences for binary-search overhead to be amortized. + if (b < 64) return false; + if (static_cast(s) * static_cast(b) < 8 * total_tokens) return false; + return true; +} + template __device__ void fused_qkv_rope_block_forward(const scalar_t *src, const float *freqs, scalar_t *out, const bool interleaved, const int s_id, @@ -467,9 +581,8 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c const int cp_size, const int cp_rank, const int s, const int b, const int h, const int d, const int d2, const int stride_s_or_t, const int stride_b, const int stride_h, const int stride_d, - cudaStream_t stream) { + const int64_t total_tokens, cudaStream_t stream) { int warps_per_block = h < 16 ? 4 : 8; - dim3 blocks(s, b); dim3 threads(THREADS_PER_WARP, warps_per_block); const int shared_mem_size = 2 * d2 * sizeof(float); // cos, sin int o_stride_s_or_t, o_stride_b; @@ -487,6 +600,16 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c const int o_stride_h = d; const int o_stride_d = 1; + if (fused_rope_thd_use_token_linear(qkv_format, b, s, total_tokens)) { + dim3 blocks(static_cast(total_tokens)); + fused_rope_thd_token_forward_kernel<<>>( + input, cu_seqlens, freqs, start_positions, output, interleaved, cp_size, cp_rank, b, h, d, + d2, stride_s_or_t, stride_h, stride_d, o_stride_s_or_t, o_stride_h, o_stride_d); + NVTE_CHECK_CUDA(cudaGetLastError()); + return; + } + + dim3 blocks(s, b); fused_rope_forward_kernel<<>>( input, cu_seqlens, freqs, start_positions, output, interleaved, cp_size, cp_rank, s, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h, @@ -501,9 +624,9 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se const bool interleaved, const int cp_size, const int cp_rank, const int s, const int b, const int h, const int d, const int d2, const int stride_s_or_t, const int stride_b, const int stride_h, - const int stride_d, cudaStream_t stream) { + const int stride_d, const int64_t total_tokens, + cudaStream_t stream) { int warps_per_block = h < 16 ? 4 : 8; - dim3 blocks(s, b); dim3 threads(THREADS_PER_WARP, warps_per_block); const int shared_mem_size = 2 * d2 * sizeof(float); // cos, sin int o_stride_s_or_t, o_stride_b; @@ -521,6 +644,17 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se const int o_stride_h = d; const int o_stride_d = 1; + if (fused_rope_thd_use_token_linear(qkv_format, b, s, total_tokens)) { + dim3 blocks(static_cast(total_tokens)); + fused_rope_thd_token_backward_kernel<<>>( + output_grads, cu_seqlens, freqs, start_positions, input_grads, interleaved, cp_size, + cp_rank, b, h, d, d2, stride_s_or_t, stride_h, stride_d, o_stride_s_or_t, o_stride_h, + o_stride_d); + NVTE_CHECK_CUDA(cudaGetLastError()); + return; + } + + dim3 blocks(s, b); fused_rope_backward_kernel<<>>( output_grads, cu_seqlens, freqs, start_positions, input_grads, interleaved, cp_size, cp_rank, s, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, @@ -579,6 +713,12 @@ void fused_rope_forward(const Tensor &input, const Tensor &cu_seqlens, const Ten const int cp_rank, const int s, const int b, const int h, const int d, const int d2, const int stride_s_or_t, const int stride_b, const int stride_h, const int stride_d, cudaStream_t stream) { + // For THD the packed local token count is the first dimension of the input + // tensor. SBHD/BSHD ignore this value. + const int64_t total_tokens = + (qkv_format == NVTE_QKV_Format::NVTE_THD && !input.data.shape.empty()) + ? static_cast(input.data.shape[0]) + : 0; TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input.data.dtype, scalar_t, fused_rope_forward_launcher(reinterpret_cast(input.data.dptr), @@ -587,7 +727,7 @@ void fused_rope_forward(const Tensor &input, const Tensor &cu_seqlens, const Ten reinterpret_cast(start_positions.data.dptr), reinterpret_cast(output->data.dptr), qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t, - stride_b, stride_h, stride_d, stream);); + stride_b, stride_h, stride_d, total_tokens, stream);); } void fused_rope_backward(const Tensor &output_grads, const Tensor &cu_seqlens, const Tensor &freqs, @@ -597,6 +737,10 @@ void fused_rope_backward(const Tensor &output_grads, const Tensor &cu_seqlens, c const int h, const int d, const int d2, const int stride_s_or_t, const int stride_b, const int stride_h, const int stride_d, cudaStream_t stream) { + const int64_t total_tokens = + (qkv_format == NVTE_QKV_Format::NVTE_THD && !output_grads.data.shape.empty()) + ? static_cast(output_grads.data.shape[0]) + : 0; TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( output_grads.data.dtype, scalar_t, fused_rope_backward_launcher(reinterpret_cast(output_grads.data.dptr), @@ -605,7 +749,7 @@ void fused_rope_backward(const Tensor &output_grads, const Tensor &cu_seqlens, c reinterpret_cast(start_positions.data.dptr), reinterpret_cast(input_grads->data.dptr), qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t, - stride_b, stride_h, stride_d, stream);); + stride_b, stride_h, stride_d, total_tokens, stream);); } void fused_qkv_rope_forward(const Tensor &qkv_input, const Tensor &q_freqs, const Tensor &k_freqs,