diff --git a/tests/pytorch/test_torch_compile.py b/tests/pytorch/test_torch_compile.py index 51f72b1e56..439bae84d7 100644 --- a/tests/pytorch/test_torch_compile.py +++ b/tests/pytorch/test_torch_compile.py @@ -97,8 +97,8 @@ def __fx_repr__(self): def _make_qfactory(tag: str): """Return a qfactory that produces ToyQuantizer instances tagged with *tag*.""" - def qfactory(role: str): - return ToyQuantizer(tag=f"{tag}:{role}") + def qfactory(role): + return ToyQuantizer(tag=f"{tag}:{role.tensor_type}") return qfactory @@ -324,3 +324,139 @@ def fn(inp): out = compiled(inp) out.sum().backward() + + +@pytest.mark.parametrize( + "fp8_recipe", + [None, *_all_recipes], + ids=lambda r: "bf16" if r is None else type(r).__name__, +) +def test_te_linear_compiles(fp8_recipe): + """torch.compile(fullgraph=True) of ``te.Linear`` under every built-in + recipe (and the bf16-only baseline with no autocast). + + Exercises the custom-op path in + :mod:`transformer_engine.pytorch.dynamo`: forward goes through + ``_linear_compiled_op``, backward through the registered + ``transformer_engine::linear_backward`` op, and the dataclass + arg-objects are packed/unpacked via the bucket dispatch in + :mod:`transformer_engine.pytorch.dynamo`. + """ + if fp8_recipe is not None and not fp8_available: + pytest.skip(reason_for_no_fp8) + + dtype = torch.bfloat16 + device = "cuda" + + # FP8 GEMMs require leading dimensions divisible by 16; pick + # in/out features and batch comfortably above that minimum. + model = te.Linear(64, 32, params_dtype=dtype, device=device) + inp = torch.randn(32, 64, dtype=dtype, device=device, requires_grad=True) + + def fn(inp): + if fp8_recipe is None: + return model(inp) + with te.autocast(recipe=fp8_recipe): + return model(inp) + + torch._dynamo.reset() + compiled = torch.compile(fn, fullgraph=True) + + out = compiled(inp) + out.sum().backward() + assert out.shape == (32, 32) + assert inp.grad is not None + assert model.weight.grad is not None, "weight.grad missing" + assert model.weight.grad.shape == model.weight.shape, ( + f"weight.grad shape {tuple(model.weight.grad.shape)} != " + f"weight shape {tuple(model.weight.shape)}" + ) + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +def test_te_linear_compile_with_quantized_fp8_weight(): + """torch.compile should handle Linear weights initialized as FP8 tensors.""" + dtype = torch.bfloat16 + device = "cuda" + fp8_recipe = recipe.Float8CurrentScaling() + + with te.quantized_model_init(enabled=True, recipe=fp8_recipe): + model = te.Linear(64, 32, params_dtype=dtype, device=device) + + assert isinstance(model.weight, te.Float8Tensor) + inp = torch.randn(32, 64, dtype=dtype, device=device, requires_grad=True) + + def fn(inp): + with te.autocast(recipe=fp8_recipe): + return model(inp) + + torch._dynamo.reset() + compiled = torch.compile(fn, fullgraph=True) + + out = compiled(inp) + out.sum().backward() + assert out.shape == (32, 32) + assert inp.grad is not None + assert model.weight.grad is not None, "Float8Tensor weight.grad missing" + assert model.weight.grad.shape == model.weight.shape, ( + f"Float8Tensor weight.grad shape {tuple(model.weight.grad.shape)} != " + f"weight shape {tuple(model.weight.shape)}" + ) + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +def test_te_linear_compile_with_fp8_output(): + """torch.compile of ``te.Linear(..., fp8_output=True)``: forward returns + a :class:`Float8Tensor`. + + Exercises the output-rewrap path in + :mod:`transformer_engine.pytorch.dynamo`: the first user output is + declared ``Union[torch.Tensor, Float8Tensor]`` in ``output_annotations``, + and when an output quantizer is active the eager + fake paths must + rewrap the inner data tensors back into a ``Float8Tensor`` for the + user-facing slot. + + Backward through a subclass return value is a known PyTorch + ``torch.compile`` limitation (Dynamo / AOT autograd drop the + ``grad_fn`` on wrapper-subclass outputs of custom ops, so + ``out.sum().backward()`` errors with "element 0 of tensors does + not require grad and does not have a grad_fn"). The forward shape + + type assertions below are sufficient to exercise the rewrap; + grad-routing on FP8 outputs under compile is left as future work. + """ + dtype = torch.bfloat16 + device = "cuda" + fp8_recipe = recipe.Float8CurrentScaling() + + model = te.Linear(64, 32, params_dtype=dtype, device=device) + inp = torch.randn(32, 64, dtype=dtype, device=device, requires_grad=True) + + def fn(inp): + with te.autocast(recipe=fp8_recipe): + return model(inp, fp8_output=True) + + torch._dynamo.reset() + compiled = torch.compile(fn, fullgraph=True) + + out = compiled(inp) + assert isinstance(out, te.Float8Tensor), ( + f"expected Float8Tensor output, got {type(out).__name__}" + ) + assert out.shape == (32, 32) + # The compile-path reassembly rebuilds the wrapper via + # ``__tensor_unflatten__``, whose snapshot-free ``meta`` forces + # ``quantizer=None`` (a live ``ProcessGroup`` / amax-reduction group + # can't survive Dynamo guards). ``make_fake_empty`` stashes the live + # quantizer on the fake template and the reassembly helper restores it, + # so the output must keep a (non-``None``) quantizer rather than losing + # its amax-reduction group. + assert out._quantizer is not None, ( + "FP8 output lost its quantizer (and thus its amax-reduction group) " + "on the torch.compile path" + ) + # Dequantising outside the compiled region exercises the + # ``Float8Tensor`` machinery (scale + data + dtype all wired up + # by the rewrap) on the value returned from the compiled fn. + deq = out.dequantize() + assert deq.shape == (32, 32) + assert deq.dtype == dtype diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index b773a81d1b..97bea190ea 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -654,3 +654,5 @@ def _make_repr(self) -> str: f"qfactory={self.qfactory}, " f"backward_override={self.backward_override}" ) + + diff --git a/transformer_engine/pytorch/constants.py b/transformer_engine/pytorch/constants.py index 2aff4fd8e8..694de2d94e 100644 --- a/transformer_engine/pytorch/constants.py +++ b/transformer_engine/pytorch/constants.py @@ -41,6 +41,34 @@ tex.DType.kBFloat16: torch.bfloat16, } +# Map: TE DType *id* (Python int) -> TE DType enum. Used by +# :func:`canonicalize_te_dtype` to recover the pybind enum from its +# integer id without going through ``tex.DType(int)``, which Dynamo +# cannot trace (pybind11 enum constructor is opaque). +TE_DType_ID_To_TE = { + int(tex.DType.kByte): tex.DType.kByte, + int(tex.DType.kFloat8E4M3): tex.DType.kFloat8E4M3, + int(tex.DType.kFloat8E5M2): tex.DType.kFloat8E5M2, + int(tex.DType.kFloat4E2M1): tex.DType.kFloat4E2M1, + int(tex.DType.kInt32): tex.DType.kInt32, + int(tex.DType.kFloat32): tex.DType.kFloat32, + int(tex.DType.kFloat16): tex.DType.kFloat16, + int(tex.DType.kBFloat16): tex.DType.kBFloat16, +} + + +def canonicalize_te_dtype(dtype): + """Accept either a TE ``DType`` enum or its Python ``int`` id. + + Recipe state keeps dtype ids as Python ``int`` values for cheap, + trace-friendly comparisons. Quantizer objects, however, are passed to + TE's C++ bindings, which expect the pybind ``tex.DType`` enum. + """ + if isinstance(dtype, int): + return TE_DType_ID_To_TE[dtype] + return dtype + + # Cache enum -> int conversions to avoid repeated PyObject lookups. FP8FwdTensorIdx = SimpleNamespace( GEMM1_INPUT=int(tex.FP8FwdTensors.GEMM1_INPUT), diff --git a/transformer_engine/pytorch/dynamo.py b/transformer_engine/pytorch/dynamo.py new file mode 100644 index 0000000000..0c8ec0f378 --- /dev/null +++ b/transformer_engine/pytorch/dynamo.py @@ -0,0 +1,2045 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""torch.compile (Dynamo) integration for TransformerEngine modules.""" +from __future__ import annotations + +import dataclasses +from enum import Enum +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Tuple, + Union, + get_args, + get_origin, + get_type_hints, +) + +import torch + + +__all__ = [ + "OpaqueSimpleMetadata", + "_te_register_custom_op", +] + + +_TE_OP_NAMESPACE = "transformer_engine_compile" +_TE_LIB = torch.library.Library(_TE_OP_NAMESPACE, "FRAGMENT") + + +# Sentinel for ``None`` entries inside the op's flat ``Tensor[]`` return. +# Used by :func:`_te_register_custom_op` to support ``None`` outputs (e.g. +# an FP8 weight workspace returned only on the cache-miss path) on a +# non-nullable schema -- ``Tensor?[]`` returns are not picked up by +# ``torch.library.register_autograd``, so the registered backward never +# attaches a ``grad_fn`` to the op's outputs. +_NONE_SENTINEL_DTYPE = torch.uint8 + + +def _encode_none(t: Optional[torch.Tensor]) -> torch.Tensor: + """Replace ``None`` with a 0-element uint8 sentinel tensor.""" + if t is None: + return torch.empty(0, dtype=_NONE_SENTINEL_DTYPE) + return t + + +def _decode_none(t: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + """Inverse of :func:`_encode_none`.""" + if t is None: + return None + if t.numel() == 0 and t.dtype == _NONE_SENTINEL_DTYPE: + return None + return t + + +# --------------------------------------------------------------------------- # +# Output layout helpers +# --------------------------------------------------------------------------- # +# +# A user output of a TE custom op can be one of: +# * ``None`` -> 1 sentinel slot. +# * plain :class:`torch.Tensor` -> 1 slot. +# * wrapper-subclass tensor with +# ``__tensor_flatten__`` (e.g. +# :class:`Float8Tensor`) -> ``len(inner_names)`` slots. +# * pure-Python class with +# ``_torch_compile_flatten`` (e.g. +# :class:`Float8TensorStorage`) -> ``len(tensors)`` slots. +# +# At op-execution time, :func:`_format_fwd_result` splits each output via +# its flatten protocol and concatenates the inner plain tensors into the +# op's ``Tensor[]`` return. +# +# At call-site time (in :func:`forward_fn` / ``setup_context``), the layout +# for each output is read straight off the forward ``fake_impl``'s fake +# values, which double as reassembly templates (:func:`_template_slot_count` +# / :func:`_template_reassemble`). + + +def _contiguous_stride(shape: Sequence[int]) -> Tuple[int, ...]: + """Row-major contiguous stride for ``shape``. + + Used to fill in the ``stride`` field expected by + ``__tensor_unflatten__`` when rebuilding a wrapper subclass from a + fake template (:func:`_template_reassemble`). + """ + stride: List[int] = [1] * len(shape) + for i in range(len(shape) - 2, -1, -1): + stride[i] = stride[i + 1] * int(shape[i + 1]) + return tuple(stride) + + +# --------------------------------------------------------------------------- # +# Reassembly: rebuild user-facing objects from the op's flat ``Tensor[]``. +# --------------------------------------------------------------------------- # +# +# The forward ``fake_impl`` returns the op's outputs / saved tensors as fake +# values (``make_fake_empty`` wrappers / ``make_empty`` storages / +# ``torch.empty`` plains / aliased forward args / ``None``). Each fake value is +# itself a complete reassembly *template*: it says how many flat slots the real +# value occupies and how to rebuild it. :func:`_flatten_value_into` packs a +# value into slots; the two helpers below are its inverse (slot count + +# rebuild), reading straight off the fake template -- no separate descriptor +# object is materialised. + + +def _template_slot_count(template: Any, *, aliased: bool = False) -> int: + """Flat ``Tensor[]`` slots the real value for ``template`` occupies. + + ``aliased`` arg / ``None`` -> 1 (an :func:`_encode_none` sentinel); a + plain tensor -> 1; a ``make_fake_empty`` subclass -> ``len(inner_names)`` + (from its stamped plan); a storage -> ``len(_torch_compile_flatten())``. + """ + if aliased or template is None: + return 1 + if isinstance(template, torch.Tensor): + plan = getattr(template, _TE_COMPILE_UNFLATTEN_PLAN, None) + if plan is not None: + inner_names, _ = plan + return len(inner_names) + return 1 + flatten = getattr(template, "_torch_compile_flatten", None) + if flatten is not None: + _, _, tensors = flatten() + return len(tensors) + raise TypeError( + f"fake_impl produced an unsupported value of type {type(template).__name__}; " + "expected None / torch.Tensor (plain or make_fake_empty subclass) / " + "a storage exposing _torch_compile_flatten()." + ) + + +def _template_reassemble( + template: Any, + chunk: List[Any], + *, + with_autograd: bool = False, + aliased: bool = False, +) -> Any: + """Rebuild the user-facing value for ``template`` from real slots ``chunk``. + + Inverse of :func:`_flatten_value_into`, driven by the fake template: an + ``aliased`` arg / ``None`` -> ``None`` (aliases are resolved by the + caller's ``setup_context`` from the alias name); a plain tensor -> + ``chunk[0]``; a ``make_fake_empty`` subclass -> ``__tensor_unflatten__`` + (routed through :class:`_ToSubclassFn` when ``with_autograd`` so the wrap + stays on the autograd graph); a storage -> ``_torch_compile_do_unflatten``. + """ + if aliased or template is None: + return None + if isinstance(template, torch.Tensor): + plan = getattr(template, _TE_COMPILE_UNFLATTEN_PLAN, None) + if plan is not None: + inner_names, meta = plan + shape = tuple(template.shape) + stride = _contiguous_stride(shape) + if with_autograd: + result = _ToSubclassFn.apply( + type(template), inner_names, meta, shape, stride, *chunk + ) + else: + inner_dict = dict(zip(inner_names, chunk)) + result = type(template).__tensor_unflatten__( + inner_dict, meta, shape, stride + ) + # ``__tensor_unflatten__`` rebuilds with ``quantizer=None`` (the + # snapshot can't carry a live ``ProcessGroup``); restore the live + # quantizer the fake template stashed so the output keeps its + # amax-reduction group. + quantizer = getattr(template, "_te_compile_quantizer", None) + if quantizer is not None: + result._quantizer = quantizer + return result + return chunk[0] + flatten = getattr(template, "_torch_compile_flatten", None) + if flatten is not None: + meta, pg, _ = flatten() + real_tensors = [t for t in chunk if t is not None] + return type(template)._torch_compile_do_unflatten(meta, pg, real_tensors) + raise TypeError( + f"fake_impl produced an unsupported value of type {type(template).__name__}; " + "expected None / torch.Tensor (plain or make_fake_empty subclass) / " + "a storage exposing _torch_compile_flatten()." + ) + + +def _split_fwd_fake_result( + result: Tuple[Any, ...], +) -> Tuple[List[Any], List[Any], Dict[str, Any]]: + """Slice a forward ``fake_impl`` return into ``(user_fakes, saved_fakes, ctx_attrs)``. + + ``result`` has the eager-impl tuple shape ``(*user_outputs, + tensors_to_save, tensor_objects, ctx_attrs)``; the fake values double as + reassembly templates for :func:`_template_slot_count` / + :func:`_template_reassemble`. + """ + num_outputs = len(result) - _FWD_TRAILING_SLOTS + saved = result[num_outputs] + ctx_attrs = result[num_outputs + 2] + user_fakes = list(result[:num_outputs]) + saved_fakes = list(saved) if saved is not None else [] + ctx_attrs = dict(ctx_attrs) if ctx_attrs else {} + return user_fakes, saved_fakes, ctx_attrs + + +# --------------------------------------------------------------------------- # +# ``fake_impl`` consumers. +# +# A module describes its forward op outputs directly as a ``fwd_fake_impl`` +# that returns the same ``(*user_outputs, tensors_to_save, tensor_objects, +# ctx_attrs)`` tuple as the eager ``fwd_impl``, but built out of *fake* +# values: +# * ``quantizer.make_fake_empty(...)`` -- Dynamo-safe quantized wrapper. +# * ``quantizer.make_empty(...)`` -- quantized storage. +# * ``torch.empty(...)`` -- plain tensor. +# * the actual forward-arg tensor -- an aliased saved slot. +# * ``None`` -- absent output / saved slot. +# These fake values are the single source of truth for the op's layout: +# * ``forward_fn`` / ``setup_context`` reassemble the real flat ``Tensor[]`` +# using the fakes as templates (:func:`_template_slot_count` / +# :func:`_template_reassemble`), resolving aliased saved slots via +# :func:`_alias_name_for`. +# * :func:`_fwd_register_fake_from_fake_impl` wires the same callable as the +# op's ``register_fake`` (aliased saved slots nulled so the fake flat +# ``Tensor[]`` layout matches the eager impl, which writes ``None`` for +# aliases). +# The backward ``bwd_fake_impl`` is used directly as the backward +# ``register_fake`` -- backward grads never round-trip through the op +# payload, so no reassembly is needed. +# --------------------------------------------------------------------------- # + +# Attribute stamped on ``make_fake_empty`` outputs carrying the +# ``(inner_names, meta)`` plan needed to rebuild the subclass via +# ``__tensor_unflatten__``. The adapter reads it back (as a Dynamo +# constant) instead of calling ``value.__tensor_flatten__()`` in-trace: +# a tensor method returning non-tensors graph-breaks under fullgraph, +# whereas a plain attribute read is inlined. +_TE_COMPILE_UNFLATTEN_PLAN = "_te_compile_unflatten_plan" + + +def _fwd_arg_alias_pairs(fwd_obj: Any, field_names: Sequence[str]) -> List[Tuple[torch.Tensor, str]]: + """Collect ``(tensor field value, field name)`` for a fwd-arg object. + + ``field_names`` is precomputed outside the trace (reading + ``dataclasses.fields`` in-trace would graph-break on the class + ``mappingproxy``); attribute access by name is inlined. Used to + detect saved slots that alias a forward arg by identity (``is``). + """ + pairs: List[Tuple[torch.Tensor, str]] = [] + for name in field_names: + value = getattr(fwd_obj, name, None) + if isinstance(value, torch.Tensor): + pairs.append((value, name)) + return pairs + + +def _alias_name_for(value: Any, pairs: List[Tuple[torch.Tensor, str]]) -> Optional[str]: + """Return the forward-arg name ``value`` aliases (by ``is``), else ``None``.""" + for tensor, name in pairs: + if value is tensor: + return name + return None + + +def _fwd_register_fake_from_fake_impl( + fwd_fake_impl: Callable[[Any], Tuple[Any, ...]], + field_names: Sequence[str], +) -> Callable[[Any], Tuple[Any, ...]]: + """Adapt a forward ``fake_impl`` into a ``register_fake`` kernel. + + The user's ``fake_impl`` returns the *actual* forward-arg tensor for + aliased saved slots; the eager impl instead writes ``None`` there + (the value rides along as a ctx alias, not through the op payload). + Aliased saved slots are nulled here so the fake flat ``Tensor[]`` + layout stays identical to the eager impl. + """ + + def fwd_fake(fwd_obj: Any) -> Tuple[Any, ...]: + result = fwd_fake_impl(fwd_obj) + num_outputs = len(result) - _FWD_TRAILING_SLOTS + user_outputs = result[:num_outputs] + saved = result[num_outputs] + if saved is None: + tensors_to_save: Any = None + else: + pairs = _fwd_arg_alias_pairs(fwd_obj, field_names) + tensors_to_save = tuple( + None if _alias_name_for(v, pairs) is not None else v for v in saved + ) + return (*user_outputs, tensors_to_save, None, None) + + return fwd_fake + + +class _ToSubclassFn(torch.autograd.Function): + """Construct a wrapper-subclass tensor from its inner plain tensors, + preserving autograd flow through ``__tensor_unflatten__``. + + Non-tensor args (``cls``, ``inner_names``, ``meta``, ``outer_shape``, + ``outer_stride``) are static constants. Dynamo / AOT capture them as + constants on the autograd.Function node; the variadic ``inner_tensors`` + are real / fake graph tensors emitted by the underlying custom op. + """ + + @staticmethod + def forward(ctx, cls, inner_names, meta, outer_shape, outer_stride, *inner_tensors): + """Reassemble ``cls`` from ``inner_tensors`` via ``__tensor_unflatten__``.""" + ctx.inner_names = inner_names + ctx.num_inner = len(inner_tensors) + inner_dict = dict(zip(inner_names, inner_tensors)) + return cls.__tensor_unflatten__(inner_dict, meta, outer_shape, outer_stride) + + @staticmethod + def backward(ctx, grad_output): + """Route ``grad_output`` back to its per-inner-name slots.""" + # Under AOTAutograd, ``grad_output`` typically arrives flattened + # via the subclass machinery; under eager it may be the subclass + # itself. Both paths support ``__tensor_flatten__``-driven routing. + if grad_output is not None and hasattr(grad_output, "__tensor_flatten__"): + names_in_grad, _ = grad_output.__tensor_flatten__() + grad_by_name = {n: getattr(grad_output, n) for n in names_in_grad} + grads = tuple(grad_by_name.get(n) for n in ctx.inner_names) + else: + # Fallback: route the lone grad to the first inner slot; the + # remaining slots (typically derived quantities like scale) + # get ``None``. + grads = (grad_output,) + (None,) * (ctx.num_inner - 1) + # First five returns correspond to the five leading non-tensor args + # to ``forward`` (``cls``, ``inner_names``, ``meta``, ``shape``, + # ``stride``); none of them carries a gradient. + return (None, None, None, None, None) + grads + + +# --------------------------------------------------------------------------- # +# OpaqueSimpleMetadata +# --------------------------------------------------------------------------- # + +class OpaqueSimpleMetadata: + """Opaque value-type bundle of simple Python values. + + Wraps a ``{name: value}`` dict so that many small non-Tensor arguments + of a TE custom op can be passed as a single op input. Registered as a + torch.compile *value* opaque type, meaning Dynamo specializes the + traced graph on the bundle's contents: ``__eq__`` installs a guard, + and any change to a wrapped value triggers a recompile. + + Allowed value types: primitives in :attr:`PRIMITIVE_TYPES`, + :class:`enum.Enum`, :class:`torch.Size`, plus arbitrarily nested + tuples/lists thereof. + """ + + # Primitive Python types we are willing to bundle into a single op + # input. The bundle is registered as a torch.compile *value* opaque + # type, so its contents must be hashable, comparable for equality, + # and round-trippable through ``__fx_repr__``. + PRIMITIVE_TYPES: Tuple[type, ...] = ( + type(None), + bool, + int, + float, + str, + torch.dtype, + torch.device, + ) + + @classmethod + def _is_opaque_value(cls, value: Any) -> bool: + """Whether ``value``'s class is registered as a value-opaque type. + """ + return _is_opaque_value_type(type(value)) + + @classmethod + def is_simple_value(cls, value: Any) -> bool: + """Whether ``value`` is allowed inside an instance. + + Accepts simple primitives (see :attr:`PRIMITIVE_TYPES`), + :class:`enum.Enum`, :class:`torch.Size`, instances of any class + registered as a torch.compile *value*-opaque type (the latter + already supplies ``__eq__`` / ``__hash__`` / ``__fx_repr__`` as + a registration prerequisite), and arbitrarily nested + tuples / lists thereof. + """ + if isinstance(value, cls.PRIMITIVE_TYPES): + return True + if isinstance(value, Enum): + return True + if isinstance(value, torch.Size): + return True + if cls._is_opaque_value(value): + return True + if isinstance(value, (list, tuple)): + return all(cls.is_simple_value(v) for v in value) + return False + + @classmethod + def _to_hashable(cls, value: Any) -> Any: + """Convert a simple value into something hashable (lists -> tuples).""" + if isinstance(value, (list, tuple, torch.Size)): + return tuple(cls._to_hashable(v) for v in value) + # Opaque-value instances already supply ``__hash__`` (required + # by registration) so they can stay as-is. + return value + + @classmethod + def _fmt_simple(cls, value: Any) -> str: + """Repr for a simple value, evaluable in a context with ``torch`` globals.""" + if isinstance(value, torch.dtype): + return f"__import__('torch').{str(value).split('.')[-1]}" + if isinstance(value, torch.device): + return f"__import__('torch').device({str(value)!r})" + if isinstance(value, torch.Size): + return f"__import__('torch').Size({list(value)!r})" + if isinstance(value, Enum): + return f"{type(value).__name__}.{value.name}" + if isinstance(value, list): + return "[" + ", ".join(cls._fmt_simple(v) for v in value) + "]" + if isinstance(value, tuple): + body = ", ".join(cls._fmt_simple(v) for v in value) + return f"({body},)" if len(value) == 1 else f"({body})" + if cls._is_opaque_value(value): + # Opaque-value types declare their FX reconstruction via + # ``__fx_repr__``; just splice their expression in here. + return value.__fx_repr__()[0] + return repr(value) + + def __init__(self, data: Optional[Dict[str, Any]] = None) -> None: + data = dict(data) if data else {} + cls = type(self) + for k, v in data.items(): + if not cls.is_simple_value(v): + raise TypeError( + f"OpaqueSimpleMetadata field '{k}' has unsupported " + f"type {type(v).__name__}; only simple primitives " + f"({', '.join(t.__name__ for t in cls.PRIMITIVE_TYPES)}, " + f"Enum, torch.Size, registered torch.compile value-" + f"opaque types) and tuples/lists thereof are allowed." + ) + self._data: Dict[str, Any] = data + self._frozen: Tuple[Tuple[str, Any], ...] = tuple( + (k, cls._to_hashable(v)) for k, v in sorted(data.items()) + ) + + def __getitem__(self, key: str) -> Any: + return self._data[key] + + def __getattr__(self, name: str) -> Any: + # Only called when normal attribute lookup fails, so ``_data`` / + # ``_frozen`` won't recurse here once set in ``__init__``. + try: + return self._data[name] + except KeyError as e: + raise AttributeError(name) from e + + def __contains__(self, key: str) -> bool: + return key in self._data + + def keys(self) -> List[str]: + return list(self._data.keys()) + + def values(self) -> List[Any]: + return list(self._data.values()) + + def items(self) -> List[Tuple[str, Any]]: + return list(self._data.items()) + + def get(self, key: str, default: Any = None) -> Any: + return self._data.get(key, default) + + def as_dict(self) -> Dict[str, Any]: + return dict(self._data) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, OpaqueSimpleMetadata): + return NotImplemented + return self._frozen == other._frozen + + def __hash__(self) -> int: + return hash(self._frozen) + + def __fx_repr__(self) -> Tuple[str, Dict[str, Any]]: + cls = type(self) + items = ", ".join( + f"{k!r}: {cls._fmt_simple(v)}" for k, v in self._data.items() + ) + # Collect every type referenced by a nested opaque-value's + # ``__fx_repr__`` so the FX codegen can resolve those names. + globals_: Dict[str, Any] = { + "OpaqueSimpleMetadata": OpaqueSimpleMetadata, + } + + def _collect(value: Any) -> None: + if isinstance(value, (list, tuple)): + for v in value: + _collect(v) + return + # Skip plain Python / torch primitives up-front: they're + # rendered as literals by ``_fmt_simple`` and need no + # globals entry. + if isinstance(value, cls.PRIMITIVE_TYPES): + return + if isinstance(value, torch.Size): + return + if isinstance(value, Enum): + # ``_fmt_simple`` emits ``EnumName.MEMBER``; the Enum + # class must be in scope when the source string is + # later ``exec``d (e.g. by ``GraphModule.print_readable`` + # or by Inductor's runtime wrapper). + t = type(value) + globals_[t.__name__] = t + return + if cls._is_opaque_value(value): + _, extra = value.__fx_repr__() + globals_.update(extra) + + for v in self._data.values(): + _collect(v) + return (f"OpaqueSimpleMetadata({{{items}}})", globals_) + + def __repr__(self) -> str: + # ``__repr__`` is on hot diagnostic paths (Inductor error + # formatters, FX node printers, ...) and must never raise: + # treating any embedded value's ``repr`` failure as a soft + # placeholder keeps those error reporters from masking the + # actual root-cause exception with a crash inside our repr. + parts: List[str] = [] + for k, v in self._data.items(): + try: + v_repr = repr(v) + except Exception as e: # pylint: disable=broad-except + v_repr = f"<{type(v).__name__}: repr failed: {e!s}>" + parts.append(f"{k!r}: {v_repr}") + return f"OpaqueSimpleMetadata({{{', '.join(parts)}}})" + + +# Register OpaqueSimpleMetadata as a torch.compile value-opaque type, and +# resolve the schema name of ``torch.distributed.ProcessGroup`` (registered +# upstream as a *reference* opaque type via +# ``torch.distributed.device_mesh._register_distributed_opaque_types``). +# Both are done at module import so that any TE op declared via +# ``_te_register_custom_op`` can immediately reference them in its schema. +# Older PyTorch versions without these APIs are tolerated: the eager path +# keeps working, only torch.compile tracing of TE custom ops is unavailable. +try: + from torch._library.opaque_object import ( + get_opaque_type_name, + is_opaque_value_type as _is_opaque_value_type, + register_opaque_type, + ) + + register_opaque_type(OpaqueSimpleMetadata, typ="value") + _OPAQUE_SIMPLE_META_TYPE_NAME: Optional[str] = get_opaque_type_name( + OpaqueSimpleMetadata + ) + + _PROCESS_GROUP_TYPE_NAME: Optional[str] = None + try: + from torch.distributed import ProcessGroup + from torch.distributed.device_mesh import ( + _register_distributed_opaque_types, + ) + + _register_distributed_opaque_types() + _PROCESS_GROUP_TYPE_NAME = get_opaque_type_name(ProcessGroup) + except Exception: # pragma: no cover - distributed not built / disabled + _PROCESS_GROUP_TYPE_NAME = None +except Exception: # pragma: no cover - older torch without opaque_object + _is_opaque_value_type = None + _OPAQUE_SIMPLE_META_TYPE_NAME = None + _PROCESS_GROUP_TYPE_NAME = None + + +# --------------------------------------------------------------------------- # +# Field buckets +# --------------------------------------------------------------------------- # +# +# Each dataclass field is mapped to exactly one bucket that owns its +# schema slots and the pack/unpack logic between the dataclass attribute +# and the flat ``torch.library`` view. Concrete bucket types are defined +# below; the per-class docstrings describe what each one matches. + + +def _strip_optional(annot: Any) -> Tuple[Any, bool]: + """If ``annot`` is ``Optional[X]`` return ``(X, True)``; else ``(annot, False)``. + + Shared by all bucket matchers below. + """ + if get_origin(annot) is Union: + args = get_args(annot) + if type(None) in args: + non_none = [a for a in args if a is not type(None)] + if len(non_none) == 1: + return non_none[0], True + return annot, False + + +class _Bucket: + """Per-field handler for translating between a dataclass field and the + flat ``{slot_name: slot_value}`` view consumed by ``torch.library``. + + Each concrete bucket owns: + + * a :meth:`try_build` classmethod that decides whether a given + ``(name, annotation)`` pair belongs to this bucket (returning an + instance, or ``None`` to defer to the next bucket); + * the runtime :meth:`schema_slots` / :meth:`pack` / :meth:`unpack` + logic for that field. + + :class:`_SimpleBundleBucket` is the exception: it aggregates many + simple-typed fields into a single op input, so it does not implement + ``try_build``. It exposes :meth:`matches_field` for the per-field + membership test, and is constructed once at the end of dispatch + with the collected names. + """ + + @classmethod + def try_build(cls, name: str, annot: Any) -> Optional["_Bucket"]: + """Return an instance handling ``(name, annot)``, or ``None``.""" + raise NotImplementedError + + def schema_slots(self) -> List[Tuple[str, str]]: + """Return ``[(slot_name, schema_type_str), ...]`` for this field.""" + raise NotImplementedError + + def pack(self, owner: Any) -> List[Tuple[str, Any]]: + """Return ``[(slot_name, value), ...]`` extracted from ``owner``.""" + raise NotImplementedError + + def unpack(self, args: Dict[str, Any], kwargs: Dict[str, Any]) -> None: + """Read this field's slots from ``args`` and write the + reconstructed dataclass attribute(s) into ``kwargs``.""" + raise NotImplementedError + + +class _MetaPGTensorsBucket(_Bucket): + """Shared three-slot bucket emitting ``__meta`` / + ``__pg`` / ``__tensors``. + + Used by every field whose value must be carried as the triple + ``(OpaqueSimpleMetadata, ProcessGroup?, Tensor[])`` -- today this + covers ``Tensor | QuantizedTensorStorage`` unions (see + :class:`_UniversalTensorBucket`) and ``Quantizer`` instances + (see :class:`_FlattenableBucket`). Concrete subclasses + implement :meth:`_pack_value` / :meth:`_unpack_value` for their + flatten/unflatten protocol; the rest of the bucket contract is + identical and lives here. + """ + + SUFFIX_META = "__meta" + SUFFIX_PG = "__pg" + SUFFIX_TENSORS = "__tensors" + + def __init__(self, name: str) -> None: + if _OPAQUE_SIMPLE_META_TYPE_NAME is None or _PROCESS_GROUP_TYPE_NAME is None: + raise RuntimeError( + f"Field {name!r} requires both OpaqueSimpleMetadata and " + "torch.distributed.ProcessGroup to be registered as " + "torch._library opaque types; one or both are " + "unavailable in this PyTorch build." + ) + self.name = name + + def _slot_meta(self) -> str: + return self.name + self.SUFFIX_META + + def _slot_pg(self) -> str: + return self.name + self.SUFFIX_PG + + def _slot_tensors(self) -> str: + return self.name + self.SUFFIX_TENSORS + + def schema_slots(self) -> List[Tuple[str, str]]: + return [ + (self._slot_meta(), _OPAQUE_SIMPLE_META_TYPE_NAME), + (self._slot_pg(), _PROCESS_GROUP_TYPE_NAME + "?"), + (self._slot_tensors(), "Tensor[]"), + ] + + def pack(self, owner: Any) -> List[Tuple[str, Any]]: + value = getattr(owner, self.name) + meta, pg, tensors = self._pack_value(value) + return [ + (self._slot_meta(), meta), + (self._slot_pg(), pg), + (self._slot_tensors(), list(tensors)), + ] + + def unpack(self, args: Dict[str, Any], kwargs: Dict[str, Any]) -> None: + kwargs[self.name] = self._unpack_value( + args[self._slot_meta()], + args[self._slot_pg()], + args[self._slot_tensors()], + ) + + def _pack_value( + self, value: Any + ) -> Tuple[Any, Any, List[torch.Tensor]]: + """Flatten one field value into ``(meta, pg, tensors)``.""" + raise NotImplementedError + + def _unpack_value( + self, meta: Any, pg: Any, tensors: List[torch.Tensor] + ) -> Any: + """Inverse of :meth:`_pack_value`.""" + raise NotImplementedError + + +class _UniversalTensorBucket(_Bucket): + """``Tensor | QuantizedTensorStorage`` (also subclass-tensor) field. + + Emits four schema slots per field, regardless of the runtime value: + + * ```` (``Tensor?``) -- plain tensor / subclass tensor + (e.g. :class:`Float8Tensor`) + passes through here untouched. + ``None`` for the storage path. + * ``__tensors`` (``Tensor[]``) -- flat inner tensors when the + value was carried through a + flatten protocol (storage at + pack-time, or a subclass that + was dispatched into flat form + by ``register_torch_dispatch`` + on the outer op). + * ``__pg`` (``ProcessGroup?``) -- distributed handle attached + to the flatten metadata, if + any. + * ``__meta`` (``OpaqueSimpleMetadata``) -- everything else: + the storage / subclass meta + dict, plus a ``__kind__`` + marker telling the unpacker + which slot to look at: + ``"none"``, ``"tensor"``, or + ``"storage"`` (the latter + covers both storage and any + already-flattened subclass). + + Storage values are flattened at ``_pack`` time (callsite). Plain + tensors -- including subclass instances -- are passed unchanged + through ````; under ``torch.compile`` an outer-op + ``register_torch_dispatch`` rule turns each registered subclass + into the storage layout *between* outer and inner op so the + autograd graph stays attached to the user-facing wrapper. + """ + + SUFFIX_TENSORS = "__tensors" + SUFFIX_PG = "__pg" + SUFFIX_META = "__meta" + + KIND_KEY = "__kind__" + KIND_NONE = "none" + KIND_TENSOR = "tensor" + KIND_STORAGE = "storage" + + def __init__(self, name: str) -> None: + if _OPAQUE_SIMPLE_META_TYPE_NAME is None or _PROCESS_GROUP_TYPE_NAME is None: + raise RuntimeError( + f"Field {name!r} requires both OpaqueSimpleMetadata and " + "torch.distributed.ProcessGroup to be registered as " + "torch._library opaque types; one or both are " + "unavailable in this PyTorch build." + ) + self.name = name + + def slot_name(self) -> str: + return self.name + + def slot_tensors(self) -> str: + return self.name + self.SUFFIX_TENSORS + + def slot_pg(self) -> str: + return self.name + self.SUFFIX_PG + + def slot_meta(self) -> str: + return self.name + self.SUFFIX_META + + def schema_slots(self) -> List[Tuple[str, str]]: + return [ + (self.slot_name(), "Tensor?"), + (self.slot_tensors(), "Tensor[]"), + (self.slot_pg(), _PROCESS_GROUP_TYPE_NAME + "?"), + (self.slot_meta(), _OPAQUE_SIMPLE_META_TYPE_NAME), + ] + + @staticmethod + def _is_tensor_storage_union(annot: Any) -> bool: + origin = get_origin(annot) + if origin is not Union: + return False + members = [a for a in get_args(annot) if a is not type(None)] + if torch.Tensor not in members: + return False + qts = _quantized_tensor_storage_cls() + if qts is None: + return False + return any( + isinstance(member, type) and issubclass(member, qts) + for member in members + ) + + @classmethod + def try_build(cls, name: str, annot: Any) -> Optional["_UniversalTensorBucket"]: + if cls._is_tensor_storage_union(annot): + return cls(name) + return None + + def pack(self, owner: Any) -> List[Tuple[str, Any]]: + value = getattr(owner, self.name) + if value is None: + return [ + (self.slot_name(), None), + (self.slot_tensors(), []), + (self.slot_pg(), None), + (self.slot_meta(), OpaqueSimpleMetadata({self.KIND_KEY: self.KIND_NONE})), + ] + # Plain ``torch.Tensor`` *and* any subclass (e.g. ``Float8Tensor``) + # hit this branch first -- the wrapper is forwarded untouched + # through the ``Tensor?`` slot so the autograd graph stays + # attached to the user-facing tensor object. Subclass-specific + # flattening (if any) happens later inside the outer op's + # ``register_torch_dispatch`` rule. + if isinstance(value, torch.Tensor): + return [ + (self.slot_name(), value), + (self.slot_tensors(), []), + (self.slot_pg(), None), + (self.slot_meta(), OpaqueSimpleMetadata({self.KIND_KEY: self.KIND_TENSOR})), + ] + qts = _quantized_tensor_storage_cls() + if qts is not None and isinstance(value, qts): + meta, pg, tensors = value._torch_compile_flatten() + # Stamp the storage-flatten meta with our kind marker so the + # unpacker can route by ``__kind__`` alone. + meta._data[self.KIND_KEY] = self.KIND_STORAGE + return [ + (self.slot_name(), None), + (self.slot_tensors(), list(tensors)), + (self.slot_pg(), pg), + (self.slot_meta(), meta), + ] + raise TypeError( + f"field {self.name!r} expected None, torch.Tensor, or " + f"QuantizedTensorStorage, got {type(value).__name__}" + ) + + def unpack(self, args: Dict[str, Any], kwargs: Dict[str, Any]) -> None: + meta = args[self.slot_meta()] + kind = meta.get(self.KIND_KEY) + if kind == self.KIND_NONE: + kwargs[self.name] = None + return + if kind == self.KIND_TENSOR: + kwargs[self.name] = args[self.slot_name()] + return + qts = _quantized_tensor_storage_cls() + kwargs[self.name] = qts._torch_compile_unflatten( + meta, args[self.slot_pg()], args[self.slot_tensors()] + ) + + +class _TensorBucket(_Bucket): + """``Tensor`` / ``Optional[Tensor]`` -> single ``Tensor`` / ``Tensor?`` slot.""" + + def __init__(self, name: str, is_optional: bool) -> None: + self.name = name + self.type_str = "Tensor?" if is_optional else "Tensor" + + @classmethod + def try_build(cls, name: str, annot: Any) -> Optional["_TensorBucket"]: + stripped, is_optional = _strip_optional(annot) + if stripped is torch.Tensor: + return cls(name, is_optional) + return None + + def schema_slots(self) -> List[Tuple[str, str]]: + return [(self.name, self.type_str)] + + def pack(self, owner: Any) -> List[Tuple[str, Any]]: + return [(self.name, getattr(owner, self.name))] + + def unpack(self, args: Dict[str, Any], kwargs: Dict[str, Any]) -> None: + kwargs[self.name] = args[self.name] + + +class _ProcessGroupBucket(_Bucket): + """``ProcessGroup`` / ``Optional[ProcessGroup]`` -> one direct opaque-ref slot. + + PG is registered upstream (in ``torch.distributed.device_mesh``) as + a value-opaque type, so torch.library carries it without help. + """ + + def __init__(self, name: str, is_optional: bool) -> None: + if _PROCESS_GROUP_TYPE_NAME is None: + raise RuntimeError( + f"ProcessGroup field {name!r} requires torch.distributed " + "and the opaque-type registration in " + "torch.distributed.device_mesh; neither is available in " + "this PyTorch build." + ) + self.name = name + self.type_str = _PROCESS_GROUP_TYPE_NAME + ("?" if is_optional else "") + + @classmethod + def try_build(cls, name: str, annot: Any) -> Optional["_ProcessGroupBucket"]: + stripped, is_optional = _strip_optional(annot) + if not isinstance(stripped, type): + return None + try: + from torch.distributed import ProcessGroup + except Exception: # pragma: no cover - distributed not built + return None + if not issubclass(stripped, ProcessGroup): + return None + return cls(name, is_optional) + + def schema_slots(self) -> List[Tuple[str, str]]: + return [(self.name, self.type_str)] + + def pack(self, owner: Any) -> List[Tuple[str, Any]]: + return [(self.name, getattr(owner, self.name))] + + def unpack(self, args: Dict[str, Any], kwargs: Dict[str, Any]) -> None: + kwargs[self.name] = args[self.name] + + +# Cached resolutions of TE types that ``dynamo`` references lazily to +# avoid import cycles (they live in modules that themselves import this +# one). Each ``_*_cls`` getter resolves its target once and reuses the +# result on every subsequent call; the values are kept module-level +# rather than baked into bucket instances so the cache survives across +# different dataclass registrations. +_QTS_REF: Optional[type] = None +_QUANTIZER_REF: Optional[type] = None + + +def _quantized_tensor_storage_cls() -> Optional[type]: + """Lazy-resolve :class:`QuantizedTensorStorage`; ``None`` if unavailable.""" + global _QTS_REF + if _QTS_REF is None: + try: + from transformer_engine.pytorch.quantized_tensor import ( + QuantizedTensorStorage, + ) + + _QTS_REF = QuantizedTensorStorage + except Exception: # pragma: no cover - partial init + return None + return _QTS_REF + + +def _quantizer_cls() -> Optional[type]: + """Lazy-resolve :class:`Quantizer`; ``None`` if unavailable.""" + global _QUANTIZER_REF + if _QUANTIZER_REF is None: + try: + from transformer_engine.pytorch.quantized_tensor import Quantizer + + _QUANTIZER_REF = Quantizer + except Exception: # pragma: no cover - partial init + return None + return _QUANTIZER_REF + + +def _flattenable_bases() -> Tuple[type, ...]: + """Return the list of base classes whose subclasses are routed + through :class:`_FlattenableBucket`. + + A "flattenable" type implements the duck-typed pair + + * instance method ``_flatten() -> (OpaqueSimpleMetadata, ref, list[Tensor])`` + * classmethod ``_unflatten(meta, ref, tensors)`` (dispatches by an + identifier stamped into ``meta``). + """ + return tuple( + cls + for cls in (_quantizer_cls(), _quantized_tensor_storage_cls()) + if cls is not None + ) + + +class _FlattenableBucket(_MetaPGTensorsBucket): + """Field whose type implements the ``_flatten`` / ``_unflatten`` + protocol (see :func:`_flattenable_bases`). Used today for + :class:`~transformer_engine.pytorch.quantized_tensor.Quantizer` and + :class:`~transformer_engine.pytorch.quantized_tensor.QuantizedTensorStorage`. + """ + + # Stored under ``_qcls`` in the metadata bundle to encode ``None`` + # without making any of the three slots nullable. + NONE_MARKER_KEY = "_qcls" + NONE_MARKER_VAL = "" + + def __init__(self, name: str, base_cls: type) -> None: + super().__init__(name) + self.base_cls = base_cls + + @classmethod + def try_build(cls, name: str, annot: Any) -> Optional["_FlattenableBucket"]: + stripped, _ = _strip_optional(annot) + if not isinstance(stripped, type): + return None + for base in _flattenable_bases(): + if issubclass(stripped, base): + return cls(name, base) + return None + + def _pack_value(self, value: Any) -> Tuple[Any, Any, List[torch.Tensor]]: + if value is None: + return ( + OpaqueSimpleMetadata({self.NONE_MARKER_KEY: self.NONE_MARKER_VAL}), + None, + [], + ) + if hasattr(value, "_flatten"): + return value._flatten() + return value._torch_compile_flatten() + + def _unpack_value( + self, meta: Any, pg: Any, tensors: List[torch.Tensor] + ) -> Any: + if meta.get(self.NONE_MARKER_KEY) == self.NONE_MARKER_VAL: + return None + if hasattr(self.base_cls, "_unflatten"): + return self.base_cls._unflatten(meta, pg, tensors) + return self.base_cls._torch_compile_unflatten(meta, pg, tensors) + + +class _SimpleBundleBucket(_Bucket): + """Aggregator: bundles every simple-typed field of the dataclass + into a single :class:`OpaqueSimpleMetadata` slot. + + Does not implement :meth:`try_build` because membership is decided + per-field via :meth:`matches_field`, with construction deferred + until all simple field names are collected. + """ + + SLOT = "_simple_meta" + + def __init__(self, names: List[str]) -> None: + if _OPAQUE_SIMPLE_META_TYPE_NAME is None: + raise RuntimeError( + "OpaqueSimpleMetadata could not be registered with " + "torch._library.opaque_object; cannot bundle simple fields " + f"{names}. Upgrade PyTorch or drop the simple fields." + ) + self.names = list(names) + + @classmethod + def matches_field(cls, annot: Any) -> bool: + """Whether ``annot`` (Optional-aware, recursive on tuple/list) is + bundled-simple, i.e. eligible for this aggregator. + + Accepts simple primitives, :class:`enum.Enum`, :class:`torch.Size`, + any class registered as a torch.compile *value*-opaque type, and + nested tuples / lists thereof. + """ + annot, _ = _strip_optional(annot) + if annot in OpaqueSimpleMetadata.PRIMITIVE_TYPES: + return True + if isinstance(annot, type) and issubclass(annot, Enum): + return True + if annot is torch.Size: + return True + # Any registered value-opaque class is hashable / FX-reproducible + # and therefore safe to embed in the OpaqueSimpleMetadata bundle. + if isinstance(annot, type) and _is_opaque_value_type(annot): + return True + origin = get_origin(annot) + if origin in (tuple, list): + # Inner args may contain Ellipsis (e.g. ``Tuple[int, ...]``); + # only require the *concrete* inner annotations to be simple. + inner = [a for a in get_args(annot) if a is not Ellipsis] + return bool(inner) and all(cls.matches_field(a) for a in inner) + return False + + def schema_slots(self) -> List[Tuple[str, str]]: + return [(self.SLOT, _OPAQUE_SIMPLE_META_TYPE_NAME)] + + def pack(self, owner: Any) -> List[Tuple[str, Any]]: + bundle = OpaqueSimpleMetadata({n: getattr(owner, n) for n in self.names}) + return [(self.SLOT, bundle)] + + def unpack(self, args: Dict[str, Any], kwargs: Dict[str, Any]) -> None: + if self.SLOT not in args: + return + meta = args[self.SLOT] + for n in self.names: + kwargs[n] = meta[n] + + +class _UnknownBucket(_Bucket): + """Fallback for fields whose annotation no other bucket claims. + Emits no schema slot; pack rejects non-trivial values to avoid silent + data loss; unpack restores the field as ``None``. + + A "trivial" value is one that carries no observable information -- + ``None`` itself or a sequence of all-``None`` entries (e.g. a + ``tensor_objects`` payload from :func:`prepare_for_saving` over a + bag of plain bf16 tensors). Such values are dropped on the way into + the op and reconstructed from companion fields (``saved_tensors``, + quantizer metadata, ...) on the way out. + + Constructed directly by :func:`_get_buckets` (it has no + annotation-based ``try_build`` -- it's the explicit "no match" case). + """ + + def __init__(self, name: str, owner_cls_name: str) -> None: + self.name = name + self.owner_cls_name = owner_cls_name + + @staticmethod + def _is_trivial(value: Any) -> bool: + if value is None: + return True + if isinstance(value, (list, tuple)): + return all(v is None for v in value) + return False + + def schema_slots(self) -> List[Tuple[str, str]]: + return [] + + def pack(self, owner: Any) -> List[Tuple[str, Any]]: + value = getattr(owner, self.name, None) + if not self._is_trivial(value): + raise TypeError( + f"{self.owner_cls_name} field {self.name!r} has a type not " + "supported by torch.compile (not Tensor, simple, " + "ProcessGroup, or Quantizer) and carries a non-trivial " + "value; add a matching bucket in dynamo.py to handle it." + ) + return [] + + def unpack(self, args: Dict[str, Any], kwargs: Dict[str, Any]) -> None: + kwargs[self.name] = None + + +# Buckets, in priority order, that own ``try_build`` for a single field. +_FIELD_BUCKETS: Tuple[type, ...] = ( + _UniversalTensorBucket, + _TensorBucket, + _ProcessGroupBucket, + _FlattenableBucket, +) + + +# --------------------------------------------------------------------------- # +# Dataclass <-> torch.library plumbing +# --------------------------------------------------------------------------- # +# +# The helpers below translate a plain ``@dataclass`` argument container +# into the flat ``{slot_name: slot_value}`` view ``torch.library`` works +# with. Each dataclass field is dispatched (by annotation) to one +# :class:`_Bucket`; schema / pack / unpack are then loops over that list. + + +def _resolved_field_annotations(cls: type) -> List[Tuple[str, Any]]: + """Return ``[(field_name, resolved_type), ...]`` for a dataclass.""" + if not dataclasses.is_dataclass(cls): + raise TypeError( + f"{cls.__name__} must be a @dataclass to be used as a TE " + f"custom-op argument container." + ) + # ``get_type_hints`` resolves forward references and PEP 563 + # ``from __future__ import annotations`` strings. + try: + hints = get_type_hints(cls) + except Exception: + hints = {} + return [(f.name, hints.get(f.name, f.type)) for f in dataclasses.fields(cls)] + + +def _get_buckets(cls: type) -> List[_Bucket]: + """Build the bucket list for a dataclass from its field annotations. + + Dispatch order per field: try each bucket in :data:`_FIELD_BUCKETS` + (Tensor, ProcessGroup, Quantizer); if none claims the field, route + it to :class:`_SimpleBundleBucket` if its annotation is bundle-able, + else to :class:`_UnknownBucket`. + + Intentionally **not** cached on ``cls``. Caching there (e.g. by + writing ``cls.__te_buckets__``) tickles Dynamo: subsequent reads of + ``cls.__dict__`` from a compiled function trigger + "mappingproxy affected by dictionary mutation" graph breaks. Hot + paths must instead capture the bucket list once at op registration + time and pass it explicitly to :func:`_pack` / :func:`_unpack`. + """ + buckets: List[_Bucket] = [] + simple_names: List[str] = [] + for name, annot in _resolved_field_annotations(cls): + built: Optional[_Bucket] = None + for bucket_cls in _FIELD_BUCKETS: + built = bucket_cls.try_build(name, annot) + if built is not None: + break + if built is not None: + buckets.append(built) + elif _SimpleBundleBucket.matches_field(annot): + simple_names.append(name) + else: + buckets.append(_UnknownBucket(name, cls.__name__)) + if simple_names: + buckets.append(_SimpleBundleBucket(simple_names)) + return buckets + + +def _build_schema(buckets: List[_Bucket]) -> Tuple[str, List[str]]: + """Return ``(schema_str, slot_names)`` for a precomputed bucket list. + + ``schema_str`` is the parenthesised argument list (e.g. + ``"(Tensor x, Tensor? y)"``) that ``torch.library.Library.define`` + appends to the op name; ``slot_names`` is the ordered list of slot + keys produced by :func:`_pack`, used to flatten/unflatten the + keyword dict into the positional call. + """ + spec = [slot for b in buckets for slot in b.schema_slots()] + names = [name for name, _ in spec] + schema_str = "(" + ", ".join(f"{type_str} {name}" for name, type_str in spec) + ")" + return schema_str, names + + +def _pack(obj: Any, buckets: List[_Bucket]) -> Dict[str, Any]: + """Ask each bucket to extract its slot(s) from ``obj``. + + ``buckets`` is the precomputed bucket list (from :func:`_get_buckets`). + Hot paths -- e.g. the closures created by + :func:`_te_register_custom_op` -- must pass the precomputed list to + avoid recomputing and, critically, to keep Dynamo away from + ``cls.__dict__`` while tracing. + """ + out: Dict[str, Any] = {} + for bucket in buckets: + for name, value in bucket.pack(obj): + out[name] = value + return out + + +def _unpack(cls: type, args: Dict[str, Any], buckets: List[_Bucket]) -> Any: + """Ask each bucket to inject its field(s) into a fresh instance. + + The instance is built via ``cls.__new__(cls)`` (we bypass any + dataclass ``__init__`` so unknown-typed fields can stay as ``None`` + even when they have no default). ``buckets`` semantics match + :func:`_pack`. + """ + kwargs: Dict[str, Any] = {} + for bucket in buckets: + bucket.unpack(args, kwargs) + obj = cls.__new__(cls) + for k, v in kwargs.items(): + object.__setattr__(obj, k, v) + return obj + + +# --------------------------------------------------------------------------- # +# Op registration helpers +# --------------------------------------------------------------------------- # +# +# Per-step building blocks (schema, kernel wrapping, autograd bridge, +# dispatcher) used by :func:`_te_register_custom_op` to turn user-supplied +# eager kernels + dataclass arg types into a ``torch.library`` custom op. + + +def _prepare_for_saving(tensors: Any) -> Tuple[List[Optional[torch.Tensor]], Any]: + """Lazy wrapper around :func:`quantized_tensor.prepare_for_saving`. + + Used only to flatten the user's setup-context return into a + ``(flat_tensors, tensor_objects)`` pair stashed on ``ctx`` for the + backward; the forward output and saved-tensor restoration on the + compile-path go through :func:`_template_reassemble` instead. Lazy-imports + avoid the dynamo<->quantized_tensor circular import that + ``transformer_engine.pytorch`` would otherwise trigger at module + import time. + """ + from transformer_engine.pytorch.quantized_tensor import prepare_for_saving + + return prepare_for_saving(*(tensors or ())) + + +# --------------------------------------------------------------------------- # +# Forward-result packing +# --------------------------------------------------------------------------- # +# +# The op schema is fixed at ``-> Tensor[]``. To return non-tensor +# values (subclass wrappers, ``QuantizedTensorStorage``, ``None``...), +# :func:`_format_fwd_result` runs each user output through its +# flatten protocol and concatenates the inner tensors into the flat +# return; saved-for-backward tensors follow in declaration order. + + +def _flatten_value_into(flat: List[torch.Tensor], value: Any) -> None: + """Append the ``Tensor[]`` slots produced by ``value`` to ``flat``. + + The inverse of :func:`_template_reassemble`; the slot counts match + :func:`_template_slot_count`: + + * ``None`` -> 1 sentinel slot (via :func:`_encode_none`). + * plain Tensor -> 1 slot. + * tensor subclass with ``__tensor_flatten__`` -> ``len(inner_names)`` + slots, in the order declared by the class. + * storage with ``_torch_compile_flatten`` -> ``len(tensors)`` slots. + """ + if value is None: + flat.append(_encode_none(None)) + return + if isinstance(value, torch.Tensor): + if type(value) is not torch.Tensor and hasattr(value, "__tensor_flatten__"): + inner_names, _ = value.__tensor_flatten__() + flat.extend(_encode_none(getattr(value, n)) for n in inner_names) + else: + flat.append(_encode_none(value)) + return + if hasattr(value, "_torch_compile_flatten"): + _, _, tensors = value._torch_compile_flatten() + flat.extend(_encode_none(t) for t in tensors) + return + raise TypeError( + f"unsupported value type {type(value).__name__}; expected " + "None / torch.Tensor / tensor subclass with __tensor_flatten__ / " + "class with _torch_compile_flatten." + ) + + +# Trailing slots in every ``fwd_impl`` return tuple: +# ``tensors_to_save, tensor_objects, ctx_attrs``. User-output count +# is ``len(result) - _FWD_TRAILING_SLOTS``. +_FWD_TRAILING_SLOTS = 3 + + +def _format_fwd_result(result: Any) -> List[torch.Tensor]: + """Pack a fwd-impl return tuple into the op's ``Tensor[]`` payload. + + User outputs come first, then the saved-for-backward tensors in + declaration order. Both groups go through the same per-value + :func:`_flatten_value_into` dispatch -- the slot layout produced + here must match exactly what :func:`_template_slot_count` reports + for the corresponding fake template, since the call-site reassembly in + :func:`forward_fn` / :func:`_setup_context` slices this flat list + back into user-facing objects using those per-template counts. + + ``None`` entries on either side are smuggled through + :func:`_encode_none` so the schema stays non-nullable and + ``register_autograd`` still attaches a ``grad_fn`` to the op's + outputs. + + The split point between user outputs and saved tensors is + inferred from the impl's return shape: + ``(*user_outputs, tensors_to_save, tensor_objects, ctx_attrs)`` + -- the last three slots are the standard ``fwd_impl`` tail. + """ + num_outputs = len(result) - _FWD_TRAILING_SLOTS + flat: List[torch.Tensor] = [] + for value in result[:num_outputs]: + _flatten_value_into(flat, value) + saved = result[num_outputs] or () + for value in saved: + _flatten_value_into(flat, value) + return flat + + +def _format_bwd_result( + grads: Any, num_grad_inputs: int, op_qualname: str +) -> List[torch.Tensor]: + """Pack a backward-impl return tuple into the op's ``Tensor[]`` payload. + + Validates that the user kernel returned exactly one grad per + ``input_tensors_for_grad`` entry; raises with the op's qualified + name on mismatch. + """ + grads = list(grads) + if len(grads) != num_grad_inputs: + raise RuntimeError( + f"{op_qualname} expected backward_impl to return " + f"{num_grad_inputs} grads (one per input_tensors_for_grad " + f"entry), got {len(grads)}" + ) + return [_encode_none(g) for g in grads] + + +def _resolve_grad_targets( + fwd_buckets: List[_Bucket], + fwd_arg_type: type, + input_tensors_for_grad: List[str], +) -> Tuple[List[Any], List[Tuple[int, bool]]]: + """Validate ``input_tensors_for_grad`` and resolve grad-output layout. + + Returns ``(fwd_slot_defaults, grad_targets)`` where: + + * ``fwd_slot_defaults`` is the per-slot "no-grad" template the + autograd return tuple starts from -- ``[]`` for ``Tensor[]`` + slots, ``None`` otherwise. ``register_autograd`` requires one + grad slot per forward input with matching tree structure (a + ``Tensor[]`` slot must get back a list, not bare ``None``). + * ``grad_targets`` is the ``[(slot_index, as_list), ...]`` mapping + for each name in ``input_tensors_for_grad``, in the same order; + ``as_list`` is ``True`` for ``Tensor[]``-shaped slots so the + caller wraps the single grad into a length-matched list. + """ + fwd_slot_defaults: List[Any] = [] + for bucket in fwd_buckets: + for _, type_str in bucket.schema_slots(): + fwd_slot_defaults.append([] if type_str.endswith("[]") else None) + + fwd_grad_targets: Dict[str, Tuple[int, bool]] = {} + slot_offset = 0 + for bucket in fwd_buckets: + slots = bucket.schema_slots() + if isinstance(bucket, _TensorBucket): + fwd_grad_targets[bucket.name] = (slot_offset, False) + elif isinstance(bucket, _UniversalTensorBucket): + # Grad routes to the ``Tensor?`` slot -- the wrapper / + # plain-tensor passthrough -- so the gradient flows back + # to the user-facing object (e.g. an ``nn.Parameter`` + # wrapped as ``Float8Tensor``). In the storage path the + # ``Tensor?`` slot is ``None`` and the kernel does not + # request a grad for it. + for i, (slot_name, _) in enumerate(slots): + if slot_name == bucket.slot_name(): + fwd_grad_targets[bucket.name] = (slot_offset + i, False) + break + slot_offset += len(slots) + + unknown = [n for n in input_tensors_for_grad if n not in fwd_grad_targets] + if unknown: + raise ValueError( + f"input_tensors_for_grad contains names not present in " + f"{fwd_arg_type.__name__} schema: {unknown}" + ) + grad_targets = [fwd_grad_targets[n] for n in input_tensors_for_grad] + return fwd_slot_defaults, grad_targets + + +def _register_kernel( + *, + op_name: str, + op_qualname: str, + arg_type: type, + arg_names: List[str], + buckets: List[_Bucket], + impl: Callable[[Any], Any], + fake_impl: Callable[[Any], Any], + format_result: Callable[[Any], List[torch.Tensor]], +) -> None: + """Wire ``impl`` + ``fake_impl`` into :data:`_TE_LIB` under ``op_name``. + + The wrapper unpacks the flat positional args using + ``arg_names`` / ``buckets``, calls the user kernel with the rebuilt + dataclass instance, and packs the result through ``format_result`` + (which encodes ``None``s into the op's ``Tensor[]`` return slot). + """ + + def _eager(*flat: Any) -> List[torch.Tensor]: + kwargs = dict(zip(arg_names, flat)) + obj = _unpack(arg_type, kwargs, buckets) + return format_result(impl(obj)) + + def _fake(*flat: Any) -> List[torch.Tensor]: + kwargs = dict(zip(arg_names, flat)) + obj = _unpack(arg_type, kwargs, buckets) + return format_result(fake_impl(obj)) + + _TE_LIB.impl(op_name, _eager, "CompositeExplicitAutograd") + torch.library.register_fake(op_qualname, _fake, lib=_TE_LIB) + + +def _collect_universal_slot_offsets(buckets: List[_Bucket]) -> List[int]: + """Return the start index of each :class:`_UniversalTensorBucket` + group inside the flat positional arg list of a registered op. + + The four schema slots emitted by a universal bucket are always + contiguous (``name``, ``__tensors``, ``__pg``, ``__meta``); knowing + the offset of the first slot lets a subclass dispatch rule rewrite + all four slots in place at trace / eager time without re-deriving + the bucket list. + """ + offsets: List[int] = [] + pos = 0 + for bucket in buckets: + if isinstance(bucket, _UniversalTensorBucket): + offsets.append(pos) + pos += len(bucket.schema_slots()) + return offsets + + +def _flatten_subclass_into_slots( + new_args: List[Any], slot_offsets: List[int], subclass: type +) -> None: + """Rewrite each ``_UniversalTensorBucket`` group whose ``Tensor?`` + slot holds an instance of ``subclass`` into the storage layout. + + Used as the body of a ``register_torch_dispatch`` rule on the outer + fwd / bwd op: a subclass passed through the user-facing op is + flattened in place (via ``_torch_compile_flatten``) so that the + inner op only ever sees plain tensors plus the storage-flatten + metadata. The wrapper's autograd identity remains attached to the + inner tensors via the wrapper-subclass machinery, so gradients + still flow back to the user-facing tensor. + """ + for offset in slot_offsets: + val = new_args[offset] + if val is None or not isinstance(val, subclass): + continue + meta, pg, tensors = val._torch_compile_flatten() + meta._data[_UniversalTensorBucket.KIND_KEY] = _UniversalTensorBucket.KIND_STORAGE + new_args[offset] = None + new_args[offset + 1] = list(tensors) + new_args[offset + 2] = pg + new_args[offset + 3] = meta + + +def _register_autograd_for_op( + *, + fwd_op_name: str, + bwd_op_name: str, + fwd_arg_type: type, + fwd_arg_names: List[str], + fwd_buckets: List[_Bucket], + bwd_arg_names: List[str], + bwd_buckets: List[_Bucket], + fwd_slot_defaults: List[Any], + grad_targets: List[Tuple[int, bool]], + setup_context_user: Callable[..., None], + backward_obj_type: type, + fwd_fake_impl: Callable[[Any], Tuple[Any, ...]], + fwd_field_names: Sequence[str], +) -> None: + """Wire ``register_autograd`` on a forward op so its backward calls + ``bwd_op_name``. + + Both the inner and outer tiers of a two-tier op share an identical + autograd bridge (the wrapper logic only cares about op *names*), so + this helper is called once per tier; the actual kernel + registration is handled separately (by :func:`_register_kernel` + for the inner tier and :func:`_register_outer_forwarder` for the + outer tier). + + The op's ``Tensor[]`` return holds the flat layout produced by + :func:`_format_fwd_result` -- one chunk per user output / saved + tensor. ``setup_context`` re-runs ``fwd_fake_impl`` to recover the + fake output / saved templates, then reassembles each chunk via + :func:`_template_reassemble`. Saved slots that alias a forward arg + (the fake returns the actual arg) are detected by identity and + surfaced to the user's ``setup_context`` via + ``ctx_attrs["saved_tensor_aliases"]``. + """ + fwd_qualname = f"{_TE_OP_NAMESPACE}::{fwd_op_name}" + + def _setup_context(ctx, inputs, output): + ctx._te_fwd_tensor_list_lengths = { + i: len(value) for i, value in enumerate(inputs) if isinstance(value, list) + } + kwargs = dict(zip(fwd_arg_names, inputs)) + fwd_obj = _unpack(fwd_arg_type, kwargs, fwd_buckets) + + user_fakes, saved_fakes, ctx_attrs = _split_fwd_fake_result(fwd_fake_impl(fwd_obj)) + pairs = _fwd_arg_alias_pairs(fwd_obj, fwd_field_names) + saved_aliases = tuple(_alias_name_for(t, pairs) for t in saved_fakes) + ctx_attrs = dict(ctx_attrs) + ctx_attrs["saved_tensor_aliases"] = saved_aliases + + cursor = 0 + user_outputs: List[Any] = [] + for template in user_fakes: + n = _template_slot_count(template) + chunk = [_decode_none(t) for t in output[cursor:cursor + n]] + cursor += n + user_outputs.append(_template_reassemble(template, chunk)) + + tensors_to_save_from_forward_list: List[Any] = [] + for template, alias in zip(saved_fakes, saved_aliases): + aliased = alias is not None + n = _template_slot_count(template, aliased=aliased) + chunk = [_decode_none(t) for t in output[cursor:cursor + n]] + cursor += n + tensors_to_save_from_forward_list.append( + _template_reassemble(template, chunk, aliased=aliased) + ) + tensors_to_save_from_forward = tuple(tensors_to_save_from_forward_list) + + bwd_obj = backward_obj_type() + tensors_to_save_from_setup = setup_context_user( + bwd_obj, + fwd_obj, + user_outputs[0] if len(user_fakes) == 1 else tuple(user_outputs), + ctx_attrs, + tensors_to_save_from_forward, + ) + tensors_to_save, tensor_objects = _prepare_for_saving(tensors_to_save_from_setup) + ctx.tensor_objects = tensor_objects + ctx.save_for_backward(*tensors_to_save) + ctx.bwd_obj = bwd_obj + + def _autograd_backward(ctx, *grad_outputs): + bwd_obj = ctx.bwd_obj + if hasattr(bwd_obj, "setup_saved_tensors"): + bwd_obj.setup_saved_tensors(ctx) + ctx.tensor_objects = None + per_output_grads = grad_outputs[0] + bwd_obj.grad_output = _decode_none(per_output_grads[0]) + kwargs = _pack(bwd_obj, bwd_buckets) + bwd_args_flat = [kwargs[name] for name in bwd_arg_names] + bwd_op = getattr(getattr(torch.ops, _TE_OP_NAMESPACE), bwd_op_name) + grads = [_decode_none(g) for g in bwd_op(*bwd_args_flat)] + out: List[Any] = list(fwd_slot_defaults) + tensor_list_lengths = getattr(ctx, "_te_fwd_tensor_list_lengths", {}) + # Pad every ``Tensor[]`` slot with ``None`` entries matching the + # corresponding forward input length. AOT's pytree check on the + # backward return rejects an empty list where the forward input + # was a non-empty list -- the list structure must match + # element-for-element. Grad-target slots below overwrite the + # first entry with the actual gradient. + for pos, length in tensor_list_lengths.items(): + if isinstance(out[pos], list): + out[pos] = [None] * length + for (pos, as_list), g in zip(grad_targets, grads): + if as_list: + length = tensor_list_lengths.get(pos, 1) + out[pos] = [g] + [None] * (length - 1) + else: + out[pos] = g + return tuple(out) + + torch.library.register_autograd( + fwd_qualname, + _autograd_backward, + setup_context=_setup_context, + lib=_TE_LIB, + ) + + +def _register_outer_forwarder( + *, + outer_op_name: str, + inner_op_name: str, + buckets: Optional[List[_Bucket]] = None, + subclass_list: Optional[List[type]] = None, +) -> None: + """Register the outer op's default kernel + fake. + + Both kernel and fake forward to the inner op, optionally with an + in-place input flatten step for any registered subclass arg (so the + inner op's plain-tensor schema is satisfied). Outputs travel + untouched in their flat ``Tensor[]`` shape -- the user-facing + wrapping back into subclasses / storage happens in + :func:`forward_fn` via :class:`_ToSubclassFn`. + """ + inner_op = getattr(getattr(torch.ops, _TE_OP_NAMESPACE), inner_op_name) + + input_flatten_enabled = bool(subclass_list) and buckets is not None + + if input_flatten_enabled: + slot_offsets = _collect_universal_slot_offsets(buckets) + + def _flatten_all(new_args: List[Any]) -> None: + for sub in subclass_list: + _flatten_subclass_into_slots(new_args, slot_offsets, sub) + + def _outer_kernel(*flat: Any) -> List[torch.Tensor]: + new_args = list(flat) + _flatten_all(new_args) + return inner_op(*new_args) + + def _outer_fake(*flat: Any) -> List[torch.Tensor]: + new_args = list(flat) + _flatten_all(new_args) + return inner_op(*new_args) + else: + def _outer_kernel(*flat: Any) -> List[torch.Tensor]: + return inner_op(*flat) + + def _outer_fake(*flat: Any) -> List[torch.Tensor]: + return inner_op(*flat) + + _TE_LIB.impl(outer_op_name, _outer_kernel, "CompositeExplicitAutograd") + torch.library.register_fake( + f"{_TE_OP_NAMESPACE}::{outer_op_name}", _outer_fake, lib=_TE_LIB + ) + + +def _all_quantized_tensor_subclasses() -> List[type]: + """Return every imported ``QuantizedTensor`` wrapper subclass. + + Imports the ``transformer_engine.pytorch.tensor`` package as a side + effect so that all concrete wrapper subclasses (``Float8Tensor``, + ``MXFP8Tensor``, ``Float8BlockwiseQTensor``, ``NVFP4Tensor``) get + registered with Python's subclass tracker before we walk + ``QuantizedTensor.__subclasses__()`` recursively. The lazy import + keeps ``dynamo.py`` itself free of top-level ``tensor`` imports + (which would form a cycle through the in-function ``dynamo`` + imports inside the tensor modules), while still giving every + custom op the full subclass set at registration time. + """ + import transformer_engine.pytorch.tensor # noqa: F401 -- side-effect: registers subclasses + from transformer_engine.pytorch.quantized_tensor import QuantizedTensor + + seen: List[type] = [] + stack: List[type] = list(QuantizedTensor.__subclasses__()) + while stack: + cls = stack.pop() + if cls in seen: + continue + seen.append(cls) + stack.extend(cls.__subclasses__()) + return seen + + +def _te_register_custom_op( + *, + op_name: str, + input_tensors_for_grad: List[str], + fwd_arg_type: type, + fwd_impl: Callable[[Any], Any], + setup_context: Callable[..., None], + backward_arg_type: type, + backward_obj: type, + backward_impl: Callable[[Any], Any], + fwd_fake_impl: Callable[[Any], Tuple[Any, ...]], + bwd_fake_impl: Callable[[Any], Tuple[Any, ...]], +) -> Callable[..., Any]: + """Register a TE module's forward + backward as a single torch custom op. + + Parameters + ---------- + op_name + Op name used when registering with ``torch.library``. The + namespace is fixed at module level (:data:`_TE_OP_NAMESPACE`). + input_tensors_for_grad + Names of forward-arg-type fields for which ``backward_impl`` + returns gradients, in the same order. The wrapper uses this to + pad the autograd return tuple with ``None`` for every input not + listed here, so torch sees one grad slot per forward input as + required by ``register_autograd``. + fwd_arg_type + Dataclass type aggregating all forward inputs (e.g. + ``LinearFwdArgs``). Used to (re)build the structured argument + from the flat tensor / non-tensor inputs accepted by the custom op. + fwd_impl + Eager forward implementation. Receives a single argument of type + ``fwd_arg_type`` and must return a tuple of the form + ``(*output_tensors, tensors_to_save, tensor_objects, ctx_attrs)`` + where: + + * ``output_tensors`` -- one or more :class:`torch.Tensor` outputs + returned to the caller. + * ``tensors_to_save`` -- flat list of :class:`torch.Tensor` to be + stashed via ``ctx.save_for_backward``. + * ``tensor_objects`` -- the metadata object produced by + :func:`prepare_for_saving`, paired with ``tensors_to_save`` to + let the backward reconstruct quantized / structured tensors. + * ``ctx_attrs`` -- non-tensor state to attach to the autograd + context, restricted to values that cannot be derived from the + forward args inside ``setup_context``. + setup_context + Eager autograd ``setup_context`` analogue. Receives a freshly + constructed ``backward_obj`` instance, the forward args, the + forward output, and ``ctx_attrs`` produced by ``fwd_impl``; + is responsible for populating the backward-state object so that + ``backward_impl`` can later consume it. + backward_arg_type + Type accepted by ``backward_impl``. May differ from ``backward_obj`` + if the backward op needs a wrapped / opaque view of the state. + backward_obj + Dataclass / class used to instantiate a fresh backward-state + container at the end of the forward pass (typically the same as + ``backward_arg_type``). + backward_impl + Eager backward implementation. Receives a single argument of type + ``backward_arg_type`` and returns the gradient tuple. + fwd_fake_impl + Forward fake implementation: ``fn(fwd_obj) -> (*user_outputs, + tensors_to_save, tensor_objects, ctx_attrs)`` -- the same tuple + shape as ``fwd_impl``, but built from *fake* values instead of + running the real kernel. Each slot is one of: + + * ``quantizer.make_fake_empty(shape, dtype, device)`` -- a + Dynamo-safe quantized wrapper (assembled via + ``__tensor_unflatten__`` with a snapshot-free meta). + * ``quantizer.make_empty(shape, dtype, device)`` -- a quantized + storage (e.g. an FP8 weight workspace). + * ``torch.empty(shape, dtype, device)`` -- a plain tensor. + * the actual forward-arg tensor -- for a saved slot that aliases + a forward input (detected by identity). + * ``None`` -- an absent output / saved slot. + + This single callable drives both consumers: ``forward_fn`` / + ``setup_context`` use its fake values directly as reassembly + templates (:func:`_template_slot_count` / + :func:`_template_reassemble`), and + :func:`_fwd_register_fake_from_fake_impl` wires it (with aliased + saved slots nulled) as the op's + :func:`torch.library.register_fake`. The whole callable must be + Dynamo-traceable under ``fullgraph=True``. + bwd_fake_impl + Backward fake implementation: ``fn(bwd_obj) -> grad_tuple``, one + fake grad per gradient output in the same order as + ``backward_impl``'s return tuple (``None`` for missing grads, + ``torch.empty`` for plain, ``quantizer.make_empty`` for + quantized). Wired directly as the backward op's + ``register_fake`` -- backward grads never round-trip through the + op payload, so no layout adapter is needed. + + Returns + ------- + Callable + A function ``forward_fn(fwd_arg_type_instance)`` that dispatches + through the registered custom op, returning the user-facing + outputs (single tensor if the impl produced exactly one + user-facing output, otherwise a tuple). Use under + ``torch.compiler.is_compiling()`` as a drop-in for + ``Function.apply``. + """ + + outer_fwd_name = op_name + outer_bwd_name = f"{op_name}_backward" + # Auto-discover every imported ``QuantizedTensor`` wrapper subclass + # so callers never have to enumerate them. Each subclass gets a + # ``register_torch_dispatch`` rule on the outer op (see below) and + # is flattened into plain tensors before the inner op runs. + subclass_list = _all_quantized_tensor_subclasses() + + # Precompute the bucket list once per arg type and capture it in + # the registered closures. Re-deriving the bucket list inside a + # compiled call would force :func:`_get_buckets` to read + # ``cls.__dict__`` from inside a Dynamo-traced function, which + # triggers a "mappingproxy affected by dictionary mutation" graph + # break under ``fullgraph=True``. + fwd_buckets: List[_Bucket] = _get_buckets(fwd_arg_type) + bwd_buckets: List[_Bucket] = _get_buckets(backward_arg_type) + + fwd_schema_args, fwd_arg_names = _build_schema(fwd_buckets) + bwd_schema_args, bwd_arg_names = _build_schema(bwd_buckets) + + num_grad_inputs = len(input_tensors_for_grad) + fwd_slot_defaults, grad_targets = _resolve_grad_targets( + fwd_buckets, fwd_arg_type, input_tensors_for_grad + ) + + # Two-tier layout when at least one ``QuantizedTensor`` subclass is + # imported (the common case -- ``_all_quantized_tensor_subclasses`` + # discovers them automatically): + # inner = ``{op_name}_base`` -- real impl, sees only plain tensors + # and the storage-flatten metadata. + # outer = ``{op_name}`` -- user-facing op that either falls through + # to the inner op (plain-tensor path) or is rewritten by a + # ``register_torch_dispatch`` rule (subclass path) into a + # call to the inner op with subclass tensors flattened in + # place. Both tiers carry their own ``register_autograd`` + # bridge. + # Single-tier fallback: if no ``QuantizedTensor`` subclasses have + # been imported (e.g. minimal embedded build) only the outer pair + # is defined and it owns the real impl directly. + inner_fwd_name = f"{op_name}_base" if subclass_list else outer_fwd_name + inner_bwd_name = f"{outer_bwd_name}_base" if subclass_list else outer_bwd_name + + # Forward op concatenates user outputs and tensors_to_save into a + # single ``Tensor[]`` return so that autograd's ``setup_context`` can + # stash the saved-for-backward tensors without re-running the eager + # impl. The schema is non-nullable (``Tensor[]``, not ``Tensor?[]``) + # because ``torch.library.register_autograd`` does not propagate + # ``grad_fn`` to a nullable list output. ``None`` entries on either + # side are smuggled through via :func:`_encode_none` / + # :func:`_decode_none` sentinels. + _TE_LIB.define(f"{inner_fwd_name}{fwd_schema_args} -> Tensor[]") + _TE_LIB.define(f"{inner_bwd_name}{bwd_schema_args} -> Tensor[]") + if subclass_list: + # Outer fwd / outer bwd are user-facing entry points. The + # outer fwd is the target of ``register_torch_dispatch`` for + # the forward subclass path; outer bwd is the target for the + # backward subclass path. Both forward to the corresponding + # inner op when no rule matches (plain-tensor / pure-storage + # path). + _TE_LIB.define(f"{outer_fwd_name}{fwd_schema_args} -> Tensor[]") + _TE_LIB.define(f"{outer_bwd_name}{bwd_schema_args} -> Tensor[]") + + # Inner pair owns the real implementation. The fwd & bwd kernels + # are registered directly against the user-supplied impls; the + # autograd bridge below wires the inner fwd op's backward to call + # the inner bwd op. + inner_fwd_qualname = f"{_TE_OP_NAMESPACE}::{inner_fwd_name}" + inner_bwd_qualname = f"{_TE_OP_NAMESPACE}::{inner_bwd_name}" + + # The module supplies its output layout as a forward ``fake_impl`` + # (fake values in the eager-impl tuple shape). ``forward_fn`` / + # ``setup_context`` consume it directly as reassembly templates; the + # forward ``register_fake`` kernel wraps it to null aliased saved slots + # (so the fake flat ``Tensor[]`` matches the eager impl). The backward + # ``fake_impl`` is the backward ``register_fake`` directly. ``field + # names`` are precomputed here (reading ``dataclasses.fields`` in-trace + # would graph-break) for the alias-by-identity detection. + fwd_field_names = [f.name for f in dataclasses.fields(fwd_arg_type)] + fwd_register_fake = _fwd_register_fake_from_fake_impl(fwd_fake_impl, fwd_field_names) + + _register_kernel( + op_name=inner_fwd_name, + op_qualname=inner_fwd_qualname, + arg_type=fwd_arg_type, + arg_names=fwd_arg_names, + buckets=fwd_buckets, + impl=fwd_impl, + fake_impl=fwd_register_fake, + format_result=_format_fwd_result, + ) + _register_kernel( + op_name=inner_bwd_name, + op_qualname=inner_bwd_qualname, + arg_type=backward_arg_type, + arg_names=bwd_arg_names, + buckets=bwd_buckets, + impl=backward_impl, + fake_impl=bwd_fake_impl, + format_result=lambda g: _format_bwd_result(g, num_grad_inputs, inner_bwd_qualname), + ) + _register_autograd_for_op( + fwd_op_name=inner_fwd_name, + bwd_op_name=inner_bwd_name, + fwd_arg_type=fwd_arg_type, + fwd_arg_names=fwd_arg_names, + fwd_buckets=fwd_buckets, + bwd_arg_names=bwd_arg_names, + bwd_buckets=bwd_buckets, + fwd_slot_defaults=fwd_slot_defaults, + grad_targets=grad_targets, + setup_context_user=setup_context, + backward_obj_type=backward_obj, + fwd_fake_impl=fwd_fake_impl, + fwd_field_names=fwd_field_names, + ) + + if subclass_list: + # Outer tier (thin shell): default kernels forward to inner + # plus a ``register_torch_dispatch`` rule per subclass that + # flattens the wrapper in place before forwarding. Carries + # its own autograd bridge so the user-facing subclass tensor + # (e.g. a ``Float8Tensor`` parameter) stays on the autograd + # graph and receives a ``.grad``. + _register_outer_forwarder( + outer_op_name=outer_fwd_name, + inner_op_name=inner_fwd_name, + buckets=fwd_buckets, + subclass_list=list(subclass_list), + ) + _register_outer_forwarder( + outer_op_name=outer_bwd_name, + inner_op_name=inner_bwd_name, + ) + _register_autograd_for_op( + fwd_op_name=outer_fwd_name, + bwd_op_name=outer_bwd_name, + fwd_arg_type=fwd_arg_type, + fwd_arg_names=fwd_arg_names, + fwd_buckets=fwd_buckets, + bwd_arg_names=bwd_arg_names, + bwd_buckets=bwd_buckets, + fwd_slot_defaults=fwd_slot_defaults, + grad_targets=grad_targets, + setup_context_user=setup_context, + backward_obj_type=backward_obj, + fwd_fake_impl=fwd_fake_impl, + fwd_field_names=fwd_field_names, + ) + + fwd_slot_offsets = _collect_universal_slot_offsets(fwd_buckets) + bwd_slot_offsets = _collect_universal_slot_offsets(bwd_buckets) + inner_fwd_op = getattr(getattr(torch.ops, _TE_OP_NAMESPACE), inner_fwd_name) + inner_bwd_op = getattr(getattr(torch.ops, _TE_OP_NAMESPACE), inner_bwd_name) + outer_fwd_op = getattr(getattr(torch.ops, _TE_OP_NAMESPACE), outer_fwd_name) + outer_bwd_op = getattr(getattr(torch.ops, _TE_OP_NAMESPACE), outer_bwd_name) + outer_fwd_qualname = f"{_TE_OP_NAMESPACE}::{outer_fwd_name}" + outer_bwd_qualname = f"{_TE_OP_NAMESPACE}::{outer_bwd_name}" + + def _flatten_all_subclasses(new_args: List[Any], slot_offsets: List[int]) -> None: + for sub in subclass_list: + _flatten_subclass_into_slots(new_args, slot_offsets, sub) + + def _fwd_rule(mode, func, types, args, kwargs): + del mode, func, types, kwargs + new_args = list(args) + _flatten_all_subclasses(new_args, fwd_slot_offsets) + return inner_fwd_op(*new_args) + + def _bwd_rule(mode, func, types, args, kwargs): + del mode, func, types, kwargs + new_args = list(args) + _flatten_all_subclasses(new_args, bwd_slot_offsets) + return inner_bwd_op(*new_args) + + # Per-subclass dispatch rule: any registered subclass arg + # passed to the outer op (e.g. Dynamo lifting a + # ``Float8Tensor`` weight into the FX graph) is flattened + # into its storage layout before forwarding to the inner op, + # which only ever sees plain tensors. + for sub in subclass_list: + torch.library.register_torch_dispatch( + outer_fwd_qualname, sub, _fwd_rule, lib=_TE_LIB + ) + torch.library.register_torch_dispatch( + outer_bwd_qualname, sub, _bwd_rule, lib=_TE_LIB + ) + + # ``QuantizedTensor.__torch_dispatch__`` falls back to + # dequantizing all subclass args for any op it does not + # recognise, which would defeat our + # ``register_torch_dispatch`` rules and would also crash on + # FakeTensors (``tex.dequantize`` needs ``data_ptr``). Mark + # every op we register through this helper -- both tiers and + # both directions -- as passthroughs so QuantizedTensor + # delegates straight to ``super().__torch_dispatch__``. + from transformer_engine.pytorch.quantized_tensor import ( + _quantized_tensor_passthrough_ops, + ) + _quantized_tensor_passthrough_ops.add(outer_fwd_op.default) + _quantized_tensor_passthrough_ops.add(outer_bwd_op.default) + _quantized_tensor_passthrough_ops.add(inner_fwd_op.default) + _quantized_tensor_passthrough_ops.add(inner_bwd_op.default) + + fwd_op = getattr(getattr(torch.ops, _TE_OP_NAMESPACE), outer_fwd_name) + + def forward_fn(fwd_args): + user_fakes, _saved_fakes, _ctx_attrs = _split_fwd_fake_result( + fwd_fake_impl(fwd_args) + ) + kwargs = _pack(fwd_args, fwd_buckets) + flat_in = [kwargs[name] for name in fwd_arg_names] + result = fwd_op(*flat_in) + + # Slice the flat result using the fake outputs as templates. Subclass + # templates route through :class:`_ToSubclassFn` to keep the wrap on + # the autograd graph; plain tensors / storage classes are + # reconstructed directly. User outputs never alias a forward arg. + cursor = 0 + outputs: List[Any] = [] + for template in user_fakes: + n = _template_slot_count(template) + chunk = [_decode_none(t) for t in result[cursor:cursor + n]] + cursor += n + outputs.append(_template_reassemble(template, chunk, with_autograd=True)) + + if len(outputs) == 1: + return outputs[0] + return tuple(outputs) + + return forward_fn diff --git a/transformer_engine/pytorch/fp8_dtype.py b/transformer_engine/pytorch/fp8_dtype.py new file mode 100644 index 0000000000..88b2b5c0c1 --- /dev/null +++ b/transformer_engine/pytorch/fp8_dtype.py @@ -0,0 +1,93 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Dynamo-friendly mirror of ``transformer_engine_torch.DType``. + +The C++-binded ``transformer_engine_torch.DType`` enum is opaque to +TorchDynamo (see ``UserDefinedObjectVariable(DType)`` graph-break under +``fullgraph=True``): Dynamo cannot proxy a pybind11 enum value as a +constant in the FX graph it builds for tensor-subclass constructors +(e.g. :class:`Float8Tensor`). + +:class:`FP8DType` is a Python :class:`enum.IntEnum` that mirrors +``tex.DType`` 1:1 by integer value. Because :class:`IntEnum` derives +from :class:`int`, Dynamo recognises it as a ``ConstantVariable`` and +captures it as a static constant on subclass-constructor calls inside +a compiled region. Conversion to/from the C++ enum is one +``int(...)`` call. +""" +from __future__ import annotations +from enum import IntEnum + +import transformer_engine_torch as tex + + +class FP8DType(IntEnum): + """Python mirror of :class:`transformer_engine_torch.DType` (int values). + + Values match :class:`tex.DType` 1:1 so that ``int(FP8DType.x) == + int(tex.DType.x)`` for every member. Use :func:`to_tex` to bridge + back to the C++ enum at pybind boundaries. + """ + + kByte = int(tex.DType.kByte) + kInt32 = int(tex.DType.kInt32) + kFloat32 = int(tex.DType.kFloat32) + kFloat16 = int(tex.DType.kFloat16) + kBFloat16 = int(tex.DType.kBFloat16) + kFloat8E4M3 = int(tex.DType.kFloat8E4M3) + kFloat8E5M2 = int(tex.DType.kFloat8E5M2) + kFloat4E2M1 = int(tex.DType.kFloat4E2M1) + + +# Precomputed at module load so Dynamo doesn't have to trace +# ``IntEnum.__new__`` / ``tex.DType.__int__`` inside compiled regions +# (both recurse through Python's internal inspect machinery and exhaust +# Dynamo's frame stack). +_TEX_TO_FP8DTYPE = {member.value: member for member in FP8DType} +_TEX_TO_FP8DTYPE_BY_TEX = {tex.DType(v): m for v, m in _TEX_TO_FP8DTYPE.items()} + + +def to_tex(d) -> tex.DType: + """Coerce ``d`` (``FP8DType`` / ``tex.DType`` / int) to ``tex.DType``.""" + if isinstance(d, tex.DType): + return d + return tex.DType(int(d)) + + +def from_tex(d: tex.DType) -> FP8DType: + """Coerce a ``tex.DType`` (or int matching one of its enum values) to + :class:`FP8DType` via a precomputed lookup table. + """ + if isinstance(d, FP8DType): + return d + if isinstance(d, tex.DType): + return _TEX_TO_FP8DTYPE_BY_TEX[d] + return _TEX_TO_FP8DTYPE[int(d)] + + +# Register ``tex.DType`` as a torch.compile value-opaque type so it +# can flow through Dynamo as a constant inside ``__tensor_flatten__`` +# meta dicts and other traced metadata payloads. Without this, +# Dynamo trips on ``UserDefinedObjectVariable(DType)`` because the +# pybind11 enum carries a custom ``__hash__``. ``__fx_repr__`` is +# injected once here so the FX codegen can serialize literal values +# as ``TE_DType()``. Gated by a try/except so importing this +# module remains safe on older PyTorch versions that lack the +# private ``opaque_object`` API. +try: + from torch._library.opaque_object import ( + is_opaque_value_type as _is_opaque_value_type, + register_opaque_type as _register_opaque_type, + ) + + if not hasattr(tex.DType, "__fx_repr__"): + tex.DType.__fx_repr__ = lambda self: ( + f"TE_DType({int(self)})", + {"TE_DType": tex.DType}, + ) + if not _is_opaque_value_type(tex.DType): + _register_opaque_type(tex.DType, typ="value", members={}) +except Exception: # pragma: no cover - older torch / partial init + pass diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 746177ec78..c4691aa645 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -972,7 +972,11 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: return if recipe.nvfp4() and isinstance(recipe_state, NVFP4BlockScalingRecipeState): return - if recipe.custom() and isinstance(recipe_state, CustomRecipeState): + if ( + recipe.custom() + and isinstance(recipe_state, CustomRecipeState) + and recipe_state.recipe is recipe + ): return # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and @@ -1859,6 +1863,12 @@ def _check_weight_tensor_recipe_correspondence(self) -> None: return if not hasattr(self, "weight_names") or not self.weight_names: return + # Skip under ``torch.compile`` -- the check is a one-off + # runtime guard that calls ``tensor._get_quantizer()`` (returns + # a ``Quantizer``, not a Tensor) and Dynamo cannot trace + # quantizer objects flowing through ``call_method``. + if torch.compiler.is_compiling(): + return recipe = self.fp8_meta["recipe"] weight_tensors = [getattr(self, name) for name in self.weight_names] diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index dcbb9eaf93..dfa0bb6b51 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -26,6 +26,7 @@ _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD, + _is_weight_workspace_valid, ) from ._common import noop_cat, WeightGradStore from ..quantization import FP8GlobalStateManager, QuantizerRole @@ -56,16 +57,22 @@ general_gemm, ) from ..constants import FP8BwdTensorIdx, FP8FwdTensorIdx, GemmParallelModes, dist_group_type -from ..jit import no_torch_dynamo +from ..dynamo import ( + _te_register_custom_op, +) from ..graph import is_graph_capturing from ..quantized_tensor import ( QuantizedTensor, QuantizedTensorStorage, Quantizer, + TensorOrQuantized, prepare_for_saving, - restore_from_func_ctx, + restore_from_saved, +) +from ..tensor.float8_tensor import ( + Float8CurrentScalingQuantizer, + Float8Quantizer, ) -from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.utils import clear_columnwise_cache, is_custom from ..export import is_in_onnx_export_mode, assert_warmed_up @@ -80,20 +87,17 @@ __all__ = ["Linear"] -TensorOrQuantized = Union[torch.Tensor, QuantizedTensorStorage] - - @dataclass(slots=True) class LinearFwdArgs: """Single-argument bag for the forward path of :class:`_Linear`.""" # --- Differentiable tensors (also passed positionally to autograd) --- weight: TensorOrQuantized - inp: torch.Tensor + inp: TensorOrQuantized bias: Optional[torch.Tensor] # --- Non-differentiable cached tensors --- - weight_workspace: Optional[torch.Tensor] + weight_workspace: Optional[QuantizedTensorStorage] # --- requires_grad flags (cached so backward does not re-query) --- input_requires_grad: bool @@ -181,7 +185,8 @@ class LinearBwdArgs: # --- Numerical / dtype config --- activation_dtype: Optional[torch.dtype] = None fp8: bool = False - fp8_recipe: Optional[Recipe] = None + use_split_accumulator_dgrad: bool = _2X_ACC_DGRAD + use_split_accumulator_wgrad: bool = _2X_ACC_WGRAD backward_override: Optional[str] = None is_weight_param_quantized: bool = False custom: bool = False @@ -225,16 +230,23 @@ class LinearBwdArgs: # --- Per-backward scratch state (populated inside _linear_backward) --- ub_obj_gradout: Optional[Any] = None - def setup_saved_tensors(self, ctx: torch.autograd.function.FunctionCtx) -> None: - """Pull saved tensors from ``ctx`` into the fields backward consumes.""" + def setup_saved_tensors(self, ctx) -> None: + """Restore saved tensors into the fields consumed by backward. + + Accepts both a ``torch.autograd.Function`` ctx (eager path) and a + ``torch.library.register_autograd`` ctx (compile path); both expose + ``saved_tensors`` and the ``tensor_objects`` attribute we attach + during forward. + """ ( self.inputmat, self.weight_fp8, self.saved_weight, self.bias, - ) = restore_from_func_ctx( - ctx - ) # pylint: disable=unbalanced-tuple-unpacking + ) = restore_from_saved( # pylint: disable=unbalanced-tuple-unpacking + ctx.tensor_objects, + list(ctx.saved_tensors), + ) def _check_fp8_reduce_and_update(): @@ -295,7 +307,19 @@ def _linear_forward_impl( # Configure tensor-parallel communication tp_world_size = get_distributed_world_size(tp_group) - backward_needs_input = is_grad_enabled and weight.requires_grad + # NOTE: prefer the explicit ``args.weight_requires_grad`` flag over + # ``weight.requires_grad`` so we stay consistent with the fake impl + # under ``torch.compile``: when the outer op flattens a + # ``Float8Tensor`` wrapper into a ``Float8TensorStorage`` for the + # inner op, the wrapper's ``requires_grad`` is observed inside the + # autograd Function and reads as ``False`` (autograd detaches its + # forward inputs), so the requires-grad bit baked into the + # storage metadata snapshot ends up ``False`` too. The fake impl + # uses ``args.weight_requires_grad`` (populated at outer-call site + # from the live ``nn.Parameter``) so the real impl must too, + # otherwise their ``backward_needs_input`` flags diverge and + # ``tensors_to_save_from_forward`` ends up with different lengths. + backward_needs_input = is_grad_enabled and args.weight_requires_grad with_input_all_gather_nccl = ( parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop ) @@ -645,7 +669,12 @@ def _linear_setup_ctx( # Numerical / dtype config bwd_args.activation_dtype = fwd_args.activation_dtype bwd_args.fp8 = fp8 - bwd_args.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + if fp8: + _bwd_recipe = FP8GlobalStateManager.get_fp8_recipe() + if hasattr(_bwd_recipe, "fp8_gemm_dgrad"): + bwd_args.use_split_accumulator_dgrad = _bwd_recipe.fp8_gemm_dgrad.use_split_accumulator + if hasattr(_bwd_recipe, "fp8_gemm_wgrad"): + bwd_args.use_split_accumulator_wgrad = _bwd_recipe.fp8_gemm_wgrad.use_split_accumulator bwd_args.backward_override = backward_override bwd_args.is_weight_param_quantized = isinstance(weight, QuantizedTensorStorage) bwd_args.custom = fwd_args.custom @@ -946,12 +975,7 @@ def _linear_backward(args: LinearBwdArgs) -> Tuple[Union[torch.Tensor, None], .. ): weight_fp8.update_usage(columnwise_usage=True) - # Choose whether to use GEMM kernel with split accumulator - use_split_accumulator = _2X_ACC_DGRAD - if bwd_args.fp8: - recipe = bwd_args.fp8_recipe - if hasattr(recipe, "fp8_gemm_dgrad"): - use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator + use_split_accumulator = bwd_args.use_split_accumulator_dgrad # Update grad input quantizer if grad_input_quantizer is not None: @@ -1094,12 +1118,7 @@ def _linear_backward(args: LinearBwdArgs) -> Tuple[Union[torch.Tensor, None], .. grad_output_quantizer.set_usage(rowwise=False, columnwise=True) grad_output = grad_output_quantizer(grad_output) - # Figure out whether to use split accumulator - use_split_accumulator = _2X_ACC_WGRAD - if bwd_args.fp8: - recipe = bwd_args.fp8_recipe - if hasattr(recipe, "fp8_gemm_wgrad"): - use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator + use_split_accumulator = bwd_args.use_split_accumulator_wgrad # Figure out whether to output wgrad GEMM directly into main grad if bwd_args.is_first_microbatch is not None: @@ -1249,6 +1268,334 @@ def wgrad_gemm( ) +# ---------------------------------------------------------------------------- +# Compile-tier wrappers: forward / backward ``fake_impl`` + ``_te_register_custom_op`` +# registration. The custom op lets ``torch.compile`` trace through linear +# forward + backward as a single graph node without entering the eager +# ``_Linear`` autograd.Function machinery. Selected by :meth:`Linear.forward` +# when ``torch.compiler.is_compiling()`` is true. +# ---------------------------------------------------------------------------- +def _linear_backward_fake_impl( + args: LinearBwdArgs, +) -> Tuple[Any, Any, Any]: + """Backward fake-impl for :func:`_linear_backward`. + + Returns the ``(wgrad, dgrad, grad_bias)`` gradient triple built from + *fake* values (``None`` for absent grads, ``torch.empty`` for plain, + ``quantizer.make_empty`` for quantized ones), in the same order as + :func:`_linear_backward`'s return tuple. Wired directly as the + backward op's ``register_fake`` -- it runs under fake-prop (not the + Dynamo trace), so ``make_empty`` is fine here. ``set_usage`` on + ``grad_input_quantizer`` is preserved because it influences + ``dgrad``'s allocation. Manual TE FSDP is unsupported; FSDP2 / MCore + FSDP go through the standard path. + """ + + if args.fsdp_group is not None: + raise NotImplementedError( + "Fake Linear backward does not support manual TE FSDP " + "(fsdp_group is not None); use FSDP2 or MCore FSDP." + ) + + assert args.saved_weight is not None and args.grad_output is not None + out_features, in_features = args.saved_weight.shape + + if args.grad_input_quantizer is not None: + args.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) + + activation_dtype = args.activation_dtype + device = args.grad_output.device + + def grad(shape, quantizer): + if shape is None: + return None + if quantizer is not None: + return quantizer.make_empty(list(shape), dtype=activation_dtype, device=device) + return torch.empty(tuple(shape), dtype=activation_dtype, device=device) + + wgrad = ( + grad((out_features, in_features), args.grad_weight_quantizer) + if args.requires_wgrad and not args.fuse_wgrad_accumulation + else None + ) + dgrad = grad(args.inp_shape, args.grad_input_quantizer) if args.requires_dgrad else None + grad_bias = grad((out_features,), None) if args.use_bias and args.requires_wgrad else None + + return wgrad, dgrad, grad_bias + + +def _linear_forward_fake_impl( + args: LinearFwdArgs, +) -> Tuple[Any, Any, Any, Any, Dict[str, Any]]: + """Forward fake-impl for :func:`_linear_forward_impl`. + + Returns ``(out, new_weight_workspace, tensors_to_save, None, + ctx_attrs)`` -- the same tuple shape as the eager impl, but built + from *fake* values (``make_fake_empty`` wrappers / ``make_empty`` + storages / ``torch.empty`` plains / aliased forward args / ``None``). + The ``fake_impl`` -> layout adapter in + :mod:`transformer_engine.pytorch.dynamo` reads the slot layout off + these fake values (and nulls aliased saved slots for the + ``register_fake`` kernel). + + All ``set_usage`` side effects on the live quantizers happen here + and are observed by both the real fwd impl and backward. + """ + fp8 = args.fp8 + debug = args.debug + fp8_or_debug = fp8 or debug + activation_dtype = args.activation_dtype + output_quantizer = args.output_quantizer + input_quantizer = args.input_quantizer + weight_quantizer = args.weight_quantizer + weight = args.weight + inp = args.inp + bias = args.bias + + save_original_input = args.save_original_input + if args.backward_override == "high_precision": + save_original_input = True + + out_features, in_features = weight.shape + assert inp.shape[-1] == in_features, "GEMM not possible" + + tp_world_size = get_distributed_world_size(args.tp_group) + backward_needs_input = args.is_grad_enabled and args.weight_requires_grad + with_input_all_gather_nccl = ( + args.parallel_mode == "column" + and args.sequence_parallel + and not args.ub_overlap_ag_fprop + ) + + # Input pipeline -- mirror ``_linear_forward_impl``'s ``set_usage`` + # calls and classify the ``saved_inputmat`` slot end-state: + # aliased to ``args.inp``, fresh quantized storage, or plain cast. + inputmat_is_storage = False + inputmat_aliases_inp = False + own_quantized_input = False + inputmat_total_shape: List[int] = list(inp.shape) + + if with_input_all_gather_nccl or args.ub_overlap_ag_fprop: + if fp8_or_debug: + if input_quantizer is None: + raise ValueError("Missing quantizer for input tensor") + if not isinstance(inp, QuantizedTensorStorage) and not args.custom: + own_quantized_input = True + input_quantizer.set_usage( + rowwise=True, + columnwise=backward_needs_input and args.backward_override is None, + ) + if isinstance( + input_quantizer, (Float8CurrentScalingQuantizer, Float8Quantizer) + ): + input_quantizer.set_usage(columnwise=False) + if save_original_input: + input_quantizer.set_usage(columnwise=False) + own_quantized_input = False + inputmat_is_storage = True + else: + inputmat_aliases_inp = inp.dtype == activation_dtype + # All-gather inflates the leading dim of the GEMM-input shape. + inputmat_total_shape = list(inp.shape) + inputmat_total_shape[0] *= tp_world_size + else: + if fp8_or_debug: + if isinstance(inp, QuantizedTensorStorage): + inp.update_usage(rowwise_usage=True) + inputmat_is_storage = True + inputmat_aliases_inp = True + else: + if input_quantizer is None: + raise ValueError("Missing quantizer for input tensor") + input_quantizer.set_usage( + rowwise=True, + columnwise=( + backward_needs_input + and not save_original_input + and args.backward_override is None + ), + ) + inputmat_is_storage = True + own_quantized_input = True + else: + inputmat_aliases_inp = inp.dtype == activation_dtype + + # ``save_original_input`` / ``backward_override="high_precision"`` + # flip ``inputmat`` back to ``args.inp`` at the tail of the impl; + # mirror that here so the saved slot ends up aliased. + if save_original_input: + inputmat_aliases_inp = True + inputmat_is_storage = False + + # Weight pipeline -- mirror ``quantize_weight`` / ``cast_if_needed``. + # ``new_weight_workspace`` is a fresh fake storage only on the + # cache-miss + ``cache_weight`` combination, else ``None``. + new_weight_workspace: Any = None + weightmat_is_storage = False + weightmat_aliases_weight = False + if fp8_or_debug: + if weight_quantizer is not None and ( + not isinstance(weight, QuantizedTensor) or debug + ): + columnwise_usage = ( + args.is_grad_enabled and args.input_requires_grad and not args.is_fsdp2 + ) + if args.backward_override is not None: + columnwise_usage = False + if not columnwise_usage: + columnwise_usage = ( + is_fp8_activation_recompute_enabled() + and not in_fp8_activation_recompute_phase() + ) + weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + elif isinstance(weight, QuantizedTensor): + weight_quantizer = weight._quantizer + + if isinstance(weight, QuantizedTensorStorage): + # Primary-quantized weight: the impl reuses it as ``weightmat``. + weightmat_is_storage = True + weightmat_aliases_weight = True + else: + weightmat_is_storage = True + workspace = args.weight_workspace + if workspace is not None and not _is_weight_workspace_valid( + workspace, weight_quantizer + ): + workspace = None + if workspace is None and args.cache_weight: + # Fresh FP8 weight workspace -- a ``*TensorStorage`` + # (``weight_quantizer`` is ``internal``). + new_weight_workspace = weight_quantizer.make_empty( + list(weight.shape), dtype=activation_dtype, device=weight.device + ) + else: + weightmat_aliases_weight = weight.dtype == activation_dtype + + if output_quantizer is not None: + output_quantizer.set_usage(rowwise=True, columnwise=False) + + # Post-comm output shape (the value that leaves the op). + gemm_out_shape: List[int] = list(inputmat_total_shape[:-1]) + [out_features] + if args.ub_overlap_rs_fprop: + out_shape: List[int] = list(inp.shape) + out_shape[0] //= tp_world_size + out_shape[-1] = out_features + elif args.parallel_mode == "row" and args.tp_size > 1 and args.sequence_parallel: + out_shape = list(gemm_out_shape) + out_shape[0] //= tp_world_size + else: + out_shape = list(gemm_out_shape) + + # User-output [0] -- the GEMM result. ``Float8Tensor`` is the only + # quantized wrapper this op produces directly; other quantizer + # families flow their workspace through ``new_weight_workspace`` + # instead. The quantized output uses ``make_fake_empty`` -- the + # Dynamo-safe wrapper allocator (``make_empty`` cannot build a + # wrapper in-trace because it proxies the live quantizer). + if output_quantizer is not None: + out = output_quantizer.make_fake_empty( + tuple(out_shape), dtype=activation_dtype, device=inp.device + ) + else: + out = torch.empty(tuple(out_shape), dtype=activation_dtype, device=inp.device) + + saved_values: List[Any] = [] + + if args.is_grad_enabled: + # Post-forward ``set_usage`` -- mirrors ``_linear_forward_impl`` + # so backward observes the same row/col layout on the input + # quantizer the impl ended up with. + if ( + backward_needs_input + and own_quantized_input + and inputmat_is_storage + and not save_original_input + ): + if args.backward_override is not None: + input_quantizer.set_usage(rowwise=True, columnwise=False) + elif ( + args.backward_input_needs_gather + and weight_quantizer.supports_only_rowwise_all_gather() + ): + input_quantizer.set_usage(rowwise=True, columnwise=False) + else: + input_quantizer.set_usage(rowwise=False, columnwise=True) + + # Slot 0 -- ``saved_inputmat``: absent / aliased to ``inp`` / + # fresh quantized storage / plain cast (mutually exclusive). + # An aliased slot returns the actual forward arg ``inp``; the + # adapter detects the identity and nulls it in the payload. + if not backward_needs_input: + saved_values.append(None) + elif inputmat_aliases_inp: + saved_values.append(inp) + elif inputmat_is_storage: + saved_values.append( + input_quantizer.make_empty( + list(inp.shape), dtype=activation_dtype, device=inp.device + ) + ) + else: + saved_values.append( + torch.empty(tuple(inp.shape), dtype=activation_dtype, device=inp.device) + ) + + # Slot 1 -- ``wt_save``. The saved storage's quantizer must + # match the one the impl uses for re-quantization, which is + # ``weight._quantizer`` for already-quantized weights. FSDP2 + # re-quantizes from the all-gathered weight on backward, so + # the slot is absent in that case. + weight_quantizer_for_save = ( + weight._quantizer + if isinstance(weight, QuantizedTensor) + else args.weight_quantizer + ) + if weightmat_aliases_weight: + saved_values.append(weight) + elif args.is_fsdp2: + saved_values.append(None) + elif weightmat_is_storage: + saved_values.append( + weight_quantizer_for_save.make_empty( + list(weight.shape), dtype=activation_dtype, device=weight.device + ) + ) + else: + saved_values.append( + torch.empty(tuple(weight.shape), dtype=activation_dtype, device=weight.device) + ) + + # Slot 2 -- ``saved_weight`` (always aliased to ``weight``). + # Slot 3 -- ``saved_bias`` (aliased to ``bias`` or absent). + saved_values.append(weight) + saved_values.append(bias if bias is not None else None) + + if args.fsdp_group is not None and args.is_grad_enabled: + raise NotImplementedError( + "Compile-time Linear forward does not support manual TE FSDP " + "(fsdp_group is not None); use FSDP2 or MCore FSDP." + ) + + ctx_attrs: Dict[str, Any] = {"fsdp_shapes": []} + + tensors_to_save = tuple(saved_values) if args.is_grad_enabled else None + return out, new_weight_workspace, tensors_to_save, None, ctx_attrs + + +_linear_compiled_op = _te_register_custom_op( + op_name="linear", + input_tensors_for_grad=["weight", "inp", "bias"], + fwd_arg_type=LinearFwdArgs, + fwd_impl=_linear_forward_impl, + fwd_fake_impl=_linear_forward_fake_impl, + setup_context=_linear_setup_ctx, + backward_arg_type=LinearBwdArgs, + backward_obj=LinearBwdArgs, + backward_impl=_linear_backward, + bwd_fake_impl=_linear_backward_fake_impl, +) + + class _Linear(torch.autograd.Function): """Linear semi-top level module Calls custom cuda extensions. @@ -1317,6 +1664,7 @@ def backward( bwd_args: LinearBwdArgs = ctx.backward_objects bwd_args.grad_output = grad_output bwd_args.setup_saved_tensors(ctx) + ctx.tensor_objects = None nvtx_label = "transformer_engine._Linear.backward" if bwd_args.ub_name is not None: nvtx_label = f"{nvtx_label}.{bwd_args.ub_name}" @@ -1716,7 +2064,6 @@ def reset_parameters(self, defer_init=False): elif self.parallel_mode == "column": set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1) - @no_torch_dynamo() def forward( self, inp: torch.Tensor, @@ -1795,12 +2142,18 @@ def forward( grad_output_quantizer, ) = quantizers - if is_grad_enabled: - linear_fn = _Linear.apply - autograd_ctx = [] - else: - linear_fn = _Linear.forward - autograd_ctx = [None] + # Under torch.compile we always dispatch through the registered + # custom op (it only takes ``fwd_args``); torch.library handles the + # no-grad case automatically. Otherwise fall back to the eager + # torch.autograd.Function (or its bare forward when grad is off). + use_compiled_op = torch.compiler.is_compiling() + if not use_compiled_op: + if is_grad_enabled: + linear_fn = _Linear.apply + autograd_ctx = [] + else: + linear_fn = _Linear.forward + autograd_ctx = [None] cache_name = None if (is_first_microbatch is None or self.is_fsdp2) else "weight" weight_workspace = ( @@ -1895,13 +2248,16 @@ def forward( cpu_offloading=is_cpu_offload_enabled(), is_grad_enabled=is_grad_enabled, ) - out, new_weight_workspace = linear_fn( - *autograd_ctx, - weight_tensor, - inp, - linear_bias_tensor, - fwd_args, - ) + if use_compiled_op: + out, new_weight_workspace = _linear_compiled_op(fwd_args) + else: + out, new_weight_workspace = linear_fn( + *autograd_ctx, + weight_tensor, + inp, + linear_bias_tensor, + fwd_args, + ) if new_weight_workspace is not None and cache_name is not None: if isinstance(new_weight_workspace, torch.Tensor): diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 0c40723517..da999c3a5a 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -32,6 +32,14 @@ from .jit import jit_fuser +# Trace-friendly TE DType ids (Python ints). Materialized once at +# import time so that hot paths (RecipeState init, get_fp8_te_dtype_id) +# never touch the pybind11 enum, which Dynamo cannot trace. +_TE_DTYPE_ID_FLOAT8_E4M3 = int(tex.DType.kFloat8E4M3) +_TE_DTYPE_ID_FLOAT8_E5M2 = int(tex.DType.kFloat8E5M2) +_TE_DTYPE_ID_FLOAT4_E2M1 = int(tex.DType.kFloat4E2M1) + + __all__ = [ "autocast", "quantized_model_init", @@ -286,6 +294,17 @@ def get_fp8_te_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType return tex.DType.kFloat8E5M2 +def get_fp8_te_dtype_id(fp8_recipe: Recipe, fprop_tensor: bool = True) -> int: + """Trace-friendly variant of :func:`get_fp8_te_dtype` returning the + integer id of the TE ``DType`` enum. Use this on any code path that + may be traced by ``torch.compile``.""" + if fp8_recipe.fp8_format == Format.E4M3 or ( + fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor + ): + return _TE_DTYPE_ID_FLOAT8_E4M3 + return _TE_DTYPE_ID_FLOAT8_E5M2 + + def get_fp4_te_dtype(fp4_recipe: Recipe) -> tex.DType: """Get fp4 data type according to recipe and tensor""" if fp4_recipe.fp4_format == Format.E2M1: @@ -293,6 +312,14 @@ def get_fp4_te_dtype(fp4_recipe: Recipe) -> tex.DType: raise ValueError(f"Unsupported FP4 format: {fp4_recipe.fp4_format}") +def get_fp4_te_dtype_id(fp4_recipe: Recipe) -> int: + """Trace-friendly variant of :func:`get_fp4_te_dtype` returning the + integer id of the TE ``DType`` enum.""" + if fp4_recipe.fp4_format == Format.E2M1: + return _TE_DTYPE_ID_FLOAT4_E2M1 + raise ValueError(f"Unsupported FP4 format: {fp4_recipe.fp4_format}") + + def get_fp8_max(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: """Get max representible FP8 value.""" if fp8_recipe.fp8_format == Format.E4M3 or ( @@ -1404,7 +1431,7 @@ def __init__( self.mode = mode self.num_quantizers = num_quantizers self.roles = roles - self.dtype = get_fp8_te_dtype(recipe, mode == "forward") + self.dtype = get_fp8_te_dtype_id(recipe, mode == "forward") # Allocate buffers if device is None: @@ -1453,7 +1480,7 @@ def __init__( self.mode = mode self.num_quantizers = num_quantizers self.roles = roles - self.dtype = get_fp8_te_dtype(recipe, mode == "forward") + self.dtype = get_fp8_te_dtype_id(recipe, mode == "forward") # Allocate buffers if device is None: @@ -1496,7 +1523,7 @@ def __init__( self.mode = mode self.num_quantizers = num_quantizers self.roles = roles - self.dtype = get_fp8_te_dtype(recipe, mode == "forward") + self.dtype = get_fp8_te_dtype_id(recipe, mode == "forward") # Allocate buffers if device is None: @@ -1536,9 +1563,9 @@ def __init__( self.mode = mode self.num_quantizers = num_quantizers self.roles = roles - self.qx_dtype = get_fp8_te_dtype(recipe, True) - self.qw_dtype = get_fp8_te_dtype(recipe, True) - self.qgrad_dtype = get_fp8_te_dtype(recipe, False) + self.qx_dtype = get_fp8_te_dtype_id(recipe, True) + self.qw_dtype = get_fp8_te_dtype_id(recipe, True) + self.qgrad_dtype = get_fp8_te_dtype_id(recipe, False) # Allocate buffers if device is None: @@ -1621,7 +1648,7 @@ def __init__( self.mode = mode self.num_quantizers = num_quantizers self.roles = roles - self.dtype = get_fp4_te_dtype(recipe) + self.dtype = get_fp4_te_dtype_id(recipe) # Allocate buffers if device is None: @@ -1837,12 +1864,17 @@ def make_quantizers(self) -> list: roles = self.roles if roles is None: - warnings.warn( - "CustomRecipeState: no QuantizerRole list provided by the module/op. " - "Falling back to bare QuantizerRole() defaults. " - "Override get_quantizer_roles() to provide meaningful roles.", - stacklevel=2, - ) + # Dynamo cannot trace the Python builtin ``_warnings.warn``, + # which graph-breaks any ``fullgraph=True`` compile that + # eventually calls ``make_quantizers``. The warning is + # informational only and is safe to skip under compile. + if not torch.compiler.is_compiling(): + warnings.warn( + "CustomRecipeState: no QuantizerRole list provided by the module/op. " + "Falling back to bare QuantizerRole() defaults. " + "Override get_quantizer_roles() to provide meaningful roles.", + stacklevel=2, + ) roles = [QuantizerRole() for _ in range(self.num_quantizers)] # qfactory must return a Quantizer or QuantizerRequest for every slot. diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index a7722f777e..8d718b3b12 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -5,7 +5,7 @@ """Pure Python base classes for quantization.""" from __future__ import annotations -from typing import Optional, Tuple, Iterable, Any, Dict, Union +from typing import Optional, Tuple, Iterable, Any, Dict, List, Union import abc import warnings import math @@ -14,6 +14,7 @@ from torch.utils._pytree import tree_map from transformer_engine.common.recipe import Recipe +from transformer_engine.pytorch.fp8_dtype import from_tex from transformer_engine.pytorch.tensor._quantization_helpers import ( _QuantizeFunc, _IdentityFunc, @@ -21,6 +22,80 @@ ) +# Maps a Quantizer subclass's ``__qualname__`` to the class object. Populated +# lazily via :meth:`Quantizer.__init_subclass__` and consumed by +# :meth:`Quantizer._unflatten` to dispatch reconstruction to the right +# subclass when a TE custom op is unpacked under ``torch.compile``. +_QUANTIZER_REGISTRY: Dict[str, type] = {} + + +def _quantizer_subclass_snapshot( + quantizer: Optional["Quantizer"], +) -> Optional[Tuple[Tuple[str, Any], ...]]: + """Return a Dynamo-guard-stable snapshot of a quantizer, or ``None``. + + Used by tensor subclasses (e.g. :class:`Float8Tensor`) to embed a + tensor-free, comparable representation of their live + :class:`Quantizer` in the ``meta`` dict returned from + ``__tensor_flatten__``. PyTorch's tensor-subclass metadata guard + diff-checks that dict via ``dict.__eq__`` on every entry into the + compiled region, so values that resolve to elementwise tensor + comparison or identity-only equality (live ``torch.Tensor`` + objects, ``ProcessGroup``, the live quantizer instance itself) + cannot appear there. + + The snapshot is a sorted tuple of ``(key, value)`` pairs derived + from ``quantizer._flatten()`` whenever the quantizer's state is + fully expressible without tensors (an empty trailing tensor list + in the ``_flatten`` triplet). Quantizers carrying tensors in their + state (e.g. :class:`Float8Quantizer`'s ``scale`` / ``amax``) and + quantizers that don't implement ``_flatten`` produce ``None``; + in that case the subclass's ``__tensor_unflatten__`` will + rebuild the wrapper with ``quantizer=None`` and any code that + needs the live quantizer must source it from the bucket-level + opaque metadata flowing through the inner custom op. + """ + if quantizer is None: + return None + try: + meta, _pg, tensors = quantizer._flatten() + except NotImplementedError: + return None + if tensors: + return None + if hasattr(meta, "_data"): + meta_dict = meta._data + elif isinstance(meta, dict): + meta_dict = meta + else: + return None + return tuple(sorted(meta_dict.items(), key=lambda kv: kv[0])) + + +def _quantizer_from_subclass_snapshot( + snapshot: Optional[Tuple[Tuple[str, Any], ...]], +) -> Optional["Quantizer"]: + """Inverse of :func:`_quantizer_subclass_snapshot`. + + Rebuilds the quantizer from the qualname stored in the snapshot's + ``"_qcls"`` entry, dispatching via :func:`Quantizer._unflatten` + (and so via the right subclass's ``_do_unflatten``). The + reconstructed quantizer's process-group reference is always + ``None`` -- live ``ProcessGroup`` objects cannot survive the + snapshot round trip; callers that need a real process group + obtain it via the bucket-level opaque metadata instead. + """ + if snapshot is None: + return None + meta_dict = dict(snapshot) + return Quantizer._unflatten(meta_dict, None, []) + +# Same idea for lightweight QuantizedTensorStorage shells. Populated via +# :meth:`QuantizedTensorStorage.__init_subclass__` and consumed by +# :meth:`QuantizedTensorStorage._torch_compile_unflatten`. +_STORAGE_REGISTRY: Dict[str, type] = {} + + # Custom ops that should pass through __torch_dispatch__ without unwrapping # QuantizedTensor subclasses (e.g. Float8Tensor). Register ops here that # handle quantized tensors internally. @@ -46,6 +121,46 @@ class QuantizedTensorStorage: _dtype: torch.dtype _quantizer: Optional[Quantizer] + # ------------------------------------------------------------------ # + # Declarative schema for the unified flatten / unflatten machinery # + # (consumed by both the storage ``_torch_compile_flatten`` protocol # + # and ``QuantizedTensor``'s PyTorch ``__tensor_flatten__`` helper). # + # ------------------------------------------------------------------ # + + # Names of optional tensor attributes on the instance, in canonical + # order. Each name must be an attribute on ``self`` and must be + # accepted as a kwarg by ``cls(**kwargs)`` (potentially after + # remapping through :attr:`_FLATTEN_CTOR_KWARG`). + _FLATTEN_TENSOR_ATTRS: Tuple[str, ...] = () + + # Maps each entry in :attr:`_FLATTEN_TENSOR_ATTRS` to one of + # ``"rowwise"`` / ``"columnwise"`` / ``"always"``. Consumed by + # :meth:`Quantizer.create_storage_metadata` to translate a live + # quantizer's ``rowwise_usage`` / ``columnwise_usage`` flags into + # per-attribute presence (``has_*``) flags at output-spec time. + # Unmapped attributes default to ``"always"``. + _FLATTEN_TENSOR_USAGE: Dict[str, str] = {} + + # Names of value-stable scalar / enum attributes needed to round-trip + # the instance. Same naming / kwarg conventions as + # :attr:`_FLATTEN_TENSOR_ATTRS`. + _FLATTEN_META_ATTRS: Tuple[str, ...] = () + + # Map from attribute name to constructor kwarg name, used when they + # differ (e.g. ``_data`` -> ``data``). Identity by default. + _FLATTEN_CTOR_KWARG: Dict[str, str] = {} + + @classmethod + def _flatten_meta_overrides(cls, meta: Dict[str, Any]) -> Dict[str, Any]: + """Hook for last-mile meta value massaging before unflatten dispatches + to ``cls(**kwargs)``. Default: no-op. + + Used today by :class:`Float8Tensor` to bridge :class:`FP8DType` + (carried by the subclass output spec) back to the native + ``tex.DType`` accepted by pybind-bound kernels. + """ + return meta + def update_usage( self, rowwise_usage: Optional[bool] = None, @@ -130,6 +245,144 @@ def copy_from_storage(self, src: QuantizedTensorStorage) -> None: f"{self.__class__.__name__} class does not implement copy_from_storage function" ) + # ------------------------------------------------------------------ # + # torch.compile flatten / unflatten protocol + # ------------------------------------------------------------------ # + + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + _STORAGE_REGISTRY[cls.__qualname__] = cls + + def __eq__(self, other: object) -> bool: + return self is other + + def __hash__(self) -> int: + return id(self) + + @classmethod + def _flatten_ctor_kw(cls, attr_name: str) -> str: + """Return the constructor kwarg name corresponding to ``attr_name``. + + Identity unless overridden via :attr:`_FLATTEN_CTOR_KWARG`. + """ + return cls._FLATTEN_CTOR_KWARG.get(attr_name, attr_name) + + @staticmethod + def _flatten_presence_key(attr_name: str) -> str: + """Return the ``has_*`` meta key indicating whether ``attr_name`` is + present in the flattened payload. Derived from the attribute name + (with the leading underscore stripped) so the static metadata + constructors in ``float8_tensor.py`` etc. don't need to know about + :attr:`_FLATTEN_CTOR_KWARG` remapping. + """ + return f"has_{attr_name.lstrip('_')}" + + def _torch_compile_flatten( + self, + ) -> Tuple[Any, Optional["torch.distributed.ProcessGroup"], List[torch.Tensor]]: + """Pack storage state into the ``(meta, pg, tensors)`` triplet + consumed by :mod:`transformer_engine.pytorch.dynamo`. + + Generic implementation driven by :attr:`_FLATTEN_TENSOR_ATTRS`, + :attr:`_FLATTEN_META_ATTRS`, and :attr:`_FLATTEN_CTOR_KWARG`. + Quantizer-with-tensors (e.g. :class:`Float8Quantizer`'s + ``scale`` / ``amax``) is round-tripped via + :meth:`Quantizer._flatten`; quantizer tensors trail the + storage's own tensors in the flat list. + """ + from transformer_engine.pytorch.dynamo import ( # pylint: disable=import-outside-toplevel + OpaqueSimpleMetadata, + ) + + tensors: List[torch.Tensor] = [] + meta_dict: Dict[str, Any] = {"_qstorage_cls": type(self).__qualname__} + # Tensor-wrapper fields are only relevant when ``self`` is a live + # ``torch.Tensor`` (e.g. ``Float8Tensor`` flattened directly into a + # storage payload by ``_flatten_subclass_into_slots``); a bare + # storage shell has no outer shape / requires_grad / device. + if isinstance(self, torch.Tensor): + meta_dict.update( + { + "is_tensor": True, + "shape": torch.Size(self.shape), + "requires_grad": self.requires_grad, + "device": self.device, + } + ) + for attr in self._FLATTEN_META_ATTRS: + meta_dict[self._flatten_ctor_kw(attr)] = getattr(self, attr) + for attr in self._FLATTEN_TENSOR_ATTRS: + tensor = getattr(self, attr) + present = tensor is not None + meta_dict[self._flatten_presence_key(attr)] = present + if present: + tensors.append(tensor) + quantizer_meta = None + process_group = None + if self._quantizer is not None: + quantizer_meta, process_group, q_tensors = self._quantizer._flatten() + tensors.extend(q_tensors) + meta_dict["quantizer_meta"] = quantizer_meta + return OpaqueSimpleMetadata(meta_dict), process_group, tensors + + @classmethod + def _torch_compile_do_unflatten( + cls, + meta: Any, + process_group: Optional["torch.distributed.ProcessGroup"], + tensors: List[torch.Tensor], + ) -> "QuantizedTensorStorage": + """Reconstruct ``cls`` from a triplet produced by + :meth:`_torch_compile_flatten`. Generic; driven by the same + ``_FLATTEN_*`` declarations. + """ + meta = cls._flatten_meta_overrides(meta) + tensor_iter = iter(tensors) + kwargs: Dict[str, Any] = {} + for attr in cls._FLATTEN_TENSOR_ATTRS: + kw = cls._flatten_ctor_kw(attr) + kwargs[kw] = next(tensor_iter) if meta[cls._flatten_presence_key(attr)] else None + quantizer = None + if meta["quantizer_meta"] is not None: + quantizer = Quantizer._unflatten( + meta["quantizer_meta"], process_group, list(tensor_iter) + ) + for attr in cls._FLATTEN_META_ATTRS: + kw = cls._flatten_ctor_kw(attr) + kwargs[kw] = meta[kw] + kwargs["quantizer"] = quantizer + if meta.get("is_tensor", False): + kwargs.update( + { + "shape": meta["shape"], + "dtype": kwargs["fake_dtype"], + "requires_grad": meta["requires_grad"], + "device": meta["device"], + } + ) + return cls(**kwargs) + + @classmethod + def _torch_compile_unflatten( + cls, + meta: Any, + process_group: Optional["torch.distributed.ProcessGroup"], + tensors: List[torch.Tensor], + ) -> "QuantizedTensorStorage": + """Dispatch to the right storage subclass based on metadata.""" + storage_cls = meta["_qstorage_cls"] + target = _STORAGE_REGISTRY.get(storage_cls) + if target is None: + raise ValueError( + f"No QuantizedTensorStorage subclass registered under " + f"qualname {storage_cls!r}; known: {sorted(_STORAGE_REGISTRY)}" + ) + return target._torch_compile_do_unflatten(meta, process_group, tensors) + + + +TensorOrQuantized = Union[torch.Tensor, QuantizedTensorStorage] + def prepare_for_saving( *tensors: Union[torch.Tensor, QuantizedTensorStorage], @@ -218,6 +471,13 @@ class Quantizer(abc.ABC): """ rowwise_usage: bool + # The :class:`QuantizedTensorStorage` subclass produced by this + # quantizer's quantize / make_empty path. Consumed by + # :meth:`create_storage_metadata` to declare a ``("storage", ...)`` + # output payload that round-trips through the generic + # :meth:`QuantizedTensorStorage._torch_compile_do_unflatten`. + _storage_cls: type["QuantizedTensorStorage"] + """Whether to construct quantized tensors with "column-wise usage" Hand-wave explanation: Consider the matrix multiplication C = A^T @@ -378,6 +638,266 @@ def get_usages(self) -> Dict[str, bool]: "columnwise": self.columnwise_usage, } + # ------------------------------------------------------------------ # + # torch.compile flatten / unflatten protocol + # ------------------------------------------------------------------ # + + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + # Auto-register every Quantizer subclass so ``_unflatten`` can + # dispatch back to it by ``__qualname__``. + _QUANTIZER_REGISTRY[cls.__qualname__] = cls + + # ---- Declarative schema for the generic :meth:`_flatten` / ---- # + # ---- :meth:`_do_unflatten` implementations below. ---- # + + # ``__init__`` kwarg name for ``self.dtype`` (e.g. ``"fp8_dtype"``, + # ``"fp4_dtype"``). + _DTYPE_INIT_KWARG: str = "fp8_dtype" + + # Scalar attribute names (besides ``dtype`` / ``rowwise_usage`` / + # ``columnwise_usage``) threaded through ``__init__``. The kwarg name + # is assumed to match the attribute name. + _INIT_META_ATTRS: Tuple[str, ...] = () + + # Scalar attribute names (besides ``internal`` / ``optimize_for_gemm``) + # set on the instance after ``__init__``. + _POST_INIT_META_ATTRS: Tuple[str, ...] = () + + # Tensor attribute names threaded through ``__init__``, in flatten + # order. + _INIT_TENSOR_ATTRS: Tuple[str, ...] = () + + # Tensor attribute names set on the instance after ``__init__``. + _POST_INIT_TENSOR_ATTRS: Tuple[str, ...] = () + + # Attribute name on ``self`` holding the (optional) ``ProcessGroup``, + # or ``None`` if the quantizer has no PG. + _PG_ATTR: Optional[str] = None + # ``__init__`` kwarg name to thread the PG through. ``None`` means + # set ``_PG_ATTR`` directly after ``__init__``. + _PG_INIT_KWARG: Optional[str] = None + + # Hardcoded ``__init__`` kwargs not derived from meta (e.g. + # ``device=torch.device("cuda")`` for ``Float8CurrentScalingQuantizer``). + _FIXED_INIT_KWARGS: Dict[str, Any] = {} + + def _flatten( + self, + ) -> Tuple[Any, Optional["torch.distributed.ProcessGroup"], List[torch.Tensor]]: + """Pack this quantizer's state into the + ``(meta, process_group, tensors)`` triplet expected by the + flattenable bucket in :mod:`transformer_engine.pytorch.dynamo`. + + Generic implementation driven by the declarative schema attrs above. + Subclasses only declare which scalars / tensors go through + ``__init__`` vs. are set post-init; the base class round-trips + ``dtype`` / ``rowwise_usage`` / ``columnwise_usage`` and + ``internal`` / ``optimize_for_gemm`` on every quantizer. + """ + from .dynamo import OpaqueSimpleMetadata # pylint: disable=import-outside-toplevel + + cls = type(self) + meta_dict: Dict[str, Any] = { + "_qcls": cls.__qualname__, + "dtype": self.dtype, + "rowwise_usage": self.rowwise_usage, + "columnwise_usage": self.columnwise_usage, + "internal": self.internal, + "optimize_for_gemm": self.optimize_for_gemm, + } + for attr in (*cls._INIT_META_ATTRS, *cls._POST_INIT_META_ATTRS): + meta_dict[attr] = getattr(self, attr) + tensors = [ + getattr(self, attr) + for attr in (*cls._INIT_TENSOR_ATTRS, *cls._POST_INIT_TENSOR_ATTRS) + ] + pg = getattr(self, cls._PG_ATTR) if cls._PG_ATTR else None + return OpaqueSimpleMetadata(meta_dict), pg, tensors + + @classmethod + def _do_unflatten( + cls, + meta: Any, + process_group: Optional["torch.distributed.ProcessGroup"], + tensors: List[torch.Tensor], + ) -> "Quantizer": + """Reconstruct an instance of ``cls`` from the triplet returned by a + previous :meth:`_flatten` on the same subclass. Generic; driven + by the declarative schema attrs. + """ + init_kwargs: Dict[str, Any] = { + cls._DTYPE_INIT_KWARG: meta["dtype"], + "rowwise": meta["rowwise_usage"], + "columnwise": meta["columnwise_usage"], + } + for attr in cls._INIT_META_ATTRS: + init_kwargs[attr] = meta[attr] + if cls._PG_INIT_KWARG is not None: + init_kwargs[cls._PG_INIT_KWARG] = process_group + init_kwargs.update(cls._FIXED_INIT_KWARGS) + tensor_iter = iter(tensors) + for attr in cls._INIT_TENSOR_ATTRS: + init_kwargs[attr] = next(tensor_iter) + q = cls(**init_kwargs) + q.internal = meta["internal"] + q.optimize_for_gemm = meta["optimize_for_gemm"] + for attr in cls._POST_INIT_META_ATTRS: + setattr(q, attr, meta[attr]) + for attr in cls._POST_INIT_TENSOR_ATTRS: + setattr(q, attr, next(tensor_iter)) + if cls._PG_ATTR is not None and cls._PG_INIT_KWARG is None: + setattr(q, cls._PG_ATTR, process_group) + return q + + @classmethod + def _unflatten( + cls, + meta: Any, + process_group: Optional["torch.distributed.ProcessGroup"], + tensors: List[torch.Tensor], + ) -> "Quantizer": + """Dispatch to the right subclass's :meth:`_do_unflatten` based on + the ``"_qcls"`` qualname stored in ``meta``. + """ + qcls = meta["_qcls"] + target = _QUANTIZER_REGISTRY.get(qcls) + if target is None: + raise ValueError( + f"No Quantizer subclass registered under qualname {qcls!r}; " + f"known: {sorted(_QUANTIZER_REGISTRY)}" + ) + return target._do_unflatten(meta, process_group, tensors) + + def _storage_scalars(self) -> Dict[str, Any]: + """Per-quantizer scalar fields for the storage's ``_FLATTEN_META_ATTRS``. + + Keys are constructor kwarg names (matching the values of + :attr:`QuantizedTensorStorage._FLATTEN_CTOR_KWARG`). ``fake_dtype`` + is supplied separately by :meth:`create_storage_metadata`; subclasses + only need to return their quantizer-specific scalars (e.g. + ``fp8_dtype``, ``with_gemm_swizzled_scales``). + """ + raise NotImplementedError( + f"{type(self).__name__} class does not implement _storage_scalars; " + "required for torch.compile output specs that emit a " + "QuantizedTensorStorage." + ) + + # Keys in :meth:`_storage_scalars` whose values are pybind enums + # (``transformer_engine_torch.DType``) and must be converted to the + # Python ``FP8DType`` proxy for :meth:`create_metadata`. The opaque + # registration in :mod:`fp8_dtype` is enough to flow ``tex.DType`` + # through Dynamo as an FX constant, but + # :meth:`autograd.Function.apply` -- used by + # :func:`_ToSubclassFn.reassemble_with_autograd` -- still rejects + # opaque values via its proxy-conversion check. The reverse + # conversion lives in the tensor subclass's + # :meth:`_flatten_meta_overrides`. + _SUBCLASS_META_TEX_KEYS: Tuple[str, ...] = ("fp8_dtype",) + + def create_metadata( + self, + *, + fake_dtype: torch.dtype, + requires_grad: bool = False, + ) -> Tuple[Tuple[str, ...], Dict[str, Any]]: + """Return ``(inner_names, meta)`` for :meth:`QuantizedTensor.__tensor_unflatten__`. + + Generic implementation driven by the storage class's + :attr:`QuantizedTensorStorage._FLATTEN_TENSOR_ATTRS` / + :attr:`QuantizedTensorStorage._FLATTEN_TENSOR_USAGE` plus this + quantizer's :meth:`_storage_scalars`. ``inner_names`` follows the + declaration order of ``_FLATTEN_TENSOR_ATTRS`` so it matches the + slot order produced by :meth:`QuantizedTensor.__tensor_flatten__` + in :func:`dynamo._flatten_value_into`. + + ``quantizer_snapshot`` is forced to ``None`` on this path: + rebuilding a live :class:`Quantizer` inside + ``__tensor_unflatten__`` would force Dynamo to trace the + constructor, which routinely trips + ``UserDefinedObjectVariable(...Quantizer)``. + """ + storage_cls = type(self)._storage_cls + usage_flag = { + "rowwise": self.rowwise_usage, + "columnwise": self.columnwise_usage, + "always": True, + } + inner_names = tuple( + attr + for attr in storage_cls._FLATTEN_TENSOR_ATTRS + if usage_flag[storage_cls._FLATTEN_TENSOR_USAGE.get(attr, "always")] + ) + scalars = self._storage_scalars() + for key in self._SUBCLASS_META_TEX_KEYS: + if key in scalars: + scalars[key] = from_tex(scalars[key]) + meta: Dict[str, Any] = { + **scalars, + "fake_dtype": fake_dtype, + "quantizer_snapshot": None, + "requires_grad": requires_grad, + } + return inner_names, meta + + def create_storage_metadata( + self, + *, + shape: Iterable[int], + fake_dtype: torch.dtype, + device: Optional[torch.device] = None, + ) -> Tuple[type["QuantizedTensorStorage"], Any, Optional[Any], int]: + """Return ``(cls, meta, process_group, tensor_count)`` describing + the ``("storage", ...)`` payload of a Dynamo output spec. + + The Dynamo layer hands the trailing + ``(meta, process_group, tensors[: tensor_count])`` triple to + :meth:`QuantizedTensorStorage._torch_compile_do_unflatten` to + reconstruct the freshly-quantized storage on the consumer side. + + Driven entirely by the storage's ``_FLATTEN_*`` schema plus a + per-quantizer :meth:`_storage_scalars` hook; ``has_*`` flags are + derived from ``rowwise_usage`` / ``columnwise_usage`` and the + storage's :attr:`QuantizedTensorStorage._FLATTEN_TENSOR_USAGE` + map. Quantizers with tensor state (e.g. :class:`Float8Quantizer`'s + ``scale`` / ``amax``) append those tensors after the storage's own + slots; :meth:`Quantizer._flatten` provides both the count and the + ``quantizer_meta`` payload needed to rebuild the quantizer. + """ + from .dynamo import OpaqueSimpleMetadata # pylint: disable=import-outside-toplevel + + if device is None: + device = torch.device("cuda") + del device, shape # storage-only path: no outer tensor view + storage_cls = type(self)._storage_cls + usage_flag = { + "rowwise": self.rowwise_usage, + "columnwise": self.columnwise_usage, + "always": True, + } + has_flags: Dict[str, bool] = {} + tensor_count = 0 + for attr in storage_cls._FLATTEN_TENSOR_ATTRS: + usage = storage_cls._FLATTEN_TENSOR_USAGE.get(attr, "always") + flag = usage_flag[usage] + has_flags[storage_cls._flatten_presence_key(attr)] = flag + if flag: + tensor_count += 1 + quantizer_meta, _, quantizer_tensors = self._flatten() + tensor_count += len(quantizer_tensors) + scalars = self._storage_scalars() + scalars["fake_dtype"] = fake_dtype + meta = OpaqueSimpleMetadata( + { + "_qstorage_cls": storage_cls.__qualname__, + **scalars, + **has_flags, + "quantizer_meta": quantizer_meta, + } + ) + return storage_cls, meta, None, tensor_count + class QuantizedTensor(torch.Tensor): """Abstract base class for tensor with quantized data @@ -686,13 +1206,13 @@ def maybe_update_inplace(arg, new_arg, schema_arg): out = super().__torch_dispatch__(func, types, args, kwargs) return out - @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - - # Do not force the QuantizedTensor type on the returned tensor - return torch._C._disabled_torch_function_impl(func, types, args, kwargs) + # Set as a class-level attribute rather than a ``@classmethod`` so that + # Dynamo recognises the canonical "torch_function disabled" idiom + # and can trace through custom-op calls that receive a + # QuantizedTensor subclass as an argument. As a method override, + # Dynamo bails with "cannot trace builtin + # torch._C._disabled_torch_function_impl". + __torch_function__ = torch._C._disabled_torch_function_impl def contiguous( self, memory_format: torch.memory_format = torch.contiguous_format @@ -713,6 +1233,65 @@ def get_metadata(self) -> Dict[str, Any]: f"{self.__class__.__name__} class does not implement get_metadata function" ) + # ------------------------------------------------------------------ # + # PyTorch wrapper-subclass flatten / unflatten # + # ------------------------------------------------------------------ # + # + # Driven by the same ``_FLATTEN_*_ATTRS`` / ``_FLATTEN_CTOR_KWARG`` + # declarations as :meth:`QuantizedTensorStorage._torch_compile_flatten`, + # plus the :meth:`_flatten_meta_overrides` hook (Float8Tensor uses it + # to bridge :class:`FP8DType` <-> ``tex.DType``). + # + # Per-subclass differences vs the storage path: PyTorch's protocol + # carries only attributes living on ``self`` (no quantizer tensors, + # no process group). Quantizers whose state contains tensors (e.g. + # :class:`Float8Quantizer`'s ``scale`` / ``amax``, + # :class:`NVFP4Quantizer`'s ``rht_matrix``) therefore round-trip via + # :func:`_quantizer_subclass_snapshot`, which bails to ``None``; the + # reconstructed tensor's ``_quantizer`` is ``None`` and downstream + # code that needs the live quantizer sources it from the bucket-level + # opaque metadata flowing alongside the inner op. + + def __tensor_flatten__(self) -> Tuple[list, dict]: + if not type(self)._FLATTEN_TENSOR_ATTRS: + raise NotImplementedError( + f"{type(self).__name__} did not declare _FLATTEN_TENSOR_ATTRS" + ) + inner: list = [ + attr for attr in self._FLATTEN_TENSOR_ATTRS if getattr(self, attr) is not None + ] + meta: Dict[str, Any] = { + "quantizer_snapshot": _quantizer_subclass_snapshot(self._quantizer), + "requires_grad": self.requires_grad, + } + for attr in self._FLATTEN_META_ATTRS: + meta[self._flatten_ctor_kw(attr)] = getattr(self, attr) + return inner, meta + + @classmethod + def __tensor_unflatten__( + cls, + inner_tensors: dict, + meta: dict, + outer_size, + outer_stride, + ) -> "QuantizedTensor": + meta = cls._flatten_meta_overrides(meta) + quantizer = _quantizer_from_subclass_snapshot(meta.get("quantizer_snapshot")) + kwargs: Dict[str, Any] = { + "shape": outer_size, + "dtype": meta["fake_dtype"], + "requires_grad": meta.get("requires_grad", False), + "quantizer": quantizer, + } + for attr in cls._FLATTEN_TENSOR_ATTRS: + kw = cls._flatten_ctor_kw(attr) + kwargs[kw] = inner_tensors.get(attr) + for attr in cls._FLATTEN_META_ATTRS: + kw = cls._flatten_ctor_kw(attr) + kwargs[kw] = meta[kw] + return cls(**kwargs) + @classmethod def make_like( cls, diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 914397b9b6..2e18465322 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -14,6 +14,7 @@ from transformer_engine_torch import DType as TE_DType from transformer_engine.common.recipe import Float8BlockScaling, Recipe from .storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage +from ..constants import canonicalize_te_dtype from ..quantized_tensor import QuantizedTensor, Quantizer from ._quantization_helpers import _IdentityFunc from ..utils import devices_match, round_up_to_nearest_multiple @@ -36,6 +37,10 @@ class Float8BlockQuantizer(Quantizer): force_pow_2_scales: bool block_scaling_dim: int + _storage_cls = Float8BlockwiseQTensorStorage + _INIT_META_ATTRS = ("amax_epsilon", "force_pow_2_scales", "block_scaling_dim") + _POST_INIT_META_ATTRS = ("block_len",) + def __init__( self, fp8_dtype: TE_DType, @@ -47,7 +52,7 @@ def __init__( block_scaling_dim: int = 2, ) -> None: super().__init__(rowwise=rowwise, columnwise=columnwise) - self.dtype = fp8_dtype + self.dtype = canonicalize_te_dtype(fp8_dtype) self.block_len = 128 self.force_pow_2_scales = force_pow_2_scales self.amax_epsilon = amax_epsilon @@ -244,7 +249,21 @@ def make_empty( **tensor_kwargs, ) - # Construct FP8 tensor + is_2d_scaled = self.block_scaling_dim == 2 + + # See ``Float8Quantizer.make_empty`` for the rationale. + if self.internal: + return Float8BlockwiseQTensorStorage( + rowwise_data, + rowwise_scale_inv, + columnwise_data, + columnwise_scale_inv, + self.dtype, + self, + is_2d_scaled, + fake_dtype=dtype, + ) + return Float8BlockwiseQTensor( shape=shape, dtype=dtype, @@ -254,7 +273,7 @@ def make_empty( columnwise_data=columnwise_data, columnwise_scale_inv=columnwise_scale_inv, quantizer=self, - is_2D_scaled=self.block_scaling_dim == 2, + is_2D_scaled=is_2d_scaled, requires_grad=requires_grad, ) @@ -266,6 +285,12 @@ def calibrate(self, tensor: torch.Tensor) -> None: def _get_compatible_recipe(self) -> Union[type[Recipe], None]: return Float8BlockScaling + def _storage_scalars(self) -> dict: + return { + "fp8_dtype": self.dtype, + "is_2D_scaled": self.block_scaling_dim == 2, + } + class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor): """Tensor class with FP8 data quantized via NxN blocks or 1xN blocks. diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index ed6091c85b..0758d2ae9d 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -4,7 +4,7 @@ """Tensor class with FP8 data""" from __future__ import annotations -from typing import Any, Optional, Tuple, Iterable, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import warnings import torch from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState @@ -18,12 +18,66 @@ ) from ..utils import canonicalize_process_group, devices_match from .storage.float8_tensor_storage import Float8TensorStorage, _FromFloat8Func -from ..quantized_tensor import QuantizedTensor, Quantizer +from ..quantized_tensor import ( + QuantizedTensor, + Quantizer, +) from ._quantization_helpers import _IdentityFunc -from ..constants import dist_group_type +from ..constants import canonicalize_te_dtype, dist_group_type +from ..fp8_dtype import FP8DType, to_tex aten = torch.ops.aten + +def _float8_make_fake_empty( + quantizer: "Quantizer", + shape: Iterable[int], + *, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, +) -> "Float8Tensor": + """Dynamo-safe ``Float8Tensor`` allocation shared by the FP8 quantizers. + + Mirrors the inner-tensor layout of ``make_empty`` (rowwise ``_data`` / + columnwise ``_transpose`` / ``_scale_inv``) but assembles the wrapper + through :meth:`QuantizedTensor.__tensor_unflatten__` -- which takes a + snapshot-free ``meta`` dict (``quantizer_snapshot=None``, ``FP8DType``) + rather than the live ``Quantizer`` / ``tex.DType`` constructor args that + Dynamo cannot proxy inside a traced frame. + """ + from ..dynamo import _contiguous_stride # pylint: disable=import-outside-toplevel + + if device is None: + device = torch.device("cuda") + shape = list(shape) + + alloc: Dict[str, torch.Tensor] = {} + if quantizer.rowwise_usage: + alloc["_data"] = torch.empty(shape, dtype=torch.uint8, device=device) + if quantizer.columnwise_usage: + transpose_shape = [shape[-1]] + list(shape[:-1]) + alloc["_transpose"] = torch.empty(transpose_shape, dtype=torch.uint8, device=device) + alloc["_scale_inv"] = torch.empty(1, dtype=torch.float32, device=device) + + inner_names, meta = quantizer.create_metadata(fake_dtype=dtype) + inner_dict = {name: alloc[name] for name in inner_names} + out = Float8Tensor.__tensor_unflatten__( + inner_dict, meta, tuple(shape), _contiguous_stride(shape) + ) + # Stamp the reassembly plan so the dynamo reassembly helper + # (:func:`transformer_engine.pytorch.dynamo._template_reassemble`) + # can rebuild this subclass from the op's flat ``Tensor[]`` payload + # by reading an attribute (Dynamo-safe) rather than calling + # ``__tensor_flatten__`` in-trace. + out._te_compile_unflatten_plan = (tuple(inner_names), meta) + # Stash the live quantizer on the template so the reassembly helper can + # restore it on the real output (``__tensor_unflatten__`` rebuilds with + # ``quantizer=None`` because the snapshot can't carry a live + # ``ProcessGroup`` / scale-amax tensors through Dynamo guards). + out._te_compile_quantizer = quantizer + return out + + _ops_to_preserve_subclass_in_fsdp2 = { torch.ops.aten.empty_like.default, torch.ops.aten.new_zeros.default, @@ -56,6 +110,9 @@ class Float8Quantizer(Quantizer): """FP8 datatype""" dtype: TE_DType + _storage_cls = Float8TensorStorage + _INIT_TENSOR_ATTRS = ("scale", "amax") + def __init__( self, scale: torch.Tensor, @@ -68,7 +125,7 @@ def __init__( super().__init__(rowwise=rowwise, columnwise=columnwise) self.scale = scale self.amax = amax - self.dtype = fp8_dtype + self.dtype = canonicalize_te_dtype(fp8_dtype) def copy(self) -> Float8Quantizer: """Create shallow copy""" @@ -142,12 +199,29 @@ def make_empty( pin_memory=pin_memory, ) - # Construct FP8 tensor + scale_inv = torch.empty(1, dtype=torch.float32, device=device, pin_memory=pin_memory) + + # Honor ``internal``: tex.quantize() returns a bare + # Float8TensorStorage when the quantizer is marked internal + # (lower CPU overhead, no autograd-aware subclass) and so should + # make_empty in order to stay shape/type-equivalent on every + # path that touches it (eager fast-path, fake-impl under + # torch.compile, etc.). + if self.internal: + return Float8TensorStorage( + data=data, + fp8_scale_inv=scale_inv, + fp8_dtype=self.dtype, + fake_dtype=dtype, + data_transpose=data_transpose, + quantizer=self, + ) + return Float8Tensor( shape=shape, dtype=dtype, data=data, - fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device, pin_memory=pin_memory), + fp8_scale_inv=scale_inv, fp8_dtype=self.dtype, requires_grad=requires_grad, data_transpose=data_transpose, @@ -155,6 +229,24 @@ def make_empty( device=device, ) + def make_fake_empty( + self, + shape: Iterable[int], + *, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, + ) -> Float8Tensor: + """Dynamo-safe analogue of :meth:`make_empty`. + + Builds the :class:`Float8Tensor` via + :meth:`QuantizedTensor.__tensor_unflatten__` (snapshot-free meta, + :class:`FP8DType`) instead of the live-quantizer constructor, so it + traces under ``torch.compile(fullgraph=True)`` -- where + :meth:`make_empty` trips on ``UserDefinedObjectVariable(Quantizer)`` + / ``UserDefinedObjectVariable(DType)``. + """ + return _float8_make_fake_empty(self, shape, dtype=dtype, device=device) + def calibrate(self, tensor: torch.Tensor) -> None: amin, amax = tensor.aminmax() self.amax.copy_(torch.max(-amin, amax)) @@ -223,6 +315,9 @@ def supports_only_rowwise_all_gather(self) -> bool: """ return True + def _storage_scalars(self) -> Dict[str, Any]: + return {"fp8_dtype": self.dtype} + class Float8CurrentScalingQuantizer(Quantizer): """Builder class for FP8 tensors with per-tensor current scaling @@ -255,6 +350,12 @@ class Float8CurrentScalingQuantizer(Quantizer): force_pow_2_scales: bool amax_epsilon: float + _storage_cls = Float8TensorStorage + _INIT_META_ATTRS = ("with_amax_reduction", "force_pow_2_scales", "amax_epsilon") + _PG_ATTR = "amax_reduction_group" + _PG_INIT_KWARG = "amax_reduction_group" + _FIXED_INIT_KWARGS = {"device": torch.device("cuda")} + def __init__( self, fp8_dtype: TE_DType, @@ -279,7 +380,7 @@ def __init__( stacklevel=2, ) del device, use_existing_amax, scale, amax # Kept for backward compatibility - self.dtype = fp8_dtype + self.dtype = canonicalize_te_dtype(fp8_dtype) self.with_amax_reduction = with_amax_reduction self.amax_reduction_group = amax_reduction_group self.force_pow_2_scales = force_pow_2_scales @@ -366,12 +467,23 @@ def make_empty( device=device, pin_memory=pin_memory, ) - # Construct FP8 tensor + scale_inv = torch.empty(1, dtype=torch.float32, device=device, pin_memory=pin_memory) + + if self.internal: + return Float8TensorStorage( + data=data, + fp8_scale_inv=scale_inv, + fp8_dtype=self.dtype, + fake_dtype=dtype, + data_transpose=data_transpose, + quantizer=self, + ) + return Float8Tensor( shape=shape, dtype=dtype, data=data, - fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device, pin_memory=pin_memory), + fp8_scale_inv=scale_inv, fp8_dtype=self.dtype, requires_grad=requires_grad, data_transpose=data_transpose, @@ -379,6 +491,17 @@ def make_empty( device=device, ) + def make_fake_empty( + self, + shape: Iterable[int], + *, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, + ) -> Float8Tensor: + """Dynamo-safe analogue of :meth:`make_empty` (see + :func:`_float8_make_fake_empty`).""" + return _float8_make_fake_empty(self, shape, dtype=dtype, device=device) + def calibrate(self, tensor: torch.Tensor) -> None: # current scaling don't need to calibrate return @@ -461,6 +584,9 @@ def supports_only_rowwise_all_gather(self) -> bool: """ return True + def _storage_scalars(self) -> Dict[str, Any]: + return {"fp8_dtype": self.dtype} + class Float8Tensor(Float8TensorStorage, QuantizedTensor): """Experimental tensor class with FP8 data @@ -494,14 +620,36 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor): """ def __repr__(self, *, tensor_contents=None): + # ``__repr__`` is on hot diagnostic paths (Dynamo's + # ``Dynamo failed to run FX node`` formatter, autograd + # anomaly mode, FX node printers, ...) and must never raise. + # In particular, dequantising a fake/functional tensor here + # would access ``data_ptr()`` and replace the real failure + # with a misleading data-pointer error. + try: + shape = tuple(self.shape) + except BaseException: # pylint: disable=broad-except + shape = "" return ( "Float8Tensor(" f"fp8_dtype={self._fp8_dtype}, " - f"scale_inv={self._scale_inv.item()}, " - f"data={self.dequantize()}" + f"shape={shape}" ")" ) + @classmethod + def _flatten_meta_overrides(cls, meta: dict) -> dict: + """Bridge :class:`FP8DType` (carried by the subclass output spec + via :meth:`Quantizer.create_metadata`) back to the native + ``tex.DType`` accepted by pybind-bound TE kernels. The eager + :meth:`__tensor_flatten__` path stores ``tex.DType`` directly and + is a no-op here. + """ + fp8_dtype = meta.get("fp8_dtype") + if isinstance(fp8_dtype, FP8DType): + meta = {**meta, "fp8_dtype": to_tex(fp8_dtype)} + return meta + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """ Construct plain PyTorch tensor from Float8Tensor diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 5cab519c79..5a17e69a73 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -6,7 +6,7 @@ from __future__ import annotations from collections.abc import Iterable import math -from typing import Optional, Tuple, Union, Any +from typing import Any, Dict, Optional, Tuple, Union import warnings import torch @@ -15,7 +15,7 @@ from transformer_engine_torch import DType as TE_DType from transformer_engine.common.recipe import MXFP8BlockScaling, Recipe -from ..constants import MXFP8_BLOCK_SCALING_SIZE +from ..constants import MXFP8_BLOCK_SCALING_SIZE, canonicalize_te_dtype from ..utils import devices_match, round_up_to_nearest_multiple from .storage.mxfp8_tensor_storage import MXFP8TensorStorage, _FromMXFP8Func from ..quantized_tensor import QuantizedTensor, Quantizer @@ -35,6 +35,8 @@ class MXFP8Quantizer(Quantizer): dtype: TE_DType + _storage_cls = MXFP8TensorStorage + def __init__( self, fp8_dtype: TE_DType, @@ -43,7 +45,7 @@ def __init__( columnwise: bool = True, ) -> None: super().__init__(rowwise=rowwise, columnwise=columnwise) - self.dtype = fp8_dtype + self.dtype = canonicalize_te_dtype(fp8_dtype) def copy(self) -> MXFP8Quantizer: """Create shallow copy""" @@ -146,7 +148,19 @@ def make_empty( pin_memory=pin_memory, ) - # Construct FP8 tensor + # See ``Float8Quantizer.make_empty`` for the rationale. + if self.internal: + return MXFP8TensorStorage( + data, + scale_inv, + columnwise_data, + columnwise_scale_inv, + self.dtype, + self, + self.optimize_for_gemm, + fake_dtype=dtype, + ) + return MXFP8Tensor( shape=shape, dtype=dtype, @@ -243,6 +257,12 @@ def onnx_dequantize(self, tensor: Union[MXFP8TensorStorage, MXFP8Tensor]) -> tor def _get_compatible_recipe(self) -> Union[type[Recipe], None]: return MXFP8BlockScaling + def _storage_scalars(self) -> Dict[str, Any]: + return { + "fp8_dtype": self.dtype, + "with_gemm_swizzled_scales": self.optimize_for_gemm, + } + class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): """Experimental tensor class with FP8 data diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 285a7f030a..e8469b06d9 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -7,7 +7,7 @@ from collections.abc import Iterable import math import warnings -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import functools import torch @@ -15,7 +15,7 @@ from transformer_engine_torch import DType as TE_DType from transformer_engine.common.recipe import NVFP4BlockScaling, Recipe -from ..constants import NVFP4_BLOCK_SCALING_SIZE, dist_group_type +from ..constants import NVFP4_BLOCK_SCALING_SIZE, canonicalize_te_dtype, dist_group_type from ..utils import ( canonicalize_process_group, devices_match, @@ -135,6 +135,21 @@ class NVFP4Quantizer(Quantizer): rht_matrix_random_sign_mask_t: int rht_matrix: torch.Tensor + _storage_cls = NVFP4TensorStorage + _DTYPE_INIT_KWARG = "fp4_dtype" + _INIT_META_ATTRS = ( + "with_amax_reduction", + "with_rht", + "with_post_rht_amax", + "with_2d_quantization", + "stochastic_rounding", + "row_scaled_nvfp4", + ) + _POST_INIT_META_ATTRS = ("rht_matrix_random_sign_mask_t",) + _POST_INIT_TENSOR_ATTRS = ("rht_matrix",) + _PG_ATTR = "amax_reduction_group" + _PG_INIT_KWARG = "amax_reduction_group" + def __init__( self, fp4_dtype: TE_DType = tex.DType.kFloat4E2M1, @@ -150,7 +165,7 @@ def __init__( with_random_sign_mask: bool = True, ) -> None: super().__init__(rowwise=rowwise, columnwise=columnwise) - self.dtype = fp4_dtype + self.dtype = canonicalize_te_dtype(fp4_dtype) self.with_rht = with_rht self.with_post_rht_amax = with_post_rht_amax self.with_amax_reduction = with_amax_reduction @@ -373,7 +388,22 @@ def make_empty( 1, dtype=torch.float32, device=device, pin_memory=pin_memory ) - # Construct FP8 tensor + # See ``Float8Quantizer.make_empty`` for the rationale. + if self.internal: + return NVFP4TensorStorage( + data, + scale_inv, + columnwise_data, + columnwise_scale_inv, + amax_rowwise, + amax_columnwise, + self.dtype, + self, + False, + fake_dtype=dtype, + row_scaled_nvfp4=self.row_scaled_nvfp4, + ) + return NVFP4Tensor( shape=shape, dtype=dtype, @@ -400,6 +430,13 @@ def _canonicalized_amax_reduction_group(self) -> dist_group_type: def _get_compatible_recipe(self) -> Union[type[Recipe], None]: return NVFP4BlockScaling + def _storage_scalars(self) -> Dict[str, Any]: + return { + "fp4_dtype": self.dtype, + "with_gemm_swizzled_scales": self.optimize_for_gemm, + "row_scaled_nvfp4": self.row_scaled_nvfp4, + } + class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor): """Quantized tensor class with FP4 data diff --git a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py index ca3913762f..84d00c7930 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py @@ -36,6 +36,32 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage): _columnwise_scale_inv: Optional[torch.Tensor] _is_2D_scaled: bool + # Declarative schema consumed by the generic + # :meth:`QuantizedTensorStorage._torch_compile_flatten` / + # :meth:`_torch_compile_do_unflatten` implementations in the base. + _FLATTEN_TENSOR_ATTRS = ( + "_rowwise_data", + "_rowwise_scale_inv", + "_columnwise_data", + "_columnwise_scale_inv", + ) + _FLATTEN_TENSOR_USAGE = { + "_rowwise_data": "rowwise", + "_rowwise_scale_inv": "rowwise", + "_columnwise_data": "columnwise", + "_columnwise_scale_inv": "columnwise", + } + _FLATTEN_META_ATTRS = ("_fp8_dtype", "_dtype", "_is_2D_scaled") + _FLATTEN_CTOR_KWARG = { + "_rowwise_data": "rowwise_data", + "_rowwise_scale_inv": "rowwise_scale_inv", + "_columnwise_data": "columnwise_data", + "_columnwise_scale_inv": "columnwise_scale_inv", + "_fp8_dtype": "fp8_dtype", + "_dtype": "fake_dtype", + "_is_2D_scaled": "is_2D_scaled", + } + def __new__( cls, rowwise_data: Optional[torch.Tensor], diff --git a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py index de7f8f58e2..4cd7162e39 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py @@ -76,6 +76,24 @@ class Float8TensorStorage(QuantizedTensorStorage): _transpose: Optional[torch.Tensor] _transpose_invalid: bool + # Declarative schema consumed by the generic + # :meth:`QuantizedTensorStorage._torch_compile_flatten` / + # :meth:`_torch_compile_do_unflatten` implementations in the base. + _FLATTEN_TENSOR_ATTRS = ("_data", "_transpose", "_scale_inv") + _FLATTEN_TENSOR_USAGE = { + "_data": "rowwise", + "_transpose": "columnwise", + "_scale_inv": "always", + } + _FLATTEN_META_ATTRS = ("_fp8_dtype", "_dtype") + _FLATTEN_CTOR_KWARG = { + "_data": "data", + "_transpose": "data_transpose", + "_scale_inv": "fp8_scale_inv", + "_fp8_dtype": "fp8_dtype", + "_dtype": "fake_dtype", + } + def __new__( cls, *args, @@ -162,6 +180,19 @@ def restore_from_saved( self._data = tensors[0] self._transpose = tensors[1] self._scale_inv = tensors[2] + # Re-derive ``_transpose_invalid`` from the restored buffer: + # the saved transpose, if present, was valid at save time + # (``prepare_for_saving`` never resets this flag, and forward + # producers don't save stale transposes). Tying the flag to + # ``self._transpose`` here makes restoration independent of + # whichever shell carried the storage across the trace + # boundary -- in particular ``torch.compile``'s save/restore + # round-trip, which builds a fresh wrapper shell for backward + # whose pre-restore ``_transpose_invalid`` would otherwise + # come from :meth:`Float8TensorStorage.__new__` (``True`` + # whenever it sees ``data_transpose=None``) and trip + # :meth:`update_usage` downstream. + self._transpose_invalid = self._transpose is None return tensors[3:] def get_data_tensors(self, rowwise_data: bool = True, columnwise_data: bool = True): @@ -215,14 +246,34 @@ def view(self, shape: torch.Size): ) def __repr__(self): + # Must never raise: this runs from Inductor error formatters, + # FX node dumps, Dynamo guards, etc. Crucially we must also + # avoid any tensor->scalar materialization (``.item()``, + # ``.tolist()``, ``dequantize()``): under fake-tensor mode they + # allocate fresh unbacked symbols which then leak out of the + # current op as "unreturned outputs" and crash the compile. + # Stick to shape/dtype summaries. + scale_shape = list(getattr(self._scale_inv, "shape", ())) + if self._data is None: + data_repr = "" + else: + data_shape = list(getattr(self._data, "shape", ())) + data_repr = f"" return ( "Float8TensorStorage(" f"fp8_dtype={self._fp8_dtype}, " - f"scale_inv={self._scale_inv.item()}, " - f"data={self.dequantize()}" + f"scale_inv=, " + f"data={data_repr}" ")" ) + # ``__new__`` re-derives ``_transpose_invalid`` from the restored + # ``_transpose`` buffer, so the flag is deliberately not round-tripped + # through ``_FLATTEN_META_ATTRS``: a producer that ships a transpose + # through the trace had it valid, and trusting a stale ``True`` from + # a Dynamo-embedded meta constant would trip :meth:`update_usage`'s + # ``not has_data_transpose`` guard in backward. + def _create_transpose(self): """Update FP8 transpose cache""" data = self._data diff --git a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py index 842f42838b..d3f19a3b1c 100644 --- a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py @@ -79,6 +79,32 @@ class MXFP8TensorStorage(QuantizedTensorStorage): # GEMM _with_gemm_swizzled_scales: bool + # Declarative schema consumed by the generic + # :meth:`QuantizedTensorStorage._torch_compile_flatten` / + # :meth:`_torch_compile_do_unflatten` implementations in the base. + _FLATTEN_TENSOR_ATTRS = ( + "_rowwise_data", + "_rowwise_scale_inv", + "_columnwise_data", + "_columnwise_scale_inv", + ) + _FLATTEN_TENSOR_USAGE = { + "_rowwise_data": "rowwise", + "_rowwise_scale_inv": "rowwise", + "_columnwise_data": "columnwise", + "_columnwise_scale_inv": "columnwise", + } + _FLATTEN_META_ATTRS = ("_fp8_dtype", "_dtype", "_with_gemm_swizzled_scales") + _FLATTEN_CTOR_KWARG = { + "_rowwise_data": "rowwise_data", + "_rowwise_scale_inv": "rowwise_scale_inv", + "_columnwise_data": "columnwise_data", + "_columnwise_scale_inv": "columnwise_scale_inv", + "_fp8_dtype": "fp8_dtype", + "_dtype": "fake_dtype", + "_with_gemm_swizzled_scales": "with_gemm_swizzled_scales", + } + def __new__( cls, rowwise_data: Optional[torch.Tensor], diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index e51acb71e5..f8d79ccf5e 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -90,7 +90,7 @@ class NVFP4TensorStorage(QuantizedTensorStorage): # column-scaled FP4 data) _amax_columnwise: torch.Tensor - # Builder class for casting to MXFP8 + # Builder class for casting to NVFP4 _quantizer: Optional[Quantizer] # FP4 data type _fp4_dtype: TE_DType @@ -100,6 +100,44 @@ class NVFP4TensorStorage(QuantizedTensorStorage): # Whether this NVFP4 tensor uses row-scaled amax metadata _row_scaled_nvfp4: bool + # Declarative schema consumed by the generic + # :meth:`QuantizedTensorStorage._torch_compile_flatten` / + # :meth:`_torch_compile_do_unflatten` implementations in the base. + _FLATTEN_TENSOR_ATTRS = ( + "_rowwise_data", + "_rowwise_scale_inv", + "_columnwise_data", + "_columnwise_scale_inv", + "_amax_rowwise", + "_amax_columnwise", + ) + _FLATTEN_TENSOR_USAGE = { + "_rowwise_data": "rowwise", + "_rowwise_scale_inv": "rowwise", + "_columnwise_data": "columnwise", + "_columnwise_scale_inv": "columnwise", + "_amax_rowwise": "rowwise", + "_amax_columnwise": "columnwise", + } + _FLATTEN_META_ATTRS = ( + "_fp4_dtype", + "_dtype", + "_with_gemm_swizzled_scales", + "_row_scaled_nvfp4", + ) + _FLATTEN_CTOR_KWARG = { + "_rowwise_data": "rowwise_data", + "_rowwise_scale_inv": "rowwise_scale_inv", + "_columnwise_data": "columnwise_data", + "_columnwise_scale_inv": "columnwise_scale_inv", + "_amax_rowwise": "amax_rowwise", + "_amax_columnwise": "amax_columnwise", + "_fp4_dtype": "fp4_dtype", + "_dtype": "fake_dtype", + "_with_gemm_swizzled_scales": "with_gemm_swizzled_scales", + "_row_scaled_nvfp4": "row_scaled_nvfp4", + } + def __new__( cls, rowwise_data: Optional[torch.Tensor], @@ -216,6 +254,10 @@ def restore_from_saved( self._amax_columnwise = tensors[5] return tensors[6:] + # ``_torch_compile_flatten`` / ``_torch_compile_do_unflatten`` are + # the generic implementations on :class:`QuantizedTensorStorage`, + # driven by the ``_FLATTEN_*`` declarations above. + def get_data_tensors(self): """Get this Tensor's data.""" return self._rowwise_data, self._columnwise_data