[Draft] Newton-Schulz via cuSOLVERMp#2706
[Draft] Newton-Schulz via cuSOLVERMp#2706vcherepanov-nv wants to merge 29 commits intoNVIDIA:mainfrom
Conversation
Add a new distributed Newton-Schulz inverse square root API to Transformer Engine's common C library. This wraps the cusolverMpNewtonSchulz library function, following the same pattern as the existing cuBLASMp integration for comm_gemm. New files: - newton_schulz.h: Public C API header with context management and computation functions - newton_schulz/newton_schulz.cpp: Implementation with RAII wrappers for cuSolverMp handles Build integration: - New NVTE_WITH_CUSOLVERMP CMake option and CUSOLVERMP_HOME env var - NVTE_CHECK_CUSOLVERMP error checking macro in logging.h - Conditional compilation guarded by NVTE_WITH_CUSOLVERMP Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Add PyTorch-level bindings for the cuSolverMp Newton-Schulz inverse square root API introduced in the previous commit. New files: - pytorch/csrc/extensions/newton_schulz.cpp: C++ extension wrapping the C API with PyTorch tensor support - pytorch/newton_schulz.py: Python wrapper that extracts NCCL communicator from torch.distributed ProcessGroup - tests/pytorch/distributed/test_newton_schulz.py: pytest launcher - tests/pytorch/distributed/run_newton_schulz.py: distributed test worker with reference implementation for numerical validation Modified files: - pytorch/csrc/extensions.h: Function declarations - pytorch/csrc/extensions/pybind.cpp: pybind11 registrations - pytorch/__init__.py: Public API export Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Fix API mismatches discovered during compilation: - cusolverMpCreate takes (handle*, deviceId, stream), not (handle*, stream) - cusolverMpCreateDeviceGrid takes handle as first arg with different parameter order - Use cusolverMpGridMapping_t (not cusolverMpGridLayout_t) and CUSOLVERMP_GRID_MAPPING_COL_MAJOR - cusolverMpCreateMatrixDesc has different parameter order: (desc*, grid, dtype, M, N, MB, NB, RSRC, CSRC, LLD) - cusolverMpNewtonSchulzDescriptorCreate takes only (nsDesc*) with no iteration/coefficient args - No cusolverMpStreamSet exists; create handle per-call with user stream - cusolverMpNewtonSchulz requires computeType and info parameters - Switch from generic template RAII to explicit deleter structs Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
…build Add NVTE_WITH_CUSOLVERMP compiler define and cusolverMp include/library paths to the PyTorch C++ extension build, following the same pattern as NVTE_UB_WITH_MPI and NVTE_ENABLE_NVSHMEM. Without this, the #ifdef NVTE_WITH_CUSOLVERMP guards in the PyTorch extension code would never be active since the define was only set as PRIVATE in the CMake build for the common library. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Two fixes: - Use ProcessGroupNCCL._comm_ptr() to extract the raw NCCL communicator pointer instead of the non-existent get_nccl_comm() method - Pass global matrix dimensions (m, n) from Python to C++ instead of using local tensor dimensions, which would produce incorrect ScaLAPACK block sizes in the distributed computation Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
cuSolverMp handle and grid creation are expensive operations. Move them from per-call creation in nvte_newton_schulz into the NVTECusolverMpCtx, which is their natural home — the context exists to encapsulate the grid. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
cuSolverMp cannot work with the default CUDA stream. Create a dedicated stream inside nvte_cusolvermp_ctx_create and remove the stream parameter from both C API functions since the context now owns its stream. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
The internal dedicated stream was reading the input tensor before the caller's stream had finished producing it, resulting in all-zero output. Add event-based synchronisation: the internal stream waits for the caller's input to be ready, and the caller's stream waits for the output to be written. Replaces the blocking cudaStreamSynchronize. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
cuSolverMp is asynchronous and uses the host workspace during multi-GPU execution. The event-based output sync did not block the host, so the local workspace_host vector was destroyed while the GPU was still reading from it. Restore cudaStreamSynchronize to ensure the host workspace remains valid for the full duration of the operation. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Avoid creating and destroying a cudaEvent_t on every nvte_newton_schulz call by making it a persistent member of NVTECusolverMpCtx, matching the existing pattern for the stream. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Replace single event with in_ready and out_ready events. After the cuSolverMp call, record out_ready on the internal stream and make the caller's stream wait on it, ensuring the output tensor is ready before the caller uses it. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Replace reference-comparison test with a direct arithmetic check: if X is the inverse square root of A, then X @ A @ X must equal the identity matrix. This is more robust and removes the need for a separate reference implementation. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR integrates cuSolverMp to provide distributed Newton-Schulz matrix orthogonalization. The implementation adds ~650 lines across build configuration, C++ bindings, Python API, and tests. Key changes:
Major concerns from prior review rounds:
Additional issue found:
Confidence Score: 2/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as Python User
participant PyAPI as newton_schulz.py
participant PyExt as C++ Extension
participant Common as newton_schulz.cpp
participant cuSolver as cuSolverMp Library
participant NCCL as NCCL Communicator
User->>PyAPI: newton_schulz(x, group, iterations, coeffs)
PyAPI->>PyAPI: Extract NCCL comm from ProcessGroup
PyAPI->>PyAPI: Calculate global dims (m, n)
PyAPI->>PyExt: cusolvermp_ctx_create(comm, nranks, rank)
PyExt->>Common: nvte_cusolvermp_ctx_create()
Common->>Common: Create CUDA stream & events
Common->>cuSolver: cusolverMpCreate()
Common->>cuSolver: cusolverMpCreateDeviceGrid()
Common-->>PyExt: Return context pointer
PyExt-->>PyAPI: Return context handle
PyAPI->>PyExt: newton_schulz(ctx, m, n, x, iters, coeffs)
PyExt->>Common: nvte_newton_schulz()
Common->>Common: Stream synchronization (events)
Common->>cuSolver: cusolverMpNewtonSchulz_bufferSize()
cuSolver-->>Common: Workspace size
Common->>Common: Allocate/grow workspace (cudaMalloc)
Common->>cuSolver: cusolverMpNewtonSchulz()
cuSolver->>NCCL: Distributed matrix operations
NCCL-->>cuSolver: Sync results
cuSolver-->>Common: Modified matrix (in-place)
Common->>Common: Stream synchronization (events)
Common-->>PyExt: Success
PyExt-->>PyAPI: Success
PyAPI->>PyExt: cusolvermp_ctx_destroy(ctx)
PyExt->>Common: nvte_cusolvermp_ctx_destroy()
Common->>Common: Free workspace
Common->>cuSolver: Destroy grid & handle
Common->>Common: Destroy stream & events
Common-->>PyExt: Done
PyExt-->>PyAPI: Done
PyAPI-->>User: Modified tensor x
Last reviewed commit: d3740fb |
| # Check: if X = A^{-1/2}, then X @ A @ X should be the identity matrix | ||
| if rank == 0: | ||
| XXT = X @ X.t() | ||
| I = torch.eye(N, device=XXT.device, dtype=XXT.dtype) | ||
| max_diff = (XXT - I).abs().max().item() | ||
| print(f"Max |X @ X.t() - I|: {max_diff:.6e}", flush=True) |
There was a problem hiding this comment.
verification doesn't match the comment - if X = A^{-1/2}, the check should be X @ A @ X ≈ I, not X @ X.t() ≈ I. The current check verifies X is orthogonal, not that X is the inverse square root of A. Note that A_orig is created on line 76 but never used.
| # Check: if X = A^{-1/2}, then X @ A @ X should be the identity matrix | |
| if rank == 0: | |
| XXT = X @ X.t() | |
| I = torch.eye(N, device=XXT.device, dtype=XXT.dtype) | |
| max_diff = (XXT - I).abs().max().item() | |
| print(f"Max |X @ X.t() - I|: {max_diff:.6e}", flush=True) | |
| # Check: if X = A^{-1/2}, then X @ A @ X should be the identity matrix | |
| XAX = X @ A_orig @ X | |
| I = torch.eye(N, device=XAX.device, dtype=XAX.dtype) | |
| max_diff = (XAX - I).abs().max().item() | |
| print(f"Max |X @ A @ X - I|: {max_diff:.6e}", flush=True) | |
| if torch.allclose(XAX, I, atol=args.atol, rtol=args.rtol): |
| nccl_backend = group._get_backend(torch.device("cuda")) | ||
| return nccl_backend._comm_ptr() |
There was a problem hiding this comment.
uses private PyTorch APIs (_get_backend, _comm_ptr) that may change in future versions
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
| quintic_coefficients = [ | ||
| 4.0848, | ||
| -6.8946, | ||
| 2.9270, | ||
| 3.9505, | ||
| -6.3029, | ||
| 2.6377, | ||
| 3.7418, | ||
| -5.5913, | ||
| 2.3037, | ||
| 2.8769, | ||
| -3.1427, | ||
| 1.2046, | ||
| 2.8366, | ||
| -3.0525, | ||
| 1.2012, | ||
| ] | ||
| coefficients = ( | ||
| quintic_coefficients if args.num_iterations == 5 else [1.5, -0.5, 0.0] * args.num_iterations | ||
| ) |
There was a problem hiding this comment.
coefficients mismatch with API defaults - test uses 15 coefficients for 5 iterations, but newton_schulz.py defaults to 5 coefficients. This inconsistency means default API behavior isn't tested.
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
| * \brief Functions for distributed Newton-Schulz inverse square root. | ||
| * | ||
| * This API is a TE-native binding to the cuSolverMp library. | ||
| * It computes an iterative Newton-Schulz inverse square root | ||
| * approximation on a distributed matrix. |
There was a problem hiding this comment.
Documentation claims this computes "inverse square root" but the test validates orthogonality (X @ X.t() ≈ I), and commit dd1dd0b states "it approximates orthogonal matrix, not inverse square root". If this computes the polar decomposition (orthogonal factor), the documentation should be updated to reflect that. Inverse square root would satisfy X @ A @ X ≈ I, which is different from orthogonality.
| m = x.size(0) * nranks # rows are distributed across ranks | ||
| n = x.size(1) |
There was a problem hiding this comment.
Assumes rows are evenly distributed (m = x.size(0) * nranks) but doesn't validate this. If matrix size isn't divisible by nranks, the computed global size m will be incorrect, leading to wrong results from cuSOLVERMp. Consider adding validation:
| m = x.size(0) * nranks # rows are distributed across ranks | |
| n = x.size(1) | |
| # Global matrix dimensions | |
| # Rows must be evenly distributed across ranks | |
| local_rows = x.size(0) | |
| m = local_rows * nranks | |
| n = x.size(1) |
Then add a validation check that all ranks have the same local_rows via dist.all_reduce.
| num_iterations: int = 5, | ||
| coefficients: Optional[List[float]] = None, | ||
| ) -> None: | ||
| """Compute Newton-Schulz inverse square root in-place on a distributed matrix. |
There was a problem hiding this comment.
Docstring says "inverse square root" but test checks orthogonality. Update to match actual behavior (see comment on header file).
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
| void nvte_newton_schulz(NVTECusolverMpCtx* ctx, int64_t m, int64_t n, NVTETensor x, | ||
| int64_t num_iterations, const float* coefficients, int64_t num_coefficients, | ||
| cudaStream_t caller_stream) { |
There was a problem hiding this comment.
num_coefficients parameter is unused
The num_coefficients parameter is accepted but never referenced in the function body. Neither cusolverMpNewtonSchulz_bufferSize (line 166) nor cusolverMpNewtonSchulz (line 183) receive this value. If cuSolverMp infers the count from num_iterations internally, then num_coefficients is dead code that should be removed from the API. If cuSolverMp actually needs it, then it should be passed to the cuSolverMp calls — otherwise the library may read out of bounds on the coefficients array.
| from transformer_engine.pytorch import optimizers | ||
| from transformer_engine.pytorch.export import onnx_export | ||
| from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy | ||
| from transformer_engine.pytorch.newton_schulz import newton_schulz |
There was a problem hiding this comment.
Unconditional import of optional feature
newton_schulz is unconditionally imported and exported as part of the public API, even when TE is built without NVTE_WITH_CUSOLVERMP. While the function itself raises a runtime error when called, this exposes the symbol to all users and makes it appear as a supported feature in auto-complete and docs. Consider guarding this import behind a check (similar to how other optional features are handled), or at minimum adding a note in the docstring that the function requires NVTE_WITH_CUSOLVERMP=1 at build time.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| // Allocate/grow device workspace | ||
| if (ctx->workspace_size < wrksp_size_device) { | ||
| if (ctx->workspace) { | ||
| NVTE_CHECK_CUDA(cudaFree(ctx->workspace)); | ||
| } | ||
| NVTE_CHECK_CUDA(cudaMalloc(&ctx->workspace, wrksp_size_device)); | ||
| ctx->workspace_size = wrksp_size_device; | ||
| } |
There was a problem hiding this comment.
Synchronous cudaFree/cudaMalloc on hot path
cudaFree followed by cudaMalloc inside nvte_newton_schulz will synchronize with the device each time the workspace needs to grow. Since the context is recreated on every call from newton_schulz.py (line 82-86 creates + destroys ctx each invocation), the workspace will never be reused across calls — the grow-only caching here is ineffective. Consider either:
- Allowing callers to keep the context alive across calls, or
- Using
cudaMallocAsync/cudaFreeAsynconctx->streamto avoid synchronous stalls.
| ctx_ptr = tex.cusolvermp_ctx_create(nccl_comm_ptr, nranks, rank) | ||
| try: | ||
| tex.newton_schulz(ctx_ptr, m, n, x, num_iterations, coefficients) | ||
| finally: | ||
| tex.cusolvermp_ctx_destroy(ctx_ptr) |
There was a problem hiding this comment.
Context created/destroyed per call wastes resources
A new NVTECusolverMpCtx is created and destroyed on every invocation of newton_schulz. Context creation involves cudaStreamCreate, two cudaEventCreate calls, cusolverMpCreate, and cusolverMpCreateDeviceGrid — all of which are heavyweight operations. And since the context is destroyed afterward, the grow-only workspace caching in the C++ layer (lines 170-177 of newton_schulz.cpp) is never actually reused.
Consider caching the context (e.g., in a module-level dict keyed by (nccl_comm_ptr, nranks, rank)) and reusing it across calls, or exposing the context lifecycle to callers so they can amortize the cost when calling newton_schulz repeatedly in a training loop.
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
| assert ( | ||
| len(coefficients) == num_iterations * 3 | ||
| ), f"Unexpected number of coefficients: {len(coefficients)} for {num_iterations} iterations" |
There was a problem hiding this comment.
use ValueError instead of assert for validation - assert can be disabled with Python's -O flag
| assert ( | |
| len(coefficients) == num_iterations * 3 | |
| ), f"Unexpected number of coefficients: {len(coefficients)} for {num_iterations} iterations" | |
| if len(coefficients) != num_iterations * 3: | |
| raise ValueError( | |
| f"Unexpected number of coefficients: {len(coefficients)} for {num_iterations} iterations" | |
| ) |
| if x.dim() != 2: | ||
| raise ValueError(f"Expected 2D tensor, got {x.dim()}D") | ||
| if not x.is_cuda: | ||
| raise ValueError("Input tensor must be on CUDA device") |
There was a problem hiding this comment.
missing contiguity check - C++ code uses data_ptr() which requires contiguous memory. Non-contiguous tensors will cause incorrect results.
| if x.dim() != 2: | |
| raise ValueError(f"Expected 2D tensor, got {x.dim()}D") | |
| if not x.is_cuda: | |
| raise ValueError("Input tensor must be on CUDA device") | |
| if x.dim() != 2: | |
| raise ValueError(f"Expected 2D tensor, got {x.dim()}D") | |
| if not x.is_cuda: | |
| raise ValueError("Input tensor must be on CUDA device") | |
| if not x.is_contiguous(): | |
| raise ValueError("Input tensor must be contiguous") |
Instead of requiring NVTE_WITH_CUSOLVERMP env var to be set for both the common library and PyTorch extension builds, inspect the already-built libtransformer_engine.so for exported symbols. This is more robust for incremental builds and CI environments where the env var may not be propagated to the extension build step. The PyTorch extension only calls nvte_* C API functions, so it does not need cusolverMp headers or libraries — only the compile definition. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
When NVTE_WITH_CUSOLVERMP is not defined, omit the Newton-Schulz functions entirely from the pybind module instead of registering stubs that throw runtime errors. The Python wrapper checks for the attribute at call time and raises a clear error message. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
| if x.dim() != 2: | ||
| raise ValueError(f"Expected 2D tensor, got {x.dim()}D") | ||
| if not x.is_cuda: | ||
| raise ValueError("Input tensor must be on CUDA device") |
There was a problem hiding this comment.
Missing dtype validation - docstring on line 36 states tensor must be float32 or bfloat16, but this isn't enforced. Passing unsupported dtypes leads to confusing errors from cuSolverMp.
| if x.dim() != 2: | |
| raise ValueError(f"Expected 2D tensor, got {x.dim()}D") | |
| if not x.is_cuda: | |
| raise ValueError("Input tensor must be on CUDA device") | |
| if x.dim() != 2: | |
| raise ValueError(f"Expected 2D tensor, got {x.dim()}D") | |
| if not x.is_cuda: | |
| raise ValueError("Input tensor must be on CUDA device") | |
| if x.dtype not in (torch.float32, torch.bfloat16): | |
| raise ValueError(f"Input tensor must be float32 or bfloat16, got {x.dtype}") |
Description
Adds an API to call Newton-Schulz method on a distributed tensor.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: