diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index a524d5c8de..c08bf5eda7 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -473,7 +473,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): dst.quantize_(src) else: if isinstance(src, QuantizedTensor): - src = src.dequantize() + dtype = dst.dtype + if dtype not in (torch.float32, torch.float16, torch.bfloat16): + dtype = torch.float32 + src = src.dequantize(dtype=dtype) dst.copy_(src) return None