diff --git a/tests/jax/test_grouped_gemm_partitioning.py b/tests/jax/test_grouped_gemm_partitioning.py new file mode 100644 index 0000000000..831e245c0e --- /dev/null +++ b/tests/jax/test_grouped_gemm_partitioning.py @@ -0,0 +1,474 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Partitioning tests for grouped quantize and grouped GEMM.""" + +from types import SimpleNamespace + +import jax +import jax.numpy as jnp +import numpy as np +import pytest +from jax.sharding import Mesh, NamedSharding, PartitionSpec + +from transformer_engine.jax.cpp_extensions.gemm import GroupedGemmPrimitive +from transformer_engine.jax.cpp_extensions.quantization import GroupedQuantizePrimitive +from transformer_engine.jax.dense import grouped_dense +from transformer_engine.jax.quantize import QuantizeLayout, QuantizerFactory, ScalingMode +from transformer_engine.jax.sharding import MeshResource, global_shard_guard + + +def _mesh(): + devices = jax.devices() + if len(devices) < 4: + pytest.skip("Grouped GEMM partitioning tests require at least 4 visible GPUs.") + return Mesh(np.asarray(devices[:4]).reshape(2, 2), ("expert", "fsdp")) + + +def _mesh_with_dp_tp(): + devices = jax.devices() + if len(devices) < 4: + pytest.skip("Grouped GEMM partitioning tests require at least 4 visible GPUs.") + return Mesh(np.asarray(devices[:4]).reshape(2, 1, 2, 1), ("expert", "dp", "fsdp", "tp")) + + +def _mesh_with_arbitrary_axis(): + devices = jax.devices() + if len(devices) < 4: + pytest.skip("Grouped GEMM partitioning tests require at least 4 visible GPUs.") + return Mesh( + np.asarray(devices[:4]).reshape(2, 1, 2, 1), + ("expert", "dp", "fsdp", "myaxis123"), + ) + + +def _arg_info(mesh, shape, spec): + return SimpleNamespace( + shape=shape, + ndim=len(shape), + size=int(np.prod(shape)), + sharding=NamedSharding(mesh, PartitionSpec(*spec)), + ) + + +def _normalize_spec(spec): + if isinstance(spec, PartitionSpec): + return tuple(spec) + return spec + + +def _spec_contains_axis(spec, axis): + for axis_spec in spec: + axis_tuple = axis_spec if isinstance(axis_spec, tuple) else (axis_spec,) + if axis in axis_tuple: + return True + return False + + +def _mxfp8_grouped_quantizer_set(n_groups): + return QuantizerFactory.create_set( + scaling_mode=ScalingMode.MXFP8_1D_SCALING, + fwd_dtype=jnp.float8_e4m3fn, + bwd_dtype=jnp.float8_e4m3fn, + is_2x2x=True, + n_groups=n_groups, + ) + + +def test_grouped_quantize_gathers_hidden_axis_for_block_scales(): + mesh = _mesh() + with global_shard_guard(MeshResource(fsdp_resource="fsdp", ep_resource="expert")): + _, _, out_shardings, arg_shardings = GroupedQuantizePrimitive.partition( + jnp.float8_e4m3fn, + ScalingMode.MXFP8_1D_SCALING.value, + QuantizeLayout.ROWWISE, + -1, + jnp.float8_e8m0fnu, + mesh, + ( + _arg_info(mesh, (8, 128, 64), ("expert", None, "fsdp")), + _arg_info(mesh, (8,), ("expert",)), + _arg_info(mesh, (8,), ("expert",)), + ), + (), + ) + + assert tuple(arg_shardings[0].spec) == ("expert", None, None) + specs = tuple(tuple(sharding.spec) for sharding in out_shardings) + assert _normalize_spec(specs[0]) == ("expert",) + assert _normalize_spec(specs[2]) == ("expert",) + assert _normalize_spec(specs[4]) == ("expert",) + + +def test_grouped_quantize_mxfp8_colwise_specs_gather_hidden_axis(): + mesh = _mesh() + with global_shard_guard(MeshResource(fsdp_resource="fsdp", ep_resource="expert")): + _, _, out_shardings, arg_shardings = GroupedQuantizePrimitive.partition( + jnp.float8_e4m3fn, + ScalingMode.MXFP8_1D_SCALING.value, + QuantizeLayout.ROWWISE_COLWISE, + -1, + jnp.float8_e8m0fnu, + mesh, + ( + _arg_info(mesh, (8, 128, 128), ("expert", None, "fsdp")), + _arg_info(mesh, (8,), ("expert",)), + _arg_info(mesh, (8,), ("expert",)), + ), + (), + ) + + assert tuple(arg_shardings[0].spec) == ("expert", None, None) + specs = tuple(tuple(sharding.spec) for sharding in out_shardings) + assert _normalize_spec(specs[0]) == ("expert",) + assert _normalize_spec(specs[1]) == ("expert",) + assert _normalize_spec(specs[2]) == ("expert",) + assert _normalize_spec(specs[3]) == ("expert",) + assert _normalize_spec(specs[4]) == ("expert",) + + +def test_grouped_quantize_strips_unsupported_axes_and_gathers_hidden_axes(): + mesh = _mesh_with_dp_tp() + with jax.set_mesh(mesh), global_shard_guard( + MeshResource(dp_resource="dp", tp_resource="tp", fsdp_resource="fsdp", ep_resource="expert") + ): + _, _, out_shardings, arg_shardings = GroupedQuantizePrimitive.partition( + jnp.float8_e4m3fn, + ScalingMode.MXFP8_1D_SCALING.value, + QuantizeLayout.ROWWISE, + -1, + jnp.float8_e8m0fnu, + mesh, + ( + _arg_info(mesh, (8, 128, 128), ("expert", "dp", ("fsdp", "tp"))), + _arg_info(mesh, (8,), (("expert", "tp"),)), + _arg_info(mesh, (8,), (("expert", "tp"),)), + ), + (), + ) + + assert tuple(arg_shardings[0].spec) == ("expert", None, None) + assert tuple(arg_shardings[1].spec) == ("expert",) + assert tuple(arg_shardings[2].spec) == ("expert",) + + out_specs = tuple(tuple(sharding.spec) for sharding in out_shardings) + assert _normalize_spec(out_specs[0]) == ("expert",) + assert _normalize_spec(out_specs[2]) == ("expert",) + assert _normalize_spec(out_specs[4]) == ("expert",) + for spec in (*out_specs, *(tuple(sharding.spec) for sharding in arg_shardings)): + assert not _spec_contains_axis(spec, "tp") + + +def test_grouped_gemm_rhs_weight_specs_gather_fsdp_but_preserve_ep(): + mesh = _mesh() + arg_infos = ( + _arg_info(mesh, (8192,), (None,)), + _arg_info(mesh, (0,), (None,)), + _arg_info(mesh, (65536,), (("expert", "fsdp"),)), + _arg_info(mesh, (2048,), (("expert", "fsdp"),)), + _arg_info(mesh, (0,), (None,)), + _arg_info(mesh, (8,), ("expert",)), + _arg_info(mesh, (0,), (None,)), + _arg_info(mesh, (0,), (None,)), + _arg_info(mesh, (0,), (None,)), + _arg_info(mesh, (8,), ("expert",)), + _arg_info(mesh, (0,), (None,)), + _arg_info(mesh, (1,), (None,)), + _arg_info(mesh, (0,), (None,)), + ) + with global_shard_guard(MeshResource(fsdp_resource="fsdp", ep_resource="expert")): + _, _, out_sharding, arg_shardings = GroupedGemmPrimitive.partition( + False, + False, + ScalingMode.NO_SCALING.value, + jnp.bfloat16, + False, + False, + False, + 1, + 1, + (1, 128, 64), + 128, + 64, + 128, + 64, + mesh, + arg_infos, + (), + ) + + assert tuple(arg_shardings[2].spec) == ("expert",) + assert tuple(arg_shardings[3].spec) == ("expert",) + assert tuple(out_sharding[0].spec) == (None, None, None) + + +def test_grouped_gemm_strips_unsupported_axes_preserves_dp_and_gathers_rhs_fsdp(): + mesh = _mesh_with_dp_tp() + arg_infos = ( + _arg_info(mesh, (8192,), (("dp", "tp"),)), + _arg_info(mesh, (0,), (("tp",),)), + _arg_info(mesh, (65536,), (("expert", "fsdp", "tp"),)), + _arg_info(mesh, (2048,), (("expert", "fsdp", "tp"),)), + _arg_info(mesh, (0,), (("fsdp", "tp"),)), + _arg_info(mesh, (8,), (("expert", "tp"),)), + _arg_info(mesh, (0,), (("tp",),)), + _arg_info(mesh, (0,), (("tp",),)), + _arg_info(mesh, (0,), (("tp",),)), + _arg_info(mesh, (8,), (("expert", "tp"),)), + _arg_info(mesh, (0,), (("tp",),)), + _arg_info(mesh, (1,), (("tp",),)), + _arg_info(mesh, (0,), (("tp",),)), + ) + result_infos = (_arg_info(mesh, (1, 128, 64), ("expert", "tp", None)),) + with jax.set_mesh(mesh), global_shard_guard( + MeshResource(dp_resource="dp", tp_resource="tp", fsdp_resource="fsdp", ep_resource="expert") + ): + _, _, out_sharding, arg_shardings = GroupedGemmPrimitive.partition( + False, + False, + ScalingMode.NO_SCALING.value, + jnp.bfloat16, + False, + False, + False, + 1, + 1, + (1, 128, 64), + 128, + 64, + 128, + 64, + mesh, + arg_infos, + result_infos, + ) + + assert tuple(arg_shardings[0].spec) == ("dp",) + assert tuple(arg_shardings[2].spec) == ("expert",) + assert tuple(arg_shardings[3].spec) == ("expert",) + assert tuple(arg_shardings[5].spec) == ("expert",) + assert tuple(out_sharding[0].spec) == ("expert", None, None) + for spec in ( + *(tuple(sharding.spec) for sharding in arg_shardings), + tuple(out_sharding[0].spec), + ): + assert not _spec_contains_axis(spec, "tp") + + +def test_grouped_gemm_reduce_axis_skips_ep_and_uses_dp(): + mesh = _mesh_with_dp_tp() + arg_infos = ( + _arg_info(mesh, (8192,), (("expert", "dp"),)), + _arg_info(mesh, (0,), (None,)), + _arg_info(mesh, (8192,), (("expert", "dp"),)), + _arg_info(mesh, (0,), (None,)), + _arg_info(mesh, (0,), (None,)), + _arg_info(mesh, (8,), ("expert",)), + _arg_info(mesh, (0,), (None,)), + _arg_info(mesh, (8,), ("expert",)), + _arg_info(mesh, (0,), (None,)), + _arg_info(mesh, (8,), ("expert",)), + _arg_info(mesh, (0,), (None,)), + _arg_info(mesh, (1,), (None,)), + _arg_info(mesh, (0,), (None,)), + ) + + with jax.set_mesh(mesh), global_shard_guard( + MeshResource(dp_resource="dp", fsdp_resource="fsdp", ep_resource="expert") + ): + _, _, reduce_axis = GroupedGemmPrimitive._parse_partition_specs( + mesh, + arg_infos, + (), + out_shape=(1, 128, 64), + lhs_is_trans=False, + lhs_axis_boundary=1, + ) + + assert reduce_axis == "dp" + + +def test_grouped_partitioning_strips_arbitrary_unsupported_axis(): + mesh = _mesh_with_arbitrary_axis() + mesh_resource = MeshResource(dp_resource="dp", fsdp_resource="fsdp", ep_resource="expert") + + with jax.set_mesh(mesh), global_shard_guard(mesh_resource): + _, _, quantize_out_shardings, quantize_arg_shardings = GroupedQuantizePrimitive.partition( + jnp.float8_e4m3fn, + ScalingMode.MXFP8_1D_SCALING.value, + QuantizeLayout.ROWWISE, + -1, + jnp.float8_e8m0fnu, + mesh, + ( + _arg_info(mesh, (8, 128, 128), ("expert", "myaxis123", ("dp", "fsdp"))), + _arg_info(mesh, (8,), (("expert", "myaxis123"),)), + _arg_info(mesh, (8,), (("expert", "myaxis123"),)), + ), + (), + ) + + gemm_arg_infos = ( + _arg_info(mesh, (8192,), (("dp", "myaxis123"),)), + _arg_info(mesh, (0,), (("myaxis123",),)), + _arg_info(mesh, (65536,), (("expert", "fsdp", "myaxis123"),)), + _arg_info(mesh, (2048,), (("expert", "fsdp", "myaxis123"),)), + _arg_info(mesh, (0,), (("fsdp", "myaxis123"),)), + _arg_info(mesh, (8,), (("expert", "myaxis123"),)), + _arg_info(mesh, (0,), (("myaxis123",),)), + _arg_info(mesh, (0,), (("myaxis123",),)), + _arg_info(mesh, (0,), (("myaxis123",),)), + _arg_info(mesh, (8,), (("expert", "myaxis123"),)), + _arg_info(mesh, (0,), (("myaxis123",),)), + _arg_info(mesh, (1,), (("myaxis123",),)), + _arg_info(mesh, (0,), (("myaxis123",),)), + ) + gemm_result_infos = (_arg_info(mesh, (1, 128, 64), ("expert", "myaxis123", None)),) + _, _, gemm_out_sharding, gemm_arg_shardings = GroupedGemmPrimitive.partition( + False, + False, + ScalingMode.NO_SCALING.value, + jnp.bfloat16, + False, + False, + False, + 1, + 1, + (1, 128, 64), + 128, + 64, + 128, + 64, + mesh, + gemm_arg_infos, + gemm_result_infos, + ) + + assert tuple(quantize_arg_shardings[0].spec) == ("expert", None, None) + assert tuple(quantize_arg_shardings[1].spec) == ("expert",) + quantize_out_specs = tuple(tuple(sharding.spec) for sharding in quantize_out_shardings) + assert _normalize_spec(quantize_out_specs[0]) == ("expert",) + assert _normalize_spec(quantize_out_specs[2]) == ("expert",) + + assert tuple(gemm_arg_shardings[0].spec) == ("dp",) + assert tuple(gemm_arg_shardings[2].spec) == ("expert",) + assert tuple(gemm_arg_shardings[3].spec) == ("expert",) + assert tuple(gemm_out_sharding[0].spec) == ("expert", None, None) + + all_specs = ( + *quantize_out_specs, + *(tuple(sharding.spec) for sharding in quantize_arg_shardings), + *(tuple(sharding.spec) for sharding in gemm_arg_shardings), + tuple(gemm_out_sharding[0].spec), + ) + for spec in all_specs: + assert not _spec_contains_axis(spec, "myaxis123") + + +def test_grouped_partitioning_shardy_rules_smoke(): + mesh = _mesh() + quantize_rule = GroupedQuantizePrimitive.shardy_sharding_rule( + jnp.float8_e4m3fn, + ScalingMode.MXFP8_1D_SCALING.value, + QuantizeLayout.ROWWISE, + -1, + jnp.float8_e8m0fnu, + mesh, + ( + SimpleNamespace(shape=(8, 128, 64)), + SimpleNamespace(shape=(8,)), + SimpleNamespace(shape=(8,)), + ), + ( + SimpleNamespace(shape=(8 * 128 * 64,)), + SimpleNamespace(shape=(1,)), + SimpleNamespace(shape=(8 * 128 * 64,)), + SimpleNamespace(shape=(1,)), + SimpleNamespace(shape=(8,)), + ), + ) + gemm_rule = GroupedGemmPrimitive.shardy_sharding_rule( + False, + False, + ScalingMode.NO_SCALING.value, + jnp.bfloat16, + False, + False, + False, + 1, + 2, + (128, 64), + 128, + 64, + 128, + 64, + mesh, + tuple(SimpleNamespace(shape=(1,)) for _ in range(13)), + (SimpleNamespace(shape=(128, 64)),), + ) + + assert quantize_rule is not None + assert gemm_rule is not None + + +def test_grouped_dense_mxfp8_ep_fsdp_outside_shard_map_single_process(): + mesh = _mesh() + n_groups = 4 + group_tokens = 128 + hidden = 256 + out_hidden = 128 + x_shape = (n_groups * group_tokens, hidden) + w_shape = (n_groups, hidden, out_hidden) + + x_sharding = NamedSharding(mesh, PartitionSpec("expert", None)) + w_sharding = NamedSharding(mesh, PartitionSpec("expert", "fsdp", None)) + group_sharding = NamedSharding(mesh, PartitionSpec("expert")) + out_sharding = NamedSharding(mesh, PartitionSpec("expert", None)) + + quantizer_set = _mxfp8_grouped_quantizer_set(n_groups) + + with mesh, global_shard_guard(MeshResource(fsdp_resource="fsdp", ep_resource="expert")): + x = jax.device_put( + jax.random.normal(jax.random.PRNGKey(20), x_shape, dtype=jnp.bfloat16) + * jnp.asarray(0.01, dtype=jnp.bfloat16), + x_sharding, + ) + w = jax.device_put( + jax.random.normal(jax.random.PRNGKey(21), w_shape, dtype=jnp.bfloat16) + * jnp.asarray(0.01, dtype=jnp.bfloat16), + w_sharding, + ) + group_sizes = jax.device_put( + jnp.full((n_groups,), group_tokens, dtype=jnp.int32), + group_sharding, + ) + + def apply_with_vjp(x, w, group_sizes): + def apply(x, w): + return grouped_dense( + x, + w, + group_sizes, + contracting_dims=((1,), (1,)), + quantizer_set=quantizer_set, + kernel_fsdp_info=("fsdp", 1), + ) + + out, vjp_fn = jax.vjp(apply, x, w) + dx, dw = vjp_fn(out) + return out, dx, dw + + out, dx, dw = jax.jit( + apply_with_vjp, + in_shardings=(x_sharding, w_sharding, group_sharding), + out_shardings=(out_sharding, x_sharding, w_sharding), + )(x, w, group_sizes) + out, dx, dw = jax.block_until_ready((out, dx, dw)) + + assert tuple(out.sharding.spec) == ("expert", None) + assert tuple(dx.sharding.spec) == ("expert", None) + assert tuple(dw.sharding.spec) == ("expert", "fsdp", None) + for value in (out, dx, dw): + local_value = np.asarray(jax.device_get(value.addressable_data(0))) + assert np.all(np.isfinite(local_value)) + assert np.any(local_value != 0.0) diff --git a/tests/jax/test_multi_process_distributed_grouped_gemm.py b/tests/jax/test_multi_process_distributed_grouped_gemm.py index 94fed0859f..aa73ba089b 100644 --- a/tests/jax/test_multi_process_distributed_grouped_gemm.py +++ b/tests/jax/test_multi_process_distributed_grouped_gemm.py @@ -7,18 +7,34 @@ import jax import jax.numpy as jnp import jax.experimental.multihost_utils as jem +import numpy as np +from jax.experimental import shard_map +from jax.sharding import NamedSharding, PartitionSpec from transformer_engine.jax.dense import grouped_dense as te_grouped_dense from transformer_engine.jax.quantize import ( QuantizerFactory, ScalingMode, ) +from transformer_engine.jax.sharding import MeshResource, global_shard_guard from utils import assert_allclose, dtype_tols N_GROUP = 8 -MESH_AXIS_NAME = "fsdp" +EP_AXIS_NAME = "ep" +FSDP_AXIS_NAME = "fsdp" +MESH_AXIS_NAME = FSDP_AXIS_NAME + + +def _mxfp8_grouped_quantizer_set(n_groups): + return QuantizerFactory.create_set( + scaling_mode=ScalingMode.MXFP8_1D_SCALING, + fwd_dtype=jnp.float8_e4m3fn, + bwd_dtype=jnp.float8_e4m3fn, + is_2x2x=True, + n_groups=n_groups, + ) def test_grouped_gemm_fp8_allgather(data_shapes, kernel_fsdp_axis): @@ -31,18 +47,31 @@ def test_grouped_gemm_fp8_allgather(data_shapes, kernel_fsdp_axis): if kernel_fsdp_axis == 2 else NamedSharding(mesh, PartitionSpec(None, MESH_AXIS_NAME, None)) ) + b_sharding = ( + NamedSharding(mesh, PartitionSpec(None, MESH_AXIS_NAME)) + if kernel_fsdp_axis == 2 + else NamedSharding(mesh, PartitionSpec(None, None)) + ) w_no_sharding = NamedSharding(mesh, PartitionSpec(None, None, None)) + b_no_sharding = NamedSharding(mesh, PartitionSpec(None, None)) def init_data(): x_key = jax.random.PRNGKey(0) w_key = jax.random.PRNGKey(1) - x = jax.random.normal(x_key, shape=(N_GROUP, *x_shape), dtype=jnp.bfloat16) - w = jax.random.normal(w_key, shape=(N_GROUP, *w_shape), dtype=jnp.bfloat16) - w_amax = jnp.max(jnp.abs(w), axis=range(1, w.ndim)) - return x, w, w, w_amax + b_key = jax.random.PRNGKey(2) + x = jax.random.normal(x_key, shape=(N_GROUP, *x_shape), dtype=jnp.bfloat16) * jnp.asarray( + 0.01, dtype=jnp.bfloat16 + ) + w = jax.random.normal(w_key, shape=(N_GROUP, *w_shape), dtype=jnp.bfloat16) * jnp.asarray( + 0.01, dtype=jnp.bfloat16 + ) + b = jax.random.normal( + b_key, shape=(N_GROUP, w_shape[-1]), dtype=jnp.bfloat16 + ) * jnp.asarray(0.01, dtype=jnp.bfloat16) + return x, w, w, b, b - def test_func(outter_x, outter_w, outter_w_amax): - in_specs = (x_sharding.spec, w_sharding.spec, None) + def test_func(outter_x, outter_w, outter_b): + in_specs = (x_sharding.spec, w_sharding.spec, b_sharding.spec) out_specs = x_sharding.spec @partial( @@ -52,41 +81,35 @@ def test_func(outter_x, outter_w, outter_w_amax): out_specs=out_specs, check_rep=False, ) - def sharded_group_gemm(x, w, w_amax): + def sharded_group_gemm(x, w, b): group_size = x.shape[0] x_reshaped = x.reshape(-1, x.shape[-1]) n_groups = jnp.full(group_size, x_reshaped.shape[0] // group_size) - quantizer_set = QuantizerFactory.create_set( - scaling_mode=ScalingMode.CURRENT_TENSOR_SCALING, - fwd_dtype=jnp.float8_e4m3fn, - bwd_dtype=jnp.float8_e5m2, - is_2x2x=True, - n_groups=group_size, - ) + quantizer_set = _mxfp8_grouped_quantizer_set(group_size) output = te_grouped_dense( x_reshaped, w, n_groups, - kernel_amax=w_amax, + bias=b, quantizer_set=quantizer_set, kernel_fsdp_info=(MESH_AXIS_NAME, kernel_fsdp_axis), ) output = output.reshape(*x.shape[:-1], -1) return output - def run(x, w, w_amax): - output = sharded_group_gemm(x, w, w_amax) + def run(x, w, b): + output = sharded_group_gemm(x, w, b) return output - output, vjp_fn = jax.vjp(run, outter_x, outter_w, outter_w_amax) - dx, dw, _ = vjp_fn(output) - return output, dx, dw + output, vjp_fn = jax.vjp(run, outter_x, outter_w, outter_b) + dx, dw, db = vjp_fn(output) + return output, dx, dw, db - def ref_func(outter_x, outter_w): + def ref_func(outter_x, outter_w, outter_b): - in_specs = (x_sharding.spec, w_no_sharding.spec) + in_specs = (x_sharding.spec, w_no_sharding.spec, b_no_sharding.spec) out_specs = x_sharding.spec @partial( @@ -96,63 +119,129 @@ def ref_func(outter_x, outter_w): out_specs=out_specs, check_rep=False, ) - def sharded_group_gemm(x, w): + def sharded_group_gemm(x, w, b): group_size = x.shape[0] x_reshaped = x.reshape(-1, x.shape[-1]) n_groups = jnp.full(group_size, x_reshaped.shape[0] // group_size) - quantizer_set = QuantizerFactory.create_set( - scaling_mode=ScalingMode.CURRENT_TENSOR_SCALING, - fwd_dtype=jnp.float8_e4m3fn, - bwd_dtype=jnp.float8_e5m2, - is_2x2x=True, - n_groups=group_size, + quantizer_set = _mxfp8_grouped_quantizer_set(group_size) + output = te_grouped_dense( + x_reshaped, + w, + n_groups, + bias=b, + quantizer_set=quantizer_set, ) - output = te_grouped_dense(x_reshaped, w, n_groups, quantizer_set=quantizer_set) output = output.reshape(*x.shape[:-1], -1) return output - def run(x, w): - output = sharded_group_gemm(x, w) + def run(x, w, b): + output = sharded_group_gemm(x, w, b) return output - output, vjp_fn = jax.vjp(run, outter_x, outter_w) - dx, dw = vjp_fn(output) - return output, dx, dw + output, vjp_fn = jax.vjp(run, outter_x, outter_w, outter_b) + dx, dw, db = vjp_fn(output) + return output, dx, dw, db - init_func = jax.jit(init_data, out_shardings=(x_sharding, w_sharding, w_no_sharding, None)) - x, w, w_global, w_amax = init_func() + init_func = jax.jit( + init_data, + out_shardings=(x_sharding, w_sharding, w_no_sharding, b_sharding, b_no_sharding), + ) + x, w, w_global, b, b_global = init_func() o_sharding = x_sharding test_func_jitted = jax.jit( test_func, - in_shardings=(x_sharding, w_sharding, None), - out_shardings=(o_sharding, x_sharding, w_sharding), + in_shardings=(x_sharding, w_sharding, b_sharding), + out_shardings=(o_sharding, x_sharding, w_sharding, b_sharding), ) ref_func_jitted = jax.jit( ref_func, - in_shardings=(x_sharding, w_no_sharding), - out_shardings=(o_sharding, x_sharding, w_no_sharding), + in_shardings=(x_sharding, w_no_sharding, b_no_sharding), + out_shardings=(o_sharding, x_sharding, w_no_sharding, b_no_sharding), ) - out, dx, dw = test_func_jitted(x, w, w_amax) - ref_out, ref_dx, ref_dw = ref_func_jitted(x, w_global) - - e4m3_tols = dtype_tols(jnp.float8_e4m3fn) - e5m2_tols = dtype_tols(jnp.float8_e5m2) - - out, ref_out = jem.process_allgather((out, ref_out)) - dx, ref_dx = jem.process_allgather((dx, ref_dx)) - dw, ref_dw = jem.process_allgather((dw, ref_dw)) + out, dx, dw, db = test_func_jitted(x, w, b) + ref_out, ref_dx, ref_dw, ref_db = ref_func_jitted(x, w_global, b_global) + + # Avoid creating a host scalar JAX array under the multi-process mesh in dtype_tols. + e4m3_tols = dtype_tols(jnp.float8_e4m3fn, rtol=0.25, atol=0.25) + + out, ref_out = jem.process_allgather((out, ref_out), tiled=True) + dx, ref_dx = jem.process_allgather((dx, ref_dx), tiled=True) + dw, ref_dw = jem.process_allgather((dw, ref_dw), tiled=True) + db, ref_db = jem.process_allgather((db, ref_db), tiled=True) + + assert_allclose(out, ref_out, **e4m3_tols) + assert_allclose(dx, ref_dx, **e4m3_tols) + assert_allclose(dw, ref_dw, **e4m3_tols) + assert_allclose(db, ref_db, **e4m3_tols) + + +def run_grouped_dense_mxfp8_ep_fsdp_outside_shard_map(): + n_groups = 4 + group_tokens = 128 + hidden = 256 + out_hidden = 128 + x_shape = (n_groups * group_tokens, hidden) + w_shape = (n_groups, hidden, out_hidden) + quantizer_set = _mxfp8_grouped_quantizer_set(n_groups) + + x_sharding = NamedSharding(mesh, PartitionSpec(EP_AXIS_NAME, None)) + w_sharding = NamedSharding(mesh, PartitionSpec(EP_AXIS_NAME, FSDP_AXIS_NAME, None)) + group_sharding = NamedSharding(mesh, PartitionSpec(EP_AXIS_NAME)) + out_sharding = NamedSharding(mesh, PartitionSpec(EP_AXIS_NAME, None)) + + with mesh, global_shard_guard( + MeshResource(ep_resource=EP_AXIS_NAME, fsdp_resource=FSDP_AXIS_NAME) + ): + x = jax.device_put( + jax.random.normal(jax.random.PRNGKey(20), x_shape, dtype=jnp.bfloat16) + * jnp.asarray(0.01, dtype=jnp.bfloat16), + x_sharding, + ) + w = jax.device_put( + jax.random.normal(jax.random.PRNGKey(21), w_shape, dtype=jnp.bfloat16) + * jnp.asarray(0.01, dtype=jnp.bfloat16), + w_sharding, + ) + group_sizes = jax.device_put( + jnp.full((n_groups,), group_tokens, dtype=jnp.int32), + group_sharding, + ) - jnp.allclose(out, ref_out, **e4m3_tols) - jnp.allclose(dx, ref_dx, **e5m2_tols) - jnp.allclose(dw, ref_dw, **e5m2_tols) + def apply_with_vjp(x, w, group_sizes): + def apply(x, w): + return te_grouped_dense( + x, + w, + group_sizes, + contracting_dims=((1,), (1,)), + quantizer_set=quantizer_set, + kernel_fsdp_info=(FSDP_AXIS_NAME, 1), + ) + + out, vjp_fn = jax.vjp(apply, x, w) + dx, dw = vjp_fn(out) + return out, dx, dw + + out, dx, dw = jax.jit( + apply_with_vjp, + in_shardings=(x_sharding, w_sharding, group_sharding), + out_shardings=(out_sharding, x_sharding, w_sharding), + )(x, w, group_sizes) + out, dx, dw = jax.block_until_ready((out, dx, dw)) + + assert tuple(out.sharding.spec) == (EP_AXIS_NAME, None) + assert tuple(dx.sharding.spec) == (EP_AXIS_NAME, None) + assert tuple(dw.sharding.spec) == (EP_AXIS_NAME, FSDP_AXIS_NAME, None) + for value in (out, dx, dw): + local_value = np.asarray(jax.device_get(value.addressable_data(0))) + assert np.all(np.isfinite(local_value)) + assert np.any(local_value != 0.0) if __name__ == "__main__": - from jax.sharding import NamedSharding, PartitionSpec - from jax.experimental import shard_map import sys coord_addr = sys.argv[1] @@ -163,10 +252,14 @@ def run(x, w): coordinator_address=coord_addr, num_processes=num_procs, process_id=proc_id ) - mesh = jax.make_mesh((num_procs,), (MESH_AXIS_NAME,)) + mesh = jax.make_mesh((num_procs,), (FSDP_AXIS_NAME,)) with mesh: data_shapes = [((4, 16, 128, 7168), (7168, 2048))] for data_shape in data_shapes: for kernel_fsdp_axis in [1, 2]: test_grouped_gemm_fp8_allgather(data_shape, kernel_fsdp_axis) + + if num_procs == 4: + mesh = jax.make_mesh((2, 2), (EP_AXIS_NAME, FSDP_AXIS_NAME)) + run_grouped_dense_mxfp8_ep_fsdp_outside_shard_map() diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 4ff6d07986..14a77d4349 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -211,6 +211,140 @@ def _get_nvfp4_tensor_scale_inv(amax): return amax / (DATA_DTYPE_MAX * SCALE_DTYPE_MAX) +def _axis_spec_contains(axis_spec, axis): + if axis is None or axis_spec is None: + return False + if isinstance(axis_spec, tuple): + return axis in axis_spec + return axis_spec == axis + + +def _spec_contains_axis(spec, axis): + return any(_axis_spec_contains(axis_spec, axis) for axis_spec in spec) + + +def _strip_axis_from_axis_spec(axis_spec, axis): + if axis is None or axis_spec is None: + return axis_spec + if isinstance(axis_spec, tuple): + stripped = tuple(a for a in axis_spec if a != axis) + if len(stripped) == 0: + return None + return stripped[0] if len(stripped) == 1 else stripped + return None if axis_spec == axis else axis_spec + + +def _strip_axis_from_spec(spec, axis): + return tuple(_strip_axis_from_axis_spec(axis_spec, axis) for axis_spec in spec) + + +def _filter_axis_spec(axis_spec, allowed_axes): + if axis_spec is None: + return None + axis_tuple = axis_spec if isinstance(axis_spec, tuple) else (axis_spec,) + axes = tuple(axis for axis in axis_tuple if axis in allowed_axes) + if len(axes) == 0: + return None + return axes[0] if len(axes) == 1 else axes + + +def _filter_spec_axes(spec, allowed_axes): + return tuple(_filter_axis_spec(axis_spec, allowed_axes) for axis_spec in spec) + + +def _supported_grouped_gemm_axes(mesh): + gsr = global_mesh_resource(validate=False) + return { + axis + for axis in (gsr.ep_resource, gsr.dp_resource, gsr.fsdp_resource) + if axis is not None and axis in mesh.axis_names + } + + +def _common_axis(spec_a, spec_b, allowed_axes=None): + axes = [] + for spec in (spec_a, spec_b): + for axis_spec in spec: + if axis_spec is None: + continue + axis_tuple = axis_spec if isinstance(axis_spec, tuple) else (axis_spec,) + for axis in axis_tuple: + if axis is not None and axis not in axes: + axes.append(axis) + for axis in axes: + if allowed_axes is not None and axis not in allowed_axes: + continue + if _spec_contains_axis(spec_a, axis) and _spec_contains_axis(spec_b, axis): + return axis + return None + + +def _merge_axis_spec(axis_spec_a, axis_spec_b): + axes = [] + for axis_spec in (axis_spec_a, axis_spec_b): + if axis_spec is None: + continue + axis_tuple = axis_spec if isinstance(axis_spec, tuple) else (axis_spec,) + for axis in axis_tuple: + if axis is not None and axis not in axes: + axes.append(axis) + if len(axes) == 0: + return None + return axes[0] if len(axes) == 1 else tuple(axes) + + +def _partition_spec_from_result(mesh, result_info, fallback_spec): + if result_info is not None and result_info.sharding is not None: + return result_info.sharding + return NamedSharding(mesh, PartitionSpec(*fallback_spec)) + + +def _local_shape_from_spec(global_shape, spec, mesh): + local_shape = [] + for dim, axis_spec in zip(global_shape, spec): + axis_size = _axis_spec_size(axis_spec, mesh) + local_shape.append(dim // axis_size) + return tuple(local_shape) + + +def _axis_spec_size(axis_spec, mesh): + axis_tuple = axis_spec if isinstance(axis_spec, tuple) else (axis_spec,) + axis_size = 1 + for axis in axis_tuple: + if axis is not None: + axis_size *= mesh.shape[axis] + return axis_size + + +def _spec_size(spec, mesh): + axis_size = 1 + for axis_spec in spec: + axis_size *= _axis_spec_size(axis_spec, mesh) + return axis_size + + +def _local_2d_sizes_from_spec(shape, spec, axis_boundary, left_size, right_size, mesh): + if len(shape) == len(spec) and len(shape) > 1: + local_shape = _local_shape_from_spec(shape, spec, mesh) + return ( + math.prod(local_shape[:axis_boundary]), + math.prod(local_shape[axis_boundary:]), + ) + + spec_size = _spec_size(spec, mesh) + if spec_size == 1: + return left_size, right_size + if left_size % spec_size == 0: + return left_size // spec_size, right_size + if right_size % spec_size == 0: + return left_size, right_size // spec_size + raise ValueError( + "Cannot derive local grouped GEMM 2D sizes from sharding spec. " + f"shape={shape}, spec={spec}, axis_boundary={axis_boundary}, " + f"left_size={left_size}, right_size={right_size}, spec_size={spec_size}" + ) + + def collective_gemm_bootstrap( num_total_devices, num_devices_per_process, @@ -1738,6 +1872,293 @@ def impl( ) return (out,) + @staticmethod + def _parse_partition_specs( + mesh, + arg_infos, + result_infos, + out_shape=None, + lhs_is_trans=None, + lhs_axis_boundary=None, + ): + gsr = global_mesh_resource(validate=False) + fsdp_axis = gsr.fsdp_resource + allowed_axes = _supported_grouped_gemm_axes(mesh) + + lhs_data_spec = _filter_spec_axes(get_padded_spec(arg_infos[0]), allowed_axes) + lhs_scale_spec = _filter_spec_axes(get_padded_spec(arg_infos[1]), allowed_axes) + rhs_data_spec = _filter_spec_axes(get_padded_spec(arg_infos[2]), allowed_axes) + rhs_scale_spec = _filter_spec_axes(get_padded_spec(arg_infos[3]), allowed_axes) + bias_spec = _filter_spec_axes(get_padded_spec(arg_infos[4]), allowed_axes) + + lhs_first_dims_spec = _filter_spec_axes(get_padded_spec(arg_infos[5]), allowed_axes) + lhs_last_dims_spec = _filter_spec_axes(get_padded_spec(arg_infos[6]), allowed_axes) + rhs_first_dims_spec = _filter_spec_axes(get_padded_spec(arg_infos[7]), allowed_axes) + rhs_last_dims_spec = _filter_spec_axes(get_padded_spec(arg_infos[8]), allowed_axes) + out_first_dims_spec = _filter_spec_axes(get_padded_spec(arg_infos[9]), allowed_axes) + out_last_dims_spec = _filter_spec_axes(get_padded_spec(arg_infos[10]), allowed_axes) + additional_arg_0_spec = _filter_spec_axes(get_padded_spec(arg_infos[11]), allowed_axes) + additional_arg_1_spec = _filter_spec_axes(get_padded_spec(arg_infos[12]), allowed_axes) + + grouped_dim_specs = ( + lhs_first_dims_spec, + lhs_last_dims_spec, + rhs_first_dims_spec, + rhs_last_dims_spec, + out_first_dims_spec, + out_last_dims_spec, + ) + grouped_dim_infos = arg_infos[5:11] + active_group_spec = next( + (spec for spec, info in zip(grouped_dim_specs, grouped_dim_infos) if info.size > 0), + (None,), + ) + if arg_infos[11].size > 1: + additional_arg_0_spec = active_group_spec + if arg_infos[12].size > 1: + additional_arg_1_spec = active_group_spec + + rhs_is_ragged = arg_infos[7].size > 0 or arg_infos[8].size > 0 + ep_axis = gsr.ep_resource + if ( + ep_axis is not None + and not rhs_is_ragged + and _spec_contains_axis(active_group_spec, ep_axis) + ): + if len(rhs_data_spec) > 0 and not _spec_contains_axis(rhs_data_spec, ep_axis): + rhs_data_spec = ( + _merge_axis_spec(rhs_data_spec[0], ep_axis), + *rhs_data_spec[1:], + ) + if len(rhs_scale_spec) > 0 and not _spec_contains_axis(rhs_scale_spec, ep_axis): + rhs_scale_spec = ( + _merge_axis_spec(rhs_scale_spec[0], ep_axis), + *rhs_scale_spec[1:], + ) + if len(bias_spec) > 0 and not _spec_contains_axis(bias_spec, ep_axis): + bias_spec = (_merge_axis_spec(bias_spec[0], ep_axis), *bias_spec[1:]) + + gather_rhs_fsdp = ( + fsdp_axis is not None + and not rhs_is_ragged + and ( + _spec_contains_axis(rhs_data_spec, fsdp_axis) + or _spec_contains_axis(rhs_scale_spec, fsdp_axis) + or _spec_contains_axis(bias_spec, fsdp_axis) + ) + ) + + if gather_rhs_fsdp: + rhs_data_spec = _strip_axis_from_spec(rhs_data_spec, fsdp_axis) + rhs_scale_spec = _strip_axis_from_spec(rhs_scale_spec, fsdp_axis) + bias_spec = _strip_axis_from_spec(bias_spec, fsdp_axis) + + reducible_axes = tuple( + axis for axis in (gsr.dp_resource, gsr.fsdp_resource) if axis is not None + ) + reduce_axis = _common_axis(lhs_data_spec, rhs_data_spec, reducible_axes) + if reduce_axis is not None and gather_rhs_fsdp: + reduce_axis = None + + if result_infos: + out_spec = _filter_spec_axes(get_padded_spec(result_infos[0]), allowed_axes) + else: + out_spec = (None,) * (len(out_shape) if out_shape is not None else 1) + + if rhs_is_ragged and lhs_is_trans is not None and lhs_axis_boundary is not None: + lhs_non_contracting_dims = ( + range(lhs_axis_boundary, len(lhs_data_spec)) + if lhs_is_trans + else range(0, lhs_axis_boundary) + ) + lhs_data_spec = list(lhs_data_spec) + for out_idx, lhs_dim in enumerate(lhs_non_contracting_dims, start=1): + if out_idx < len(out_spec): + lhs_data_spec[lhs_dim] = _merge_axis_spec( + lhs_data_spec[lhs_dim], out_spec[out_idx] + ) + lhs_data_spec = tuple(lhs_data_spec) + + return ( + ( + lhs_data_spec, + lhs_scale_spec, + rhs_data_spec, + rhs_scale_spec, + bias_spec, + lhs_first_dims_spec, + lhs_last_dims_spec, + rhs_first_dims_spec, + rhs_last_dims_spec, + out_first_dims_spec, + out_last_dims_spec, + additional_arg_0_spec, + additional_arg_1_spec, + ), + out_spec, + reduce_axis, + ) + + @staticmethod + def partition( + lhs_is_trans, + rhs_is_trans, + scaling_mode, + out_dtype, + has_bias, + use_async_d2h_group_sizes, + use_v2_ffi, + lhs_axis_boundary, + rhs_axis_boundary, + out_shape, + lhs_left_size, + lhs_right_size, + rhs_left_size, + rhs_right_size, + mesh, + arg_infos, + result_infos, + ): + arg_specs, out_spec, reduce_axis = GroupedGemmPrimitive._parse_partition_specs( + mesh, + arg_infos, + result_infos, + out_shape, + lhs_is_trans=lhs_is_trans, + lhs_axis_boundary=lhs_axis_boundary, + ) + arg_shardings = tuple(NamedSharding(mesh, PartitionSpec(*spec)) for spec in arg_specs) + out_sharding = (NamedSharding(mesh, PartitionSpec(*out_spec)),) + local_out_shape = _local_shape_from_spec(out_shape, out_spec, mesh) + local_lhs_left_size, local_lhs_right_size = _local_2d_sizes_from_spec( + arg_infos[0].shape, + arg_specs[0], + lhs_axis_boundary, + lhs_left_size, + lhs_right_size, + mesh, + ) + local_rhs_left_size, local_rhs_right_size = _local_2d_sizes_from_spec( + arg_infos[2].shape, + arg_specs[2], + rhs_axis_boundary, + rhs_left_size, + rhs_right_size, + mesh, + ) + + def sharded_impl( + lhs_data, + lhs_scale_inv, + rhs_data, + rhs_scale_inv, + bias, + lhs_first_dims, + lhs_last_dims, + rhs_first_dims, + rhs_last_dims, + out_first_dims, + out_last_dims, + additional_arg_0, + additional_arg_1, + ): + (out,) = GroupedGemmPrimitive.impl( + lhs_data, + lhs_scale_inv, + rhs_data, + rhs_scale_inv, + bias, + lhs_first_dims, + lhs_last_dims, + rhs_first_dims, + rhs_last_dims, + out_first_dims, + out_last_dims, + additional_arg_0, + additional_arg_1, + lhs_is_trans=lhs_is_trans, + rhs_is_trans=rhs_is_trans, + scaling_mode=scaling_mode, + out_dtype=out_dtype, + has_bias=has_bias, + use_async_d2h_group_sizes=use_async_d2h_group_sizes, + use_v2_ffi=use_v2_ffi, + lhs_axis_boundary=lhs_axis_boundary, + rhs_axis_boundary=rhs_axis_boundary, + out_shape=local_out_shape, + lhs_left_size=local_lhs_left_size, + lhs_right_size=local_lhs_right_size, + rhs_left_size=local_rhs_left_size, + rhs_right_size=local_rhs_right_size, + ) + + if reduce_axis is not None: + if is_all_reduce_in_float32(): + out = jax.lax.psum(out.astype(jnp.float32), reduce_axis).astype(out_dtype) + else: + out = jax.lax.psum(out, reduce_axis) + return (out,) + + return mesh, sharded_impl, out_sharding, arg_shardings + + @staticmethod + def shardy_sharding_rule( + lhs_is_trans, + rhs_is_trans, + scaling_mode, + out_dtype, + has_bias, + use_async_d2h_group_sizes, + use_v2_ffi, + lhs_axis_boundary, + rhs_axis_boundary, + out_shape, + lhs_left_size, + lhs_right_size, + rhs_left_size, + rhs_right_size, + mesh, + operand_types, + result_types, + ): + del ( + lhs_is_trans, + rhs_is_trans, + scaling_mode, + out_dtype, + has_bias, + use_async_d2h_group_sizes, + use_v2_ffi, + lhs_axis_boundary, + rhs_axis_boundary, + out_shape, + lhs_left_size, + lhs_right_size, + rhs_left_size, + rhs_right_size, + mesh, + ) + + prefix = "GroupedGemm" + + def spec_for(name, rank): + if rank == 0: + return () + return tuple(f"{prefix}_{name}_{i}" for i in range(rank)) + + operand_mappings = tuple( + spec_for(f"arg{i}", len(operand_type.shape)) + for i, operand_type in enumerate(operand_types) + ) + result_mappings = tuple( + spec_for(f"out{i}", len(result_type.shape)) + for i, result_type in enumerate(result_types) + ) + return SdyShardingRule( + operand_mappings=operand_mappings, + result_mappings=result_mappings, + ) + register_primitive(GroupedGemmPrimitive) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 7138cfcf40..5a245b5e7d 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -32,6 +32,8 @@ all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp, get_num_devices_in_mesh, + global_mesh_resource, + lax_paral_op, ) from ..quantize import ( ScaledTensor2x, @@ -52,6 +54,93 @@ __all__ = ["quantize", "quantize_dbias", "grouped_quantize", "grouped_dbias"] +def _merge_axis_specs(axis_specs): + axes = [] + for axis_spec in axis_specs: + if axis_spec is None: + continue + axis_tuple = axis_spec if isinstance(axis_spec, tuple) else (axis_spec,) + for axis in axis_tuple: + if axis is not None and axis not in axes: + axes.append(axis) + if len(axes) == 0: + return None + return axes[0] if len(axes) == 1 else tuple(axes) + + +def _flat_data_spec(input_spec): + return (_merge_axis_specs(input_spec),) + + +def _normalize_flatten_axis(flatten_axis, ndim): + return flatten_axis + ndim if flatten_axis < 0 else flatten_axis + + +def _contiguous_flat_input_spec(input_spec, flatten_axis): + flatten_axis = _normalize_flatten_axis(flatten_axis, len(input_spec)) + if flatten_axis <= 0 or len(input_spec) == 0: + return (None,) * len(input_spec) + return (input_spec[0], *((None,) * (len(input_spec) - 1))) + + +def _filter_axis_spec(axis_spec, allowed_axes): + if axis_spec is None: + return None + axis_tuple = axis_spec if isinstance(axis_spec, tuple) else (axis_spec,) + axes = tuple(axis for axis in axis_tuple if axis in allowed_axes) + if len(axes) == 0: + return None + return axes[0] if len(axes) == 1 else axes + + +def _filter_spec_axes(spec, allowed_axes): + return tuple(_filter_axis_spec(axis_spec, allowed_axes) for axis_spec in spec) + + +def _supported_grouped_quantize_axes(mesh): + gsr = global_mesh_resource(validate=False) + return { + axis + for axis in (gsr.ep_resource, gsr.dp_resource, gsr.fsdp_resource) + if axis is not None and axis in mesh.axis_names + } + + +def _axis_spec_size(axis_spec, mesh): + axis_tuple = axis_spec if isinstance(axis_spec, tuple) else (axis_spec,) + axis_size = 1 + for axis in axis_tuple: + if axis is not None: + axis_size *= mesh.shape[axis] + return axis_size + + +def _local_shape_from_spec(global_shape, spec, mesh): + local_shape = [] + for dim, axis_spec in zip(global_shape, spec): + local_shape.append(dim // _axis_spec_size(axis_spec, mesh)) + return tuple(local_shape) + + +def _pad_or_slice_to_shape(x, target_shape): + if target_shape is None or x.shape == target_shape: + return x + target_size = math.prod(target_shape) + current_size = math.prod(x.shape) + x = x.reshape(-1) + if current_size > target_size: + return x[:target_size].reshape(target_shape) + return jnp.pad(x, (0, target_size - current_size)).reshape(target_shape) + + +def _all_reduce_grouped_amax_along_dp_fsdp(amax, mesh): + gsr = global_mesh_resource() + for axis in (gsr.dp_resource, gsr.fsdp_resource): + if axis is not None and axis in mesh.axis_names: + amax = lax_paral_op(amax, jax.lax.pmax, axis, mesh) + return amax + + class BaseDBiasQuantizePrimitive(BasePrimitive): """ Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias @@ -1236,6 +1325,147 @@ def impl( ) return rowwise_out, colwise_out, rowwise_scale_inv, colwise_scale_inv, updated_amax + @staticmethod + def _parse_partition_specs(scaling_mode, q_layout, flatten_axis, mesh, arg_infos): + allowed_axes = _supported_grouped_quantize_axes(mesh) + x_spec = _filter_spec_axes(get_padded_spec(arg_infos[0]), allowed_axes) + x_spec = _contiguous_flat_input_spec(x_spec, flatten_axis) + group_spec = _filter_spec_axes(get_padded_spec(arg_infos[2]), allowed_axes) + if group_spec == (None,) and len(x_spec) > 0: + group_spec = (x_spec[0],) + flat_spec = _flat_data_spec(x_spec) + replicated_spec = (None,) + + rowwise_out_spec = flat_spec if q_layout.has_rowwise else replicated_spec + colwise_out_spec = flat_spec if q_layout.has_colwise else replicated_spec + + rowwise_scale_inv_spec = replicated_spec + colwise_scale_inv_spec = replicated_spec + if ScalingMode(scaling_mode).is_block_scaling: + rowwise_scale_inv_spec = flat_spec if q_layout.has_rowwise else replicated_spec + colwise_scale_inv_spec = flat_spec if q_layout.has_colwise else replicated_spec + elif ScalingMode(scaling_mode).is_tensor_scaling(): + rowwise_scale_inv_spec = group_spec if q_layout.has_rowwise else replicated_spec + colwise_scale_inv_spec = group_spec if q_layout.has_colwise else replicated_spec + + updated_amax_spec = group_spec + return ( + x_spec, + group_spec, + ( + rowwise_out_spec, + colwise_out_spec, + rowwise_scale_inv_spec, + colwise_scale_inv_spec, + updated_amax_spec, + ), + ) + + @staticmethod + def partition( + out_dtype, + scaling_mode, + q_layout, + flatten_axis, + scale_dtype, + mesh, + arg_infos, + result_infos, + ): + x_spec, group_spec, out_specs = GroupedQuantizePrimitive._parse_partition_specs( + scaling_mode, q_layout, flatten_axis, mesh, arg_infos + ) + local_out_shapes = ( + tuple( + _local_shape_from_spec(info.shape, spec, mesh) + for info, spec in zip(result_infos, out_specs) + ) + if result_infos + else (None,) * len(out_specs) + ) + + arg_shardings = ( + NamedSharding(mesh, PartitionSpec(*x_spec)), + NamedSharding(mesh, PartitionSpec(*group_spec)), + NamedSharding(mesh, PartitionSpec(*group_spec)), + ) + out_shardings = tuple(NamedSharding(mesh, PartitionSpec(*spec)) for spec in out_specs) + + def sharded_impl(x, scale, group_sizes): + ( + rowwise_out, + colwise_out, + rowwise_scale_inv, + colwise_scale_inv, + updated_amax, + ) = GroupedQuantizePrimitive.impl( + x, + scale, + group_sizes, + out_dtype=out_dtype, + scaling_mode=scaling_mode, + q_layout=q_layout, + flatten_axis=flatten_axis, + scale_dtype=scale_dtype, + ) + if ScalingMode(scaling_mode).is_block_scaling: + rowwise_scale_inv = _pad_or_slice_to_shape(rowwise_scale_inv, local_out_shapes[2]) + colwise_scale_inv = _pad_or_slice_to_shape(colwise_scale_inv, local_out_shapes[3]) + if ScalingMode(scaling_mode).is_tensor_scaling(): + updated_amax = _all_reduce_grouped_amax_along_dp_fsdp(updated_amax, mesh) + return ( + rowwise_out, + colwise_out, + rowwise_scale_inv, + colwise_scale_inv, + updated_amax, + ) + + return mesh, sharded_impl, out_shardings, arg_shardings + + @staticmethod + def shardy_sharding_rule( + out_dtype, + scaling_mode, + q_layout, + flatten_axis, + scale_dtype, + mesh, + value_types, + result_types, + ): + del out_dtype, scale_dtype, mesh, result_types, flatten_axis + + prefix = "GroupedQuantize" + input_spec = tuple(f"{prefix}_x_{i}" for i in range(len(value_types[0].shape))) + flat_spec = (f"{prefix}_flat",) + group_spec = (BATCHING + f"{prefix}_group",) + scalar_spec = (BATCHING + f"{prefix}_scalar",) + + rowwise_out_spec = flat_spec if q_layout.has_rowwise else scalar_spec + colwise_out_spec = flat_spec if q_layout.has_colwise else scalar_spec + + if ScalingMode(scaling_mode).is_block_scaling: + rowwise_scale_spec = flat_spec if q_layout.has_rowwise else scalar_spec + colwise_scale_spec = flat_spec if q_layout.has_colwise else scalar_spec + elif ScalingMode(scaling_mode).is_tensor_scaling(): + rowwise_scale_spec = group_spec if q_layout.has_rowwise else scalar_spec + colwise_scale_spec = group_spec if q_layout.has_colwise else scalar_spec + else: + rowwise_scale_spec = scalar_spec + colwise_scale_spec = scalar_spec + + return SdyShardingRule( + operand_mappings=(input_spec, group_spec, group_spec), + result_mappings=( + rowwise_out_spec, + colwise_out_spec, + rowwise_scale_spec, + colwise_scale_spec, + group_spec, + ), + ) + register_primitive(GroupedQuantizePrimitive) diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index f8c30ffccb..52616da994 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -14,9 +14,11 @@ import warnings import jax import jax.numpy as jnp +from jax.sharding import PartitionSpec from . import cpp_extensions as tex from .cpp_extensions.amax import AmaxScope +from .sharding import global_mesh_resource, get_mesh_axis_size, with_sharding_constraint from .quantize import ( ScaledTensor, QuantizerSet, @@ -54,6 +56,20 @@ def _psum_scatter_kernel(kernel, scattered_kernel_shape, mesh_axis, axis_idx): return kernel +def _is_manual_mesh_axis(mesh_axis): + return mesh_axis is not None and mesh_axis in jax.sharding.get_abstract_mesh().manual_axes + + +def _kernel_non_contracting_axis_to_bias_axis(kernel_axis_idx, kernel_contracting_dims): + if kernel_axis_idx in kernel_contracting_dims: + return None + bias_axis_idx = 1 + for dim in range(1, kernel_axis_idx): + if dim not in kernel_contracting_dims: + bias_axis_idx += 1 + return bias_axis_idx + + def dense( x: jnp.ndarray, kernel: jnp.ndarray, @@ -349,6 +365,18 @@ def grouped_dense( Returns: A jnp.ndarray containing the result of the grouped linear operation """ + x_contracting_dims, kernel_contracting_dims = contracting_dims + x_contracting_dims = tex.sanitize_dims(x.ndim, x_contracting_dims) + kernel_contracting_dims = tex.sanitize_dims(kernel.ndim, kernel_contracting_dims) + contracting_dims = (x_contracting_dims, kernel_contracting_dims) + + restore_leading_ep_axis = False + if x.ndim == 3 and x.shape[0] == 1: + if x_contracting_dims == (x.ndim - 1,): + restore_leading_ep_axis = True + x = x.reshape(*x.shape[1:]) + contracting_dims = ((x.ndim - 1,), kernel_contracting_dims) + output = _grouped_dense( x, kernel, @@ -361,6 +389,8 @@ def grouped_dense( quantizer_set, kernel_fsdp_info, ) + if restore_leading_ep_axis: + output = output.reshape(1, *output.shape) return output @@ -406,12 +436,33 @@ def _grouped_dense_fwd_rule( ): use_bias = bias is not None + x_contracting_dims, k_contracting_dims = contracting_dims + local_kernel_shape = kernel.shape + kernel_was_gathered = False + bias_shape = bias.shape if use_bias else None + bias_fsdp_axis_idx = -1 + bias_was_gathered = False + kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx = kernel_fsdp_info - kernel_fsdp_enabled = kernel_fsdp_mesh_axis is not None - assert not kernel_fsdp_enabled, "FSDP sharding for grouped_dense is not supported yet." - del kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx, kernel_fsdp_info, kernel_fsdp_enabled + if _is_manual_mesh_axis(kernel_fsdp_mesh_axis) and 0 < kernel_fsdp_axis_idx < kernel.ndim: + kernel = _all_gather_kernel(kernel, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx) + kernel_was_gathered = True + + if use_bias and kernel_fsdp_axis_idx not in k_contracting_dims: + bias_fsdp_axis_idx = _kernel_non_contracting_axis_to_bias_axis( + kernel_fsdp_axis_idx, k_contracting_dims + ) + mesh_axis_size = get_mesh_axis_size(kernel_fsdp_mesh_axis) + if ( + bias_fsdp_axis_idx is not None + and 0 < bias_fsdp_axis_idx < bias.ndim + and mesh_axis_size > 1 + and bias.shape[bias_fsdp_axis_idx] * mesh_axis_size + == kernel.shape[kernel_fsdp_axis_idx] + ): + bias = _all_gather_kernel(bias, kernel_fsdp_mesh_axis, bias_fsdp_axis_idx) + bias_was_gathered = True - x_contracting_dims, k_contracting_dims = contracting_dims flatten_axis_x = -len(x_contracting_dims) flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) + 1 # +1 for G axis @@ -467,10 +518,14 @@ def _grouped_dense_fwd_rule( else ctx_kernel ), x.shape, - kernel.shape, + local_kernel_shape, use_bias, quantizer_set, flatten_axis_k, + kernel_was_gathered, + bias_shape, + bias_fsdp_axis_idx, + bias_was_gathered, ) return output, ctx @@ -478,9 +533,7 @@ def _grouped_dense_fwd_rule( def _grouped_dense_bwd_rule( contracting_dims, precision, preferred_element_type, group_offset, kernel_fsdp_info, ctx, grad ): - kernel_fsdp_mesh_axis, _ = kernel_fsdp_info - kernel_fsdp_enabled = kernel_fsdp_mesh_axis is not None - assert not kernel_fsdp_enabled, "FSDP sharding for grouped_dense is not supported yet." + kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx = kernel_fsdp_info fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims @@ -493,6 +546,10 @@ def _grouped_dense_bwd_rule( use_bias, quantizer_set, flatten_axis_k, + kernel_was_gathered, + bias_shape, + bias_fsdp_axis_idx, + bias_was_gathered, ) = ctx # The 1 in range is for excluding the group dimension (shall we use the hardcoded results below?) @@ -530,6 +587,14 @@ def _grouped_dense_bwd_rule( preferred_element_type=preferred_element_type, group_offset=group_offset, ) + if _is_manual_mesh_axis(kernel_fsdp_mesh_axis) and not kernel_was_gathered: + if kernel_fsdp_axis_idx in fwd_k_contracting_dims: + dgrad_axis_idx = fwd_x_contracting_dims[ + fwd_k_contracting_dims.index(kernel_fsdp_axis_idx) + ] + dgrad = _all_gather_kernel(dgrad, kernel_fsdp_mesh_axis, dgrad_axis_idx) + else: + dgrad = jax.lax.psum(dgrad, kernel_fsdp_mesh_axis) wgrad = tex.grouped_gemm( wgrad_x_T, @@ -539,9 +604,34 @@ def _grouped_dense_bwd_rule( preferred_element_type=preferred_element_type, group_offset=group_offset, ) + if _is_manual_mesh_axis(kernel_fsdp_mesh_axis): + if kernel_was_gathered: + wgrad = _psum_scatter_kernel( + wgrad, kernel_shape, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx + ) + elif kernel_fsdp_axis_idx in fwd_k_contracting_dims: + wgrad = _psum_scatter_kernel( + wgrad, kernel_shape, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx + ) + else: + wgrad = jax.lax.psum(wgrad, kernel_fsdp_mesh_axis) + if kernel_fsdp_mesh_axis is not None: + wgrad_spec = [None] * len(kernel_shape) + ep_resource = None + try: + ep_resource = global_mesh_resource().ep_resource + except AssertionError: + pass + if len(wgrad_spec) > 0: + wgrad_spec[0] = ep_resource + if 0 <= kernel_fsdp_axis_idx < len(wgrad_spec): + wgrad_spec[kernel_fsdp_axis_idx] = kernel_fsdp_mesh_axis + wgrad = with_sharding_constraint(wgrad, PartitionSpec(*wgrad_spec)) group_sizes_grad = None dbias = tex.grouped_dbias(grad, group_sizes) if use_bias else None + if dbias is not None and _is_manual_mesh_axis(kernel_fsdp_mesh_axis) and bias_was_gathered: + dbias = _psum_scatter_kernel(dbias, bias_shape, kernel_fsdp_mesh_axis, bias_fsdp_axis_idx) return dgrad, wgrad, group_sizes_grad, dbias, quantizer_set diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 17c9a242f0..14783ecbe2 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -1471,7 +1471,7 @@ def te_grouped_dot_general(generate_quantizer_set, x, kernel, group_sizes, **kwa x, kernel, group_sizes=group_sizes, - contracting_dims=((1,), (1,)), + contracting_dims=((-1,), (1,)), quantizer_set=quantizer_set, ) return out diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 9b13412c14..8dffb71196 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -133,8 +133,8 @@ def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec): """ A wrapper function to jax.lax.with_sharding_constraint 1. Does nothing if mesh is empty. - 2. If all mesh axes are manual axes, replaces pspec with all Nones. - 3. Otherwise, strips only the manual axes. + 2. Keeps only auto axes in pspec. + 3. Returns x unchanged if no auto axes remain. """ if pspec is None: return x @@ -143,22 +143,21 @@ def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec): if mesh.empty: return x - # We want to exclude the axes that already used by shard_map and shard_map - # only sets those in the abstract_mesh, not the physical one - manual_axis_names = get_abstract_mesh().manual_axes + # with_sharding_constraint can only refer to auto axes. Explicit axes are + # already fixed by the active mesh, and manual axes are managed by shard_map. + abstract_mesh = get_abstract_mesh() + auto_axis_names = set(abstract_mesh.auto_axes) # Multiple mesh axes can be mapped to a single shape axis, so we need to unpack and process tuples here too - def filter_manual_axes(name_or_tuple): + def filter_non_auto_axes(name_or_tuple): if isinstance(name_or_tuple, tuple): - out = tuple(n for n in name_or_tuple if n not in manual_axis_names) + out = tuple(n for n in name_or_tuple if n in auto_axis_names) if len(out) == 0: return None return out - if name_or_tuple in manual_axis_names: - return None - return name_or_tuple + return name_or_tuple if name_or_tuple in auto_axis_names else None - cleaned_axis_names = tuple(filter_manual_axes(name_or_tuple) for name_or_tuple in pspec) + cleaned_axis_names = tuple(filter_non_auto_axes(name_or_tuple) for name_or_tuple in pspec) if cleaned_axis_names == (None,) * len(cleaned_axis_names): return x @@ -330,6 +329,7 @@ class MeshResource: tp_resource: Axis name for tensor parallelism (hidden dimension sharding), default is None tpsp_resource: Axis name for tensor sequence parallelism (hidden and sequence sharding), default is None fsdp_resource: Axis name for full-sharded data parallelism, default is None + ep_resource: Axis name for expert parallelism (expert sharding), default is None pp_resource: Axis name for pipeline parallelism (layer sharding), default is None cp_resource: Axis name for context parallelism (sequence sharding), default is None """ @@ -338,6 +338,7 @@ class MeshResource: tp_resource: str = None tpsp_resource: str = None fsdp_resource: str = None + ep_resource: str = None pp_resource: str = None cp_resource: str = None @@ -364,7 +365,7 @@ def global_shard_guard(resource: MeshResource): _GLOBAL_MESH_RESOURCE = old_resources -def global_mesh_resource() -> MeshResource: +def global_mesh_resource(validate: bool = True) -> MeshResource: """Get the current global mesh resource configuration. Returns: @@ -375,7 +376,8 @@ def global_mesh_resource() -> MeshResource: " context. If you are not using multiple GPUs, you can use an empty MeshResource by" " wrapping your program in 'with global_shard_guard(MeshResource()):'" ) - _validate_mesh_resource_configuration(_GLOBAL_MESH_RESOURCE) + if validate: + _validate_mesh_resource_configuration(_GLOBAL_MESH_RESOURCE) return _GLOBAL_MESH_RESOURCE