Conversation
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci L1 pytorch |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
|
/te-ci L1 pytorch |
Greptile SummaryThis PR implements comprehensive CPU overhead optimizations across Python and C++ layers of TransformerEngine. Python optimizations:
C++ optimizations:
Key improvements from previous review rounds:
Issue found:
Confidence Score: 4/5
Important Files Changed
Last reviewed commit: 73e4d1d |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
|
/te-ci L1 pytorch |
There was a problem hiding this comment.
Greptile Overview
Greptile Summary
This PR implements CPU-side performance optimizations for FP8 operations by caching frequently accessed attributes and reducing redundant function calls. The optimizations target expensive PyObject attribute lookups on custom tensor types and repeated C++ function calls.
Key Changes:
- Caches
requires_grad,dtype,shape, andis_cudaattribute accesses to avoid expensive PyObject lookups on custom tensors - Reorders attribute checks in
get_tensor_device()to prioritize internal quantized tensor attributes - Makes
num_devicesstatic innvte_is_non_tn_fp8_gemm_supported()to cache device count - Stores GEMM support check results in local variables to avoid redundant function calls
Critical Issues Found:
- Variable redeclaration error in
cublaslt_gemm.cu(line 224) will prevent compilation - Logic bug in
linear.py(line 484) changes FP8 state management from OR logic to AND logic, breaking functionality when bias is None or doesn't require grad
Confidence Score: 0/5
- This PR cannot be merged due to compilation error and critical logic bug
- Two critical issues prevent merging: (1) C++ compilation will fail due to variable redeclaration at line 224 of cublaslt_gemm.cu, and (2) logic bug at line 484 of linear.py breaks FP8 state management by requiring all three tensors to have requires_grad=True instead of any one of them
- Pay close attention to
transformer_engine/common/gemm/cublaslt_gemm.cu(compilation error) andtransformer_engine/pytorch/module/linear.py(logic bug)
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/common/gemm/cublaslt_gemm.cu | 1/5 | Caches function call result to reduce overhead, but contains variable redeclaration error that will cause compilation failure |
| transformer_engine/common/transformer_engine.cpp | 5/5 | Makes num_devices static to avoid redundant calls to cuda::num_devices() - valid optimization |
| transformer_engine/pytorch/module/linear.py | 0/5 | Caches requires_grad checks for performance, but contains critical logic bug at line 484 that changes FP8 state management behavior |
Sequence Diagram
sequenceDiagram
participant User as User Code
participant Linear as Linear Module
participant Quantizer as Quantizer/QuantizedTensor
participant GEMM as GEMM Operations
participant CPP as C++ Extensions
Note over Linear,CPP: Performance Optimization Flow
User->>Linear: forward(input, weight, bias)
Note over Linear: Cache requires_grad checks
Linear->>Linear: inp_requires_grad = inp.requires_grad<br/>weight_requires_grad = weight.requires_grad<br/>bias_requires_grad = bias.requires_grad
Linear->>Quantizer: Check if quantized tensor
alt QuantizedTensor
Note over Quantizer: Use cached dtype property
Quantizer->>Quantizer: return self._dtype
Note over Quantizer: Use cached shape/is_cuda
Quantizer->>Quantizer: return self._data.shape
else Regular Tensor
Quantizer->>Linear: Standard attribute access
end
Linear->>CPP: get_tensor_device(tensor)
Note over CPP: Reordered attribute checks
CPP->>CPP: Check _rowwise_data first<br/>Check _columnwise_data<br/>Check device last
CPP-->>Linear: device_index
Linear->>GEMM: Configure GEMM parameters
Note over GEMM: Cache nvte_is_non_tn_fp8_gemm_supported
GEMM->>CPP: nvte_is_non_tn_fp8_gemm_supported()
Note over CPP: Static num_devices cached
CPP-->>GEMM: support_flag
GEMM->>GEMM: Store in local variable
GEMM->>GEMM: Execute optimized GEMM
GEMM-->>Linear: output
Note over Linear: FP8 State Management
alt FP8 enabled and requires_grad check
Linear->>Linear: Update FP8 tensors<br/>based on cached flags
end
Linear-->>User: output
Additional Comments (2)
|
|
/te-ci L1 pytorch |
|
/te-ci L1 pytorch |
| ret.A_scale_inv = A.scale_inv.dptr; | ||
| ret.lda = is_A_transposed ? k : m; | ||
| if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) { | ||
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); |
There was a problem hiding this comment.
Variable is_nvte_non_tn_fp8_gemm_supported is redeclared in the same scope - it was already declared at line 224 for the B matrix configuration. This causes a compilation error or shadowing issue.
The variable should only be declared once at the beginning of the function scope and reused for both A and B matrix configurations.
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); | |
| // Move this declaration before line 125 (Configure A matrix) | |
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); |
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | ||
| PyTuple_New(0), kwargs); |
There was a problem hiding this comment.
Memory leak: PyTuple_New(0) creates a new Python tuple object with a reference count of 1, but this reference is never decremented. The tuple is passed directly to PyObject_Call which doesn't steal the reference, so it must be explicitly released.
This same issue occurs at lines 186, 398, 417, 697, 718, 1039, 1058, 1346, and 1367.
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | |
| PyTuple_New(0), kwargs); | |
| PyObject* args = PyTuple_New(0); | |
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | |
| args, kwargs); | |
| Py_DECREF(args); | |
| Py_DECREF(kwargs); |
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
| PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); |
There was a problem hiding this comment.
Invalid parameter: The Float8TensorStorage.__new__ method does not accept a device parameter. Looking at the Python class definition in float8_tensor_storage.py, the __new__ signature only expects: data, fp8_scale_inv, fp8_dtype, data_transpose, and quantizer.
This will cause a runtime error when trying to create Float8TensorStorage instances. This issue also occurs at lines 184, 396, 415, and in similar patterns for MXFP8 and other quantizers.
| PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); | |
| // Remove this line - device parameter is not accepted | |
| // PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); |
| PyTypeObject *NVFP4TensorPythonClass = nullptr; | ||
| PyTypeObject *NVFP4TensorStoragePythonClass = nullptr; | ||
| PyTypeObject *NVFP4QuantizerClass = nullptr; | ||
| bool is_extension_initialized = false; |
There was a problem hiding this comment.
Race condition: The is_extension_initialized flag is not thread-safe. If init_extension() is called concurrently from multiple threads, both threads could pass the check at line 105 before either sets the flag at line 110, causing the initialization code to run multiple times.
The original implementation using null-checks on the individual class pointers was safer because Python's module import system provides thread safety. Consider using std::once_flag and std::call_once for proper thread-safe initialization:
| bool is_extension_initialized = false; | |
| bool is_extension_initialized = false; | |
| std::once_flag extension_init_flag; |
And in init_extension():
void init_extension() {
std::call_once(extension_init_flag, []() {
init_float8_extension();
init_mxfp8_extension();
init_float8blockwise_extension();
init_nvfp4_extensions();
});
}| static int num_devices = transformer_engine::cuda::num_devices(); | ||
| static std::vector<int> cache(num_devices, -1); | ||
| static std::vector<std::once_flag> flags(num_devices); |
There was a problem hiding this comment.
Potential edge case: The num_devices is captured once as a static variable, so if the number of devices changes at runtime (e.g., with CUDA MPS or device enumeration changes), the cache and flags vectors could be undersized, leading to out-of-bounds access at line 968.
While dynamic device changes are rare, consider adding a bounds check:
| static int num_devices = transformer_engine::cuda::num_devices(); | |
| static std::vector<int> cache(num_devices, -1); | |
| static std::vector<std::once_flag> flags(num_devices); | |
| static int num_devices = transformer_engine::cuda::num_devices(); | |
| static std::vector<int> cache(num_devices, -1); | |
| static std::vector<std::once_flag> flags(num_devices); | |
| int device_id = transformer_engine::cuda::current_device(); | |
| NVTE_CHECK(device_id < num_devices, "Device ID exceeds cached device count"); |
| ret.B_scale_inv = B.scale_inv.dptr; | ||
| ret.ldb = is_B_transposed ? n : k; | ||
| if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) { | ||
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); |
There was a problem hiding this comment.
Variable is_nvte_non_tn_fp8_gemm_supported is redeclared in the same function scope (already declared at line 132). In C++, this is a compilation error. The variable should be declared once at the function scope level before both the A matrix and B matrix configuration blocks.
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); | |
| // Move declaration to line 104 (after GemmParam ret;) to avoid redeclaration |
| PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); | ||
|
|
||
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | ||
| PyTuple_New(0), kwargs); |
There was a problem hiding this comment.
Memory leak: PyTuple_New(0) creates a new tuple with reference count 1, but PyObject_Call does not steal the reference. The tuple must be stored and explicitly DECREF'd after the call. This pattern is repeated 12 times throughout this file (lines 166, 186, 398, 417, 697, 718, 1039, 1058, 1346, 1367).
| PyTuple_New(0), kwargs); | |
| PyObject* args = PyTuple_New(0); | |
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | |
| args, kwargs); | |
| Py_DECREF(args); |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
| ret.B_scale_inv = B.scale_inv.dptr; | ||
| ret.ldb = is_B_transposed ? n : k; | ||
| if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) { | ||
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); |
There was a problem hiding this comment.
[P0] Variable is_nvte_non_tn_fp8_gemm_supported is already declared at line 132. Declaring it again here causes a compilation error. Remove this declaration and reuse the variable from line 132.
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); | |
| // Reuse is_nvte_non_tn_fp8_gemm_supported from line 132 |
| } | ||
| py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); | ||
|
|
||
| py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); |
There was a problem hiding this comment.
[P0] Use-after-free bug: scale_inv_py is created from scale_inv before scale_inv is initialized. When scale_inv parameter is empty (lines 148-150), this dereferences an empty std::optional, causing undefined behavior. Move this line after the scale_inv initialization.
| py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); | |
| // Move scale_inv_py creation after initialization (line 150) |
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
| PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); | ||
|
|
||
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), |
There was a problem hiding this comment.
[P0] Memory leak: PyTuple_New(0) creates a new Python object with refcount 1, but it's never decremented. This leaks memory on every tensor creation. Add Py_DECREF for the tuple:
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | |
| PyObject* args = PyTuple_New(0); | |
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs); | |
| Py_DECREF(args); |
| PyDict_SetItemString(kwargs, "is_2D_scaled", py::cast(block_scaling_dim == 2).ptr()); | ||
| PyDict_SetItemString(kwargs, "data_format", py::cast(data_format).ptr()); | ||
|
|
||
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8BlockwiseQTensorPythonClass), |
There was a problem hiding this comment.
[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(MXFP8TensorStoragePythonClass), |
There was a problem hiding this comment.
[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).
| PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(NVFP4TensorStoragePythonClass), |
There was a problem hiding this comment.
[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci pytorch |
|
/te-ci pytorch |
…ormerEngine into cpu_fp8_optimizations
|
/te-ci pytorch |
|
/te-ci pytorch |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
|
/te-ci pytorch |
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
|
/te-ci pytorch |
Additional Comments (1)
when copying from one result: cached add after line 1019: self._dtype = tensor.dtype |
Description
CPU overhead optimizations
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Python Optimizations
C++ Optimizations
Checklist: