Skip to content

CPU Overhead Optimizations#2559

Open
vthumbe1503 wants to merge 80 commits intoNVIDIA:mainfrom
vthumbe1503:cpu_fp8_optimizations
Open

CPU Overhead Optimizations#2559
vthumbe1503 wants to merge 80 commits intoNVIDIA:mainfrom
vthumbe1503:cpu_fp8_optimizations

Conversation

@vthumbe1503
Copy link
Collaborator

@vthumbe1503 vthumbe1503 commented Jan 5, 2026

Description

CPU overhead optimizations

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

Python Optimizations

  • TE Pybinded Enums like tex.FP8FwdTensors.GEMM1_INPUT are casted to int in each and every forward pass. Now we are caching the integer values in a constants file and using that instead.
  • Getting tensor device in the helper function went through expensive tensor.device even for Quantized Tensor. Now we have the device declared as propery of QuantizedTensor, so it doesnt go through PyObject Lookup
  • Defining device shape and is_cuda attributes as Quantized Tensor properties (since they are easy enough to compute in python) and it avoid the expensive PyObject Lookup
  • Defining requires_grad and dtype as properties of base QuantizedTensor class. Here we cache the properties when values are being set to avoid the expensive PyObject Lookup. We still need to make sure setter goes through Pybind C++. For instance torch autograd engine in C++ needs to be aware of requires_grad changes.
  • dtype of our Custom QuantizedTensor can change when we go through x.data = new_tensor. And so we make sure dtype is cached appropriately by defining appropriate _get_data and _set_data for the data property of QuantizedTensor

C++ Optimizations

  • Caching symbol lookups in libcuda.so for driver calls like cuCtxGetCurrent, so we dont lookup the symbol in each and every forward/backward call.
  • Caching nvte_non_tn_fp8_gemm_supported() function call
  • Faster py object call without cxa_demangle to construct QuantizedTensor classes in C++
  • Reduce Python work in QuantizedTensor object creation(calculating stride from shape and getting current cuda device can be done in C++ instead Python Constructor).

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

vthumbe1503 and others added 2 commits January 5, 2026 18:11
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
vthumbe1503 and others added 4 commits January 6, 2026 12:34
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 marked this pull request as ready for review January 7, 2026 17:22
@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 7, 2026

Greptile Summary

This PR implements comprehensive CPU overhead optimizations across Python and C++ layers of TransformerEngine.

Python optimizations:

  • Caches pybinded enum integer values in FP8FwdTensorIdx/FP8BwdTensorIdx to avoid repeated enum-to-int conversions
  • Adds cached dtype, requires_grad, shape, is_cuda, and device properties on QuantizedTensor subclasses to avoid expensive PyObject lookups
  • Includes defensive fallbacks with hasattr() checks for alternate tensor construction paths

C++ optimizations:

  • Replaces pybind11 keyword argument syntax with direct C API calls using py::dict/py::tuple (properly using RAII wrappers to avoid memory leaks)
  • Implements symbol caching for CUDA driver functions with thread-safe mutex synchronization
  • Caches nvte_is_non_tn_fp8_gemm_supported() results to avoid redundant calls
  • Adds stride_from_shape() helper to compute strides in C++ instead of calling Python

Key improvements from previous review rounds:

  • Memory leak issues with PyTuple_New() have been addressed by using pybind11's RAII wrappers
  • Thread safety issue in cuda_driver.h symbol caching has been fixed with proper mutex protection
  • Defensive error handling added for edge cases where both _data and _transpose are None

Issue found:

  • Minor cache staleness bug in Float8Tensor._set_data() when copying between tensors with different dtypes (see inline comment)

Confidence Score: 4/5

  • This PR is safe to merge with minor risk from the dtype cache staleness bug
  • The optimizations are well-implemented with most previous review concerns addressed (memory leaks, thread safety, error handling). The caching strategies are sound and defensive checks are in place. Score reduced by 1 point for the minor dtype cache staleness bug in Float8Tensor._set_data() which could cause incorrect behavior when copying tensors with different dtypes.
  • transformer_engine/pytorch/tensor/float8_tensor.py needs the dtype cache update fix in _set_data method

Important Files Changed

