IAttention FP8#4209
Conversation
There was a problem hiding this comment.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/attention.py 2026-04-23 16:47:17.496428+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/attention.py 2026-04-23 16:47:37.867662+00:00
@@ -29,11 +29,12 @@
attention_layer: trt.IAttention,
) -> bool:
"""Set FP8 softmax normalization quantization on the IAttention layer if the current
node was annotated with a softmax FP8 scale by the fp8_attention_softmax lowering pass.
- Returns True if FP8 normalization was configured (caller must set decomposable=False)."""
+ Returns True if FP8 normalization was configured (caller must set decomposable=False).
+ """
if ctx.current_node is None:
return False
scale_val = ctx.current_node.meta.get("_fp8_softmax_scale")
if scale_val is None:
return False
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py 2026-04-23 16:47:17.538801+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py 2026-04-23 16:47:42.928700+00:00
@@ -580,34 +580,34 @@
# FP8 Q/K/V inputs (exponent_bits=4): SDPA node must be annotated with 1/448.
gm_fp8 = _build_sdpa_input_quant_graph(exponent_bits=4)
annotate_fp8_sdpa(gm_fp8, settings)
sdpa_nodes = [n for n in gm_fp8.graph.nodes if n.target in _SDPA_TARGETS]
assert sdpa_nodes, "No SDPA node found in graph"
- assert all("_fp8_softmax_scale" in n.meta for n in sdpa_nodes), (
- "annotate_fp8_sdpa did not annotate SDPA when Q/K/V inputs are FP8"
- )
+ assert all(
+ "_fp8_softmax_scale" in n.meta for n in sdpa_nodes
+ ), "annotate_fp8_sdpa did not annotate SDPA when Q/K/V inputs are FP8"
expected_scale = 1.0 / 448.0
for n in sdpa_nodes:
- assert abs(n.meta["_fp8_softmax_scale"] - expected_scale) < 1e-12, (
- f"Wrong softmax scale: {n.meta['_fp8_softmax_scale']}"
- )
+ assert (
+ abs(n.meta["_fp8_softmax_scale"] - expected_scale) < 1e-12
+ ), f"Wrong softmax scale: {n.meta['_fp8_softmax_scale']}"
# INT8 Q/K/V inputs (exponent_bits=0): SDPA node must NOT be annotated.
gm_int8 = _build_sdpa_input_quant_graph(exponent_bits=0)
annotate_fp8_sdpa(gm_int8, settings)
sdpa_int8 = [n for n in gm_int8.graph.nodes if n.target in _SDPA_TARGETS]
- assert all("_fp8_softmax_scale" not in n.meta for n in sdpa_int8), (
- "annotate_fp8_sdpa incorrectly annotated SDPA when Q/K/V are INT8"
- )
+ assert all(
+ "_fp8_softmax_scale" not in n.meta for n in sdpa_int8
+ ), "annotate_fp8_sdpa incorrectly annotated SDPA when Q/K/V are INT8"
# Only Q and K are FP8-quantized, V is raw: SDPA must NOT be annotated.
gm_partial = _build_sdpa_input_quant_graph(exponent_bits=4, quantize_v=False)
annotate_fp8_sdpa(gm_partial, settings)
sdpa_partial = [n for n in gm_partial.graph.nodes if n.target in _SDPA_TARGETS]
- assert all("_fp8_softmax_scale" not in n.meta for n in sdpa_partial), (
- "annotate_fp8_sdpa incorrectly annotated SDPA when V input is not FP8"
- )
+ assert all(
+ "_fp8_softmax_scale" not in n.meta for n in sdpa_partial
+ ), "annotate_fp8_sdpa incorrectly annotated SDPA when V input is not FP8"
@unittest.skipIf(
torch.cuda.get_device_capability() < (8, 9),
"FP8 quantization requires compute capability 8.9 or later",
@@ -649,19 +649,13 @@
"""Mirror of what a modelopt FP8 MHA PyTorch export will look like:
tensorrt.quantize_op on Q, K, V feeding F.scaled_dot_product_attention."""
def __init__(self, amax_val: float = 6.0):
super().__init__()
- self.register_buffer(
- "amax_q", torch.tensor(amax_val, dtype=torch.float32)
- )
- self.register_buffer(
- "amax_k", torch.tensor(amax_val, dtype=torch.float32)
- )
- self.register_buffer(
- "amax_v", torch.tensor(amax_val, dtype=torch.float32)
- )
+ self.register_buffer("amax_q", torch.tensor(amax_val, dtype=torch.float32))
+ self.register_buffer("amax_k", torch.tensor(amax_val, dtype=torch.float32))
+ self.register_buffer("amax_v", torch.tensor(amax_val, dtype=torch.float32))
def forward(self, q, k, v):
q_fp8 = torch.ops.tensorrt.quantize_op(q, self.amax_q, 8, 4, False, False)
k_fp8 = torch.ops.tensorrt.quantize_op(k, self.amax_k, 8, 4, False, False)
v_fp8 = torch.ops.tensorrt.quantize_op(v, self.amax_v, 8, 4, False, False)
@@ -690,12 +684,11 @@
engine_json = json.loads(
inspector.get_engine_information(trt.LayerInformationFormat.JSON)
)
layers = engine_json.get("Layers", [])
layer_names = [
- layer if isinstance(layer, str) else layer.get("Name", "")
- for layer in layers
+ layer if isinstance(layer, str) else layer.get("Name", "") for layer in layers
]
assert any("mha" in name.lower() for name in layer_names), (
f"No fused MHA kernel found in compiled engine. Expected a layer "
f"containing 'mha' (e.g. _gemm_mha_v2) — TRT fuses FP8 Q/K/V + "
f"normalization_quantize_to_type into a single MHA kernel. "
@@ -714,8 +707,8 @@
trt_out = compiled(q, k, v)
cos = torch.nn.functional.cosine_similarity(
ref_out.flatten().float().unsqueeze(0),
trt_out.flatten().float().unsqueeze(0),
).item()
- assert cos > 0.99, (
- f"FP8 MHA output deviates from PyTorch reference: cosine_similarity={cos}"
- )
+ assert (
+ cos > 0.99
+ ), f"FP8 MHA output deviates from PyTorch reference: cosine_similarity={cos}"| ctx, | ||
| torch.tensor(scale_val, dtype=torch.float32), | ||
| name + "_softmax_fp8_scale", | ||
| dtype=torch.float32, |
There was a problem hiding this comment.
dtype needs to match the pre-quant QKV dtype. otherwise TRT compilatio will fail on some platforms
There was a problem hiding this comment.
Do you know where we can fetch this info?
There was a problem hiding this comment.
7f0d61c I pulled the attention layer's output tensor's dtype.
There was a problem hiding this comment.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/attention.py 2026-04-23 17:31:58.988167+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/attention.py 2026-04-23 17:32:17.137589+00:00
@@ -29,11 +29,12 @@
attention_layer: trt.IAttention,
) -> bool:
"""Set FP8 softmax normalization quantization on the IAttention layer if the current
node was annotated with a softmax FP8 scale by the fp8_attention_softmax lowering pass.
- Returns True if FP8 normalization was configured (caller must set decomposable=False)."""
+ Returns True if FP8 normalization was configured (caller must set decomposable=False).
+ """
if ctx.current_node is None:
return False
scale_val = ctx.current_node.meta.get("_fp8_softmax_scale")
if scale_val is None:
return False
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/insert_fp8_softmax_qdq.py 2026-04-23 17:31:58.992065+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/insert_fp8_softmax_qdq.py 2026-04-23 17:32:18.027942+00:00
@@ -115,11 +115,13 @@
if attn_src.op != "call_function" or attn_src.target not in _MATMUL_TARGETS:
continue
if len(attn_src.args) < 2:
continue
q_source, k_source = attn_src.args[0], attn_src.args[1]
- if not (_source_is_fp8_quantize(q_source) and _source_is_fp8_quantize(k_source)):
+ if not (
+ _source_is_fp8_quantize(q_source) and _source_is_fp8_quantize(k_source)
+ ):
continue
# Register a per-insertion amax buffer (1.0).
amax_name = f"_fp8_softmax_qdq_amax_{amax_buffer_idx}"
amax_buffer_idx += 1
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py 2026-04-23 17:31:59.020247+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py 2026-04-23 17:32:21.643502+00:00
@@ -580,34 +580,34 @@
# FP8 Q/K/V inputs (exponent_bits=4): SDPA node must be annotated with 1/448.
gm_fp8 = _build_sdpa_input_quant_graph(exponent_bits=4)
annotate_fp8_sdpa(gm_fp8, settings)
sdpa_nodes = [n for n in gm_fp8.graph.nodes if n.target in _SDPA_TARGETS]
assert sdpa_nodes, "No SDPA node found in graph"
- assert all("_fp8_softmax_scale" in n.meta for n in sdpa_nodes), (
- "annotate_fp8_sdpa did not annotate SDPA when Q/K/V inputs are FP8"
- )
+ assert all(
+ "_fp8_softmax_scale" in n.meta for n in sdpa_nodes
+ ), "annotate_fp8_sdpa did not annotate SDPA when Q/K/V inputs are FP8"
expected_scale = 1.0 / 448.0
for n in sdpa_nodes:
- assert abs(n.meta["_fp8_softmax_scale"] - expected_scale) < 1e-12, (
- f"Wrong softmax scale: {n.meta['_fp8_softmax_scale']}"
- )
+ assert (
+ abs(n.meta["_fp8_softmax_scale"] - expected_scale) < 1e-12
+ ), f"Wrong softmax scale: {n.meta['_fp8_softmax_scale']}"
# INT8 Q/K/V inputs (exponent_bits=0): SDPA node must NOT be annotated.
gm_int8 = _build_sdpa_input_quant_graph(exponent_bits=0)
annotate_fp8_sdpa(gm_int8, settings)
sdpa_int8 = [n for n in gm_int8.graph.nodes if n.target in _SDPA_TARGETS]
- assert all("_fp8_softmax_scale" not in n.meta for n in sdpa_int8), (
- "annotate_fp8_sdpa incorrectly annotated SDPA when Q/K/V are INT8"
- )
+ assert all(
+ "_fp8_softmax_scale" not in n.meta for n in sdpa_int8
+ ), "annotate_fp8_sdpa incorrectly annotated SDPA when Q/K/V are INT8"
# Only Q and K are FP8-quantized, V is raw: SDPA must NOT be annotated.
gm_partial = _build_sdpa_input_quant_graph(exponent_bits=4, quantize_v=False)
annotate_fp8_sdpa(gm_partial, settings)
sdpa_partial = [n for n in gm_partial.graph.nodes if n.target in _SDPA_TARGETS]
- assert all("_fp8_softmax_scale" not in n.meta for n in sdpa_partial), (
- "annotate_fp8_sdpa incorrectly annotated SDPA when V input is not FP8"
- )
+ assert all(
+ "_fp8_softmax_scale" not in n.meta for n in sdpa_partial
+ ), "annotate_fp8_sdpa incorrectly annotated SDPA when V input is not FP8"
@unittest.skipIf(
torch.cuda.get_device_capability() < (8, 9),
"FP8 quantization requires compute capability 8.9 or later",
@@ -649,19 +649,13 @@
"""Mirror of what a modelopt FP8 MHA PyTorch export will look like:
tensorrt.quantize_op on Q, K, V feeding F.scaled_dot_product_attention."""
def __init__(self, amax_val: float = 6.0):
super().__init__()
- self.register_buffer(
- "amax_q", torch.tensor(amax_val, dtype=torch.float32)
- )
- self.register_buffer(
- "amax_k", torch.tensor(amax_val, dtype=torch.float32)
- )
- self.register_buffer(
- "amax_v", torch.tensor(amax_val, dtype=torch.float32)
- )
+ self.register_buffer("amax_q", torch.tensor(amax_val, dtype=torch.float32))
+ self.register_buffer("amax_k", torch.tensor(amax_val, dtype=torch.float32))
+ self.register_buffer("amax_v", torch.tensor(amax_val, dtype=torch.float32))
def forward(self, q, k, v):
q_fp8 = torch.ops.tensorrt.quantize_op(q, self.amax_q, 8, 4, False, False)
k_fp8 = torch.ops.tensorrt.quantize_op(k, self.amax_k, 8, 4, False, False)
v_fp8 = torch.ops.tensorrt.quantize_op(v, self.amax_v, 8, 4, False, False)
@@ -690,12 +684,11 @@
engine_json = json.loads(
inspector.get_engine_information(trt.LayerInformationFormat.JSON)
)
layers = engine_json.get("Layers", [])
layer_names = [
- layer if isinstance(layer, str) else layer.get("Name", "")
- for layer in layers
+ layer if isinstance(layer, str) else layer.get("Name", "") for layer in layers
]
assert any("mha" in name.lower() for name in layer_names), (
f"No fused MHA kernel found in compiled engine. Expected a layer "
f"containing 'mha' (e.g. _gemm_mha_v2) — TRT fuses FP8 Q/K/V + "
f"normalization_quantize_to_type into a single MHA kernel. "
@@ -714,13 +707,13 @@
trt_out = compiled(q, k, v)
cos = torch.nn.functional.cosine_similarity(
ref_out.flatten().float().unsqueeze(0),
trt_out.flatten().float().unsqueeze(0),
).item()
- assert cos > 0.99, (
- f"FP8 MHA output deviates from PyTorch reference: cosine_similarity={cos}"
- )
+ assert (
+ cos > 0.99
+ ), f"FP8 MHA output deviates from PyTorch reference: cosine_similarity={cos}"
@unittest.skipIf(
torch.cuda.get_device_capability() < (8, 9),
"FP8 quantization requires compute capability 8.9 or later",
@@ -749,19 +742,13 @@
torch.manual_seed(0)
class FP8MHAModel(torch.nn.Module):
def __init__(self, amax_val: float = 6.0):
super().__init__()
- self.register_buffer(
- "amax_q", torch.tensor(amax_val, dtype=torch.float32)
- )
- self.register_buffer(
- "amax_k", torch.tensor(amax_val, dtype=torch.float32)
- )
- self.register_buffer(
- "amax_v", torch.tensor(amax_val, dtype=torch.float32)
- )
+ self.register_buffer("amax_q", torch.tensor(amax_val, dtype=torch.float32))
+ self.register_buffer("amax_k", torch.tensor(amax_val, dtype=torch.float32))
+ self.register_buffer("amax_v", torch.tensor(amax_val, dtype=torch.float32))
def forward(self, q, k, v):
q_fp8 = torch.ops.tensorrt.quantize_op(q, self.amax_q, 8, 4, False, False)
k_fp8 = torch.ops.tensorrt.quantize_op(k, self.amax_k, 8, 4, False, False)
v_fp8 = torch.ops.tensorrt.quantize_op(v, self.amax_v, 8, 4, False, False)
@@ -791,12 +778,11 @@
engine_json = json.loads(
inspector.get_engine_information(trt.LayerInformationFormat.JSON)
)
layers = engine_json.get("Layers", [])
layer_names = [
- layer if isinstance(layer, str) else layer.get("Name", "")
- for layer in layers
+ layer if isinstance(layer, str) else layer.get("Name", "") for layer in layers
]
assert any("mha" in name.lower() for name in layer_names), (
f"No fused MHA kernel found on decomposed path. Expected a layer "
f"containing 'mha' (e.g. _gemm_mha_v2) — TRT fuses FP8 Q/K/V + "
f"softmax-output Q/DQ into _gemm_mha_v2 on Method 2 path. "
@@ -816,8 +802,8 @@
trt_out = compiled(q, k, v)
cos = torch.nn.functional.cosine_similarity(
ref_out.flatten().float().unsqueeze(0),
trt_out.flatten().float().unsqueeze(0),
).item()
- assert cos > 0.99, (
- f"Decomposed FP8 MHA output deviates from PyTorch reference: cos={cos}"
- )
+ assert (
+ cos > 0.99
+ ), f"Decomposed FP8 MHA output deviates from PyTorch reference: cos={cos}"…s to annotate target nodes
fe7f268 to
211e690
Compare
zewenli98
left a comment
There was a problem hiding this comment.
I think this approach works for both decompose and not decompose attention paths. Ideally, this work should be done on ModelOpt side I think
There was a problem hiding this comment.
current_node is needed in doc
| _MATMUL_TARGETS = { | ||
| torch.ops.aten.matmul.default, | ||
| torch.ops.aten.bmm.default, | ||
| } |
There was a problem hiding this comment.
We have targets:
@dynamo_tensorrt_converter(torch.ops.aten.matmul, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.matmul.default, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.dot.default, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.mm.default, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.mv.default, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.bmm.default, supports_dynamic_shapes=True)
whether other targets are needed here?
Description
Prototype support for FP8 normalization scale in IAttention layer
Structurally I think this approach is reasonable, to extract relevant graph level info and bake it into metadata that gets consumed in the converter. I dont think we need the softmax qdq pass necessarily if model opt inserts this for us.
cc: @nvyihengz, @yizhuoz004
Fixes #4200, #4167
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: