diff --git a/tests/pytorch/attention/test_kv_cache.py b/tests/pytorch/attention/test_kv_cache.py index c662252f9e..56a6261236 100644 --- a/tests/pytorch/attention/test_kv_cache.py +++ b/tests/pytorch/attention/test_kv_cache.py @@ -381,8 +381,13 @@ def get_tols(config, module, backend, dtype): torch.bfloat16: (1.2e-1, 1e-1), } else: + # head_dim > 128 in fp16 is the worst case for accumulated rounding, and the + # full-sequence vs incremental-KV-cache paths use different kernels/mask types. + # On sm80 with older cuDNN the agreement grazes 1e-2 on a single element, so the + # fp16 tolerance is widened slightly. Tolerances were originally calibrated on + # Hopper/Blackwell + newer cuDNN. tols = { - torch.half: (1e-2, 1e-2), + torch.half: (1.5e-2, 1.5e-2), torch.bfloat16: (8e-2, 7e-2), } if module == "DotProductAttention":