Filename Overview
transformer_engine/pytorch/quantized_tensor.py Adds property caching for dtype and requires_grad to avoid expensive PyObject lookups. Includes defensive fallbacks with hasattr() checks for alternate construction paths.
transformer_engine/pytorch/csrc/quantizer.cpp Replaces pybind11 keyword argument syntax with direct C API calls using py::dict/py::tuple for performance. Adds stride_from_shape() helper to compute strides in C++ instead of Python. Caches nvte_is_non_tn_fp8_gemm_supported() result.
transformer_engine/common/util/cuda_driver.h Implements symbol caching with proper mutex synchronization to avoid repeated get_symbol() lookups for CUDA driver functions. Thread-safe implementation protects both reads and writes to symbol_cache.
transformer_engine/common/gemm/cublaslt_gemm.cu Caches nvte_is_non_tn_fp8_gemm_supported() result to avoid redundant calls during GEMM configuration for both A and B matrices. Clean optimization with proper scoping.
transformer_engine/pytorch/constants.py Adds FP8FwdTensorIdx and FP8BwdTensorIdx namespaces caching pybinded enum integer values to avoid repeated enum-to-int casts in forward/backward passes. Excellent optimization with no functional changes.
transformer_engine/pytorch/tensor/float8_tensor.py Adds cached shape and is_cuda properties to avoid expensive PyObject lookups. Includes proper error handling when both _data and _transpose are None. Minor issue: _dtype cache not updated in Float8Tensor-to-Float8Tensor copy path.

Last reviewed commit: 73e4d1d

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

24 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR implements CPU-side performance optimizations for FP8 operations by caching frequently accessed attributes and reducing redundant function calls. The optimizations target expensive PyObject attribute lookups on custom tensor types and repeated C++ function calls.

Key Changes:

  • Caches requires_grad, dtype, shape, and is_cuda attribute accesses to avoid expensive PyObject lookups on custom tensors
  • Reorders attribute checks in get_tensor_device() to prioritize internal quantized tensor attributes
  • Makes num_devices static in nvte_is_non_tn_fp8_gemm_supported() to cache device count
  • Stores GEMM support check results in local variables to avoid redundant function calls

Critical Issues Found:

  • Variable redeclaration error in cublaslt_gemm.cu (line 224) will prevent compilation
  • Logic bug in linear.py (line 484) changes FP8 state management from OR logic to AND logic, breaking functionality when bias is None or doesn't require grad

Confidence Score: 0/5

  • This PR cannot be merged due to compilation error and critical logic bug
  • Two critical issues prevent merging: (1) C++ compilation will fail due to variable redeclaration at line 224 of cublaslt_gemm.cu, and (2) logic bug at line 484 of linear.py breaks FP8 state management by requiring all three tensors to have requires_grad=True instead of any one of them
  • Pay close attention to transformer_engine/common/gemm/cublaslt_gemm.cu (compilation error) and transformer_engine/pytorch/module/linear.py (logic bug)

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/gemm/cublaslt_gemm.cu 1/5 Caches function call result to reduce overhead, but contains variable redeclaration error that will cause compilation failure
transformer_engine/common/transformer_engine.cpp 5/5 Makes num_devices static to avoid redundant calls to cuda::num_devices() - valid optimization
transformer_engine/pytorch/module/linear.py 0/5 Caches requires_grad checks for performance, but contains critical logic bug at line 484 that changes FP8 state management behavior

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant Linear as Linear Module
    participant Quantizer as Quantizer/QuantizedTensor
    participant GEMM as GEMM Operations
    participant CPP as C++ Extensions

    Note over Linear,CPP: Performance Optimization Flow
    
    User->>Linear: forward(input, weight, bias)
    
    Note over Linear: Cache requires_grad checks
    Linear->>Linear: inp_requires_grad = inp.requires_grad<br/>weight_requires_grad = weight.requires_grad<br/>bias_requires_grad = bias.requires_grad
    
    Linear->>Quantizer: Check if quantized tensor
    alt QuantizedTensor
        Note over Quantizer: Use cached dtype property
        Quantizer->>Quantizer: return self._dtype
        Note over Quantizer: Use cached shape/is_cuda
        Quantizer->>Quantizer: return self._data.shape
    else Regular Tensor
        Quantizer->>Linear: Standard attribute access
    end
    
    Linear->>CPP: get_tensor_device(tensor)
    Note over CPP: Reordered attribute checks
    CPP->>CPP: Check _rowwise_data first<br/>Check _columnwise_data<br/>Check device last
    CPP-->>Linear: device_index
    
    Linear->>GEMM: Configure GEMM parameters
    Note over GEMM: Cache nvte_is_non_tn_fp8_gemm_supported
    GEMM->>CPP: nvte_is_non_tn_fp8_gemm_supported()
    Note over CPP: Static num_devices cached
    CPP-->>GEMM: support_flag
    GEMM->>GEMM: Store in local variable
    
    GEMM->>GEMM: Execute optimized GEMM
    GEMM-->>Linear: output
    
    Note over Linear: FP8 State Management
    alt FP8 enabled and requires_grad check
        Linear->>Linear: Update FP8 tensors<br/>based on cached flags
    end
    
    Linear-->>User: output
Loading

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 7, 2026

Additional Comments (2)

transformer_engine/common/gemm/cublaslt_gemm.cu
variable redeclared in same scope - already declared at line 132

    // int is_nvte_non_tn_fp8_gemm_supported already declared at line 132

transformer_engine/pytorch/module/linear.py
logic change from original requires_grad(inp, weight, bias) which returns True if ANY tensor requires grad. New code requires ALL THREE to be True, breaking FP8 state management when bias is None or doesn't require grad

            if ctx.fp8 and (inp_requires_grad or weight_requires_grad or bias_requires_grad):

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, 8 comments

Edit Code Review Agent Settings | Greptile

ret.A_scale_inv = A.scale_inv.dptr;
ret.lda = is_A_transposed ? k : m;
if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) {
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable is_nvte_non_tn_fp8_gemm_supported is redeclared in the same scope - it was already declared at line 224 for the B matrix configuration. This causes a compilation error or shadowing issue.

The variable should only be declared once at the beginning of the function scope and reused for both A and B matrix configurations.

Suggested change
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
// Move this declaration before line 125 (Configure A matrix)
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();

Comment on lines 165 to 166
PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass),
PyTuple_New(0), kwargs);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Memory leak: PyTuple_New(0) creates a new Python tuple object with a reference count of 1, but this reference is never decremented. The tuple is passed directly to PyObject_Call which doesn't steal the reference, so it must be explicitly released.

This same issue occurs at lines 186, 398, 417, 697, 718, 1039, 1058, 1346, and 1367.

Suggested change
PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass),
PyTuple_New(0), kwargs);
PyObject* args = PyTuple_New(0);
PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass),
args, kwargs);
Py_DECREF(args);
Py_DECREF(kwargs);

PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr());
PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr());
PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());
PyDict_SetItemString(kwargs, "device", py::cast(device).ptr());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Invalid parameter: The Float8TensorStorage.__new__ method does not accept a device parameter. Looking at the Python class definition in float8_tensor_storage.py, the __new__ signature only expects: data, fp8_scale_inv, fp8_dtype, data_transpose, and quantizer.

This will cause a runtime error when trying to create Float8TensorStorage instances. This issue also occurs at lines 184, 396, 415, and in similar patterns for MXFP8 and other quantizers.

Suggested change
PyDict_SetItemString(kwargs, "device", py::cast(device).ptr());
// Remove this line - device parameter is not accepted
// PyDict_SetItemString(kwargs, "device", py::cast(device).ptr());

PyTypeObject *NVFP4TensorPythonClass = nullptr;
PyTypeObject *NVFP4TensorStoragePythonClass = nullptr;
PyTypeObject *NVFP4QuantizerClass = nullptr;
bool is_extension_initialized = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Race condition: The is_extension_initialized flag is not thread-safe. If init_extension() is called concurrently from multiple threads, both threads could pass the check at line 105 before either sets the flag at line 110, causing the initialization code to run multiple times.

The original implementation using null-checks on the individual class pointers was safer because Python's module import system provides thread safety. Consider using std::once_flag and std::call_once for proper thread-safe initialization:

Suggested change
bool is_extension_initialized = false;
bool is_extension_initialized = false;
std::once_flag extension_init_flag;

And in init_extension():

void init_extension() {
  std::call_once(extension_init_flag, []() {
    init_float8_extension();
    init_mxfp8_extension();
    init_float8blockwise_extension();
    init_nvfp4_extensions();
  });
}

Comment on lines 964 to 966
static int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache(num_devices, -1);
static std::vector<std::once_flag> flags(num_devices);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential edge case: The num_devices is captured once as a static variable, so if the number of devices changes at runtime (e.g., with CUDA MPS or device enumeration changes), the cache and flags vectors could be undersized, leading to out-of-bounds access at line 968.

While dynamic device changes are rare, consider adding a bounds check:

Suggested change
static int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache(num_devices, -1);
static std::vector<std::once_flag> flags(num_devices);
static int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache(num_devices, -1);
static std::vector<std::once_flag> flags(num_devices);
int device_id = transformer_engine::cuda::current_device();
NVTE_CHECK(device_id < num_devices, "Device ID exceeds cached device count");

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

ret.B_scale_inv = B.scale_inv.dptr;
ret.ldb = is_B_transposed ? n : k;
if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) {
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable is_nvte_non_tn_fp8_gemm_supported is redeclared in the same function scope (already declared at line 132). In C++, this is a compilation error. The variable should be declared once at the function scope level before both the A matrix and B matrix configuration blocks.

Suggested change
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
// Move declaration to line 104 (after GemmParam ret;) to avoid redeclaration

PyDict_SetItemString(kwargs, "device", py::cast(device).ptr());

PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass),
PyTuple_New(0), kwargs);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Memory leak: PyTuple_New(0) creates a new tuple with reference count 1, but PyObject_Call does not steal the reference. The tuple must be stored and explicitly DECREF'd after the call. This pattern is repeated 12 times throughout this file (lines 166, 186, 398, 417, 697, 718, 1039, 1058, 1346, 1367).

Suggested change
PyTuple_New(0), kwargs);
PyObject* args = PyTuple_New(0);
PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass),
args, kwargs);
Py_DECREF(args);

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

vthumbe1503 and others added 3 commits January 11, 2026 19:12
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

13 files reviewed, 13 comments

Edit Code Review Agent Settings | Greptile

ret.B_scale_inv = B.scale_inv.dptr;
ret.ldb = is_B_transposed ? n : k;
if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) {
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Variable is_nvte_non_tn_fp8_gemm_supported is already declared at line 132. Declaring it again here causes a compilation error. Remove this declaration and reuse the variable from line 132.

Suggested change
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
// Reuse is_nvte_non_tn_fp8_gemm_supported from line 132

}
py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none();

py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Use-after-free bug: scale_inv_py is created from scale_inv before scale_inv is initialized. When scale_inv parameter is empty (lines 148-150), this dereferences an empty std::optional, causing undefined behavior. Move this line after the scale_inv initialization.

Suggested change
py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none();
// Move scale_inv_py creation after initialization (line 150)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());
PyDict_SetItemString(kwargs, "device", py::cast(device).ptr());

PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Memory leak: PyTuple_New(0) creates a new Python object with refcount 1, but it's never decremented. This leaks memory on every tensor creation. Add Py_DECREF for the tuple:

Suggested change
PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass),
PyObject* args = PyTuple_New(0);
PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs);
Py_DECREF(args);

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

PyDict_SetItemString(kwargs, "is_2D_scaled", py::cast(block_scaling_dim == 2).ptr());
PyDict_SetItemString(kwargs, "data_format", py::cast(data_format).ptr());

PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8BlockwiseQTensorPythonClass),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr());
PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());

PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(MXFP8TensorStoragePythonClass),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).ptr());
PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());

PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(NVFP4TensorStoragePythonClass),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

11 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

vthumbe1503 and others added 3 commits February 24, 2026 06:59
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 changed the title CPU Optimizations for FP8 CPU Optimizations Feb 24, 2026
vthumbe1503 and others added 4 commits February 23, 2026 23:03
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

19 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@vthumbe1503 vthumbe1503 changed the title CPU Optimizations CPU Overhead Optimizations Feb 24, 2026
@vthumbe1503
Copy link
Collaborator Author

/te-ci pytorch

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Collaborator Author

/te-ci pytorch

@vthumbe1503
Copy link
Collaborator Author

/te-ci pytorch

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

23 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@vthumbe1503
Copy link
Collaborator Author

/te-ci pytorch

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

23 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Collaborator Author

/te-ci pytorch

pre-commit-ci bot and others added 2 commits February 27, 2026 17:51
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Collaborator Author

/te-ci pytorch

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 27, 2026

Additional Comments (1)

transformer_engine/pytorch/tensor/float8_tensor.py, line 1019
cached _dtype not updated when copying Float8Tensor with different dtype

when copying from one Float8Tensor to another with different dtype (line 1004 condition), the code creates a dummy tensor with the new dtype and sets it using super(Float8Tensor, type(self)).data.__set__(self, dummy_tensor) (line 1019). this bypasses QuantizedTensor._set_data() which updates the cached _dtype attribute

result: cached _dtype becomes stale and won't match the actual tensor's dtype

add after line 1019:

self._dtype = tensor.dtype

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant