diff --git a/.ci/scripts/test_model_e2e.sh b/.ci/scripts/test_model_e2e.sh index 5cee37b19cf..9643c286d6d 100755 --- a/.ci/scripts/test_model_e2e.sh +++ b/.ci/scripts/test_model_e2e.sh @@ -354,7 +354,9 @@ EOF fi ;; qwen3_5_moe) - RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --prompt 'What is the capital of France?' --max_new_tokens 128 --temperature 0" + RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --prompt 'What is the capital of France?' --max_new_tokens 128 --temperature 0 --cuda_graph" + # CUDA graph capture requires cudaMallocAsync backend for stream-ordered allocations + export PYTORCH_CUDA_ALLOC_CONF=backend:cudaMallocAsync ;; voxtral_realtime) RUNNER_ARGS="--model_path ${MODEL_DIR}/model.pte --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --preprocessor_path ${MODEL_DIR}/$PREPROCESSOR --audio_path ${MODEL_DIR}/$AUDIO_FILE --temperature 0" diff --git a/backends/cuda/runtime/cuda_backend.cpp b/backends/cuda/runtime/cuda_backend.cpp index eb0a07b8d8f..f98c862f69a 100644 --- a/backends/cuda/runtime/cuda_backend.cpp +++ b/backends/cuda/runtime/cuda_backend.cpp @@ -68,6 +68,7 @@ using executorch::runtime::Span; using executorch::runtime::etensor::Tensor; // SlimTensor type aliases +using cuda::CudaGraphPhase; using slim::CPU_DEVICE; using slim::DEFAULT_CUDA_DEVICE; using slim::DeviceTraits; @@ -80,6 +81,8 @@ namespace { constexpr char kSkipCopyOutputToCpuForMethod[] = "skip_copy_output_to_cpu_for_method"; constexpr char kUseSharedCudaStream[] = "use_shared_cuda_stream"; +constexpr char kEnableCudaGraphForMethod[] = "enable_cuda_graph_for_method"; +constexpr int kCudaGraphWarmupSteps = 3; constexpr char kShareKvCacheAcrossMethods[] = "share_kv_cache_across_methods"; } // anonymous namespace @@ -147,6 +150,20 @@ class ET_EXPERIMENTAL CudaBackend final return method_in_csv(method_name, skip_copy_method_); } + void set_cuda_graph_method( + const std::array& raw) { + std::lock_guard guard(cuda_graph_method_mutex_); + cuda_graph_method_ = std::string(raw.data()); + } + + bool should_use_cuda_graph_for_method(const std::string& method_name) const { + if (method_name.empty()) { + return false; + } + std::lock_guard guard(cuda_graph_method_mutex_); + return method_in_csv(method_name, cuda_graph_method_); + } + // Create the shared CUDA stream. Called when use_shared_cuda_stream option // is set to true. The presence of shared_cuda_stream_ indicates shared mode. void create_shared_cuda_stream() { @@ -265,6 +282,17 @@ class ET_EXPERIMENTAL CudaBackend final ET_LOG(Error, "Option %s must be a boolean.", kUseSharedCudaStream); return Error::InvalidArgument; } + } else if (std::strcmp(option.key, kEnableCudaGraphForMethod) == 0) { + if (auto* val = std::get_if>( + &option.value)) { + set_cuda_graph_method(*val); + } else { + ET_LOG( + Error, + "Option %s must be a method name string.", + kEnableCudaGraphForMethod); + return Error::InvalidArgument; + } } } return Error::Ok; @@ -532,6 +560,17 @@ class ET_EXPERIMENTAL CudaBackend final method_name.c_str()); } + // Initialize CUDA graph state if enabled for this method. + if (should_use_cuda_graph_for_method(method_name)) { + handle->cuda_graph_state.phase = CudaGraphPhase::Warmup; + handle->cuda_graph_state.warmup_remaining = kCudaGraphWarmupSteps; + ET_LOG( + Info, + "CUDA graph enabled for method '%s' (warmup=%d)", + method_name.c_str(), + kCudaGraphWarmupSteps); + } + return (DelegateHandle*)handle; // Return the handle post-processing } @@ -558,6 +597,68 @@ class ET_EXPERIMENTAL CudaBackend final n_outputs, args.size()) + // --------------------------------------------------------------- + // CUDA graph REPLAY path — skip all tensor setup and just replay + // --------------------------------------------------------------- + if (handle->cuda_graph_state.phase == CudaGraphPhase::Replay) { + Result csr = getCurrentCUDAStream(0); + ET_CHECK_OK_OR_RETURN_ERROR(csr.error()); + cudaStream_t cs = csr.get(); + + // Copy new input data into static input buffers + for (size_t i = 0; i < n_inputs; i++) { + auto* cpu_tensor = &(args[i]->toTensor()); + ET_CHECK_OR_RETURN_ERROR( + cpu_tensor->nbytes() == + handle->cuda_graph_state.static_input_nbytes[i], + InvalidArgument, + "CUDA graph replay: input %zu size mismatch (expected %zu, got %zu)", + i, + handle->cuda_graph_state.static_input_nbytes[i], + cpu_tensor->nbytes()); + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpyAsync( + handle->cuda_graph_state.static_input_ptrs[i], + cpu_tensor->const_data_ptr(), + handle->cuda_graph_state.static_input_nbytes[i], + cudaMemcpyHostToDevice, + cs)); + } + + // Replay the captured graph + cudaError_t gerr = + cudaGraphLaunch(handle->cuda_graph_state.graph_exec, cs); + ET_CHECK_OR_RETURN_ERROR( + gerr == cudaSuccess, + Internal, + "cudaGraphLaunch failed: %s", + cudaGetErrorString(gerr)); + + // Copy outputs back to CPU + const bool copy_outputs = + !should_skip_copy_for_method(handle->method_name); + if (copy_outputs) { + for (size_t i = 0; i < n_outputs; i++) { + auto* cpu_out = &(args[i + n_inputs]->toTensor()); + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpyAsync( + cpu_out->mutable_data_ptr(), + handle->cuda_graph_state.static_output_ptrs[i], + handle->cuda_graph_state.static_output_nbytes[i], + cudaMemcpyDeviceToHost, + cs)); + } + cudaStreamSynchronize(cs); + } + + return Error::Ok; + } + + // --------------------------------------------------------------- + // Normal path (also used for WARMUP and CAPTURE phases) + // --------------------------------------------------------------- + bool is_capture_step = + (handle->cuda_graph_state.phase == CudaGraphPhase::Warmup && + handle->cuda_graph_state.warmup_remaining == 0); + // NOTE: ExecuTorch tensors may be on CPU or GPU due to the skip-copy // optimization. We need to create GPU copies for CUDA kernel execution // using SlimTensor. @@ -568,6 +669,33 @@ class ET_EXPERIMENTAL CudaBackend final for (size_t i = 0; i < n_inputs; i++) { auto* cpu_tensor = &(args[i]->toTensor()); + // CAPTURE step: allocate persistent static GPU buffers + if (is_capture_step) { + size_t nbytes = cpu_tensor->nbytes(); + + void* static_ptr = nullptr; + cudaError_t merr = cudaMalloc(&static_ptr, nbytes); + ET_CHECK_OR_RETURN_ERROR( + merr == cudaSuccess, + Internal, + "cudaMalloc for static input %zu failed: %s", + i, + cudaGetErrorString(merr)); + + cudaMemcpy( + static_ptr, + cpu_tensor->const_data_ptr(), + nbytes, + cudaMemcpyHostToDevice); + + handle->cuda_graph_state.static_input_ptrs.push_back(static_ptr); + handle->cuda_graph_state.static_input_nbytes.push_back(nbytes); + + gpu_inputs[i] = make_slimtensor_from_blob_with_etensor_metadata( + static_ptr, cpu_tensor); + continue; + } + // Check if input data is already on GPU (skip-copy optimization for // inputs) This can happen when the caller has pre-staged data on GPU cudaPointerAttributes attributes{}; @@ -576,19 +704,8 @@ class ET_EXPERIMENTAL CudaBackend final cudaError_t err = cudaPointerGetAttributes(&attributes, data_ptr); if (err == cudaSuccess && attributes.type == cudaMemoryTypeDevice) { // Data is already on GPU - wrap it directly without copy - auto sizes = cpu_tensor->sizes(); - auto strides = cpu_tensor->strides(); - std::vector sizes_vec(sizes.begin(), sizes.end()); - std::vector strides_vec(strides.begin(), strides.end()); - - gpu_inputs[i] = new SlimTensor(slim::from_blob( - const_cast(data_ptr), - slim::makeArrayRef(sizes_vec), - slim::makeArrayRef(strides_vec), - static_cast(cpu_tensor->scalar_type()), - DEFAULT_CUDA_DEVICE, - 0 // storage_offset - )); + gpu_inputs[i] = make_slimtensor_from_blob_with_etensor_metadata( + const_cast(data_ptr), cpu_tensor); continue; } @@ -640,8 +757,25 @@ class ET_EXPERIMENTAL CudaBackend final // NOTE: run() steals input handles (RAII wraps them at the start of // run_impl) and may replace output handles with its own. Result cuda_stream_ret = getCurrentCUDAStream(0); - cudaStream_t cuda_stream = cuda_stream_ret.get(); ET_CHECK_OK_OR_RETURN_ERROR(cuda_stream_ret.error()); + cudaStream_t cuda_stream = cuda_stream_ret.get(); + + if (is_capture_step) { + // ----- CUDA graph CAPTURE ----- + ET_LOG( + Info, + "CUDA graph: beginning stream capture for '%s'", + handle->method_name.c_str()); + + cudaError_t cerr = + cudaStreamBeginCapture(cuda_stream, cudaStreamCaptureModeRelaxed); + ET_CHECK_OR_RETURN_ERROR( + cerr == cudaSuccess, + Internal, + "cudaStreamBeginCapture failed: %s", + cudaGetErrorString(cerr)); + } + AOTIRuntimeError error = handle->run( handle->container_handle, reinterpret_cast(gpu_inputs.data()), @@ -667,6 +801,87 @@ class ET_EXPERIMENTAL CudaBackend final "AOTInductorModelContainerRun failed with error code %d", error); + if (is_capture_step) { + // End capture → instantiate graph + cudaError_t gerr = + cudaStreamEndCapture(cuda_stream, &handle->cuda_graph_state.graph); + ET_CHECK_OR_RETURN_ERROR( + gerr == cudaSuccess, + Internal, + "cudaStreamEndCapture failed: %s", + cudaGetErrorString(gerr)); + + gerr = cudaGraphInstantiate( + &handle->cuda_graph_state.graph_exec, + handle->cuda_graph_state.graph, + cudaGraphInstantiateFlagAutoFreeOnLaunch); + ET_CHECK_OR_RETURN_ERROR( + gerr == cudaSuccess, + Internal, + "cudaGraphInstantiate failed: %s", + cudaGetErrorString(gerr)); + + // Record static output pointers (stable under graph replay) + for (size_t i = 0; i < n_outputs; i++) { + SlimTensor* out = gpu_outputs[i]; + handle->cuda_graph_state.static_output_ptrs.push_back(out->data_ptr()); + handle->cuda_graph_state.static_output_nbytes.push_back(out->nbytes()); + } + + handle->cuda_graph_state.phase = CudaGraphPhase::Replay; + ET_LOG( + Info, + "CUDA graph: captured and instantiated for '%s'", + handle->method_name.c_str()); + + // Replay once to actually produce output (capture doesn't execute) + gerr = cudaGraphLaunch(handle->cuda_graph_state.graph_exec, cuda_stream); + ET_CHECK_OR_RETURN_ERROR( + gerr == cudaSuccess, + Internal, + "cudaGraphLaunch (first replay) failed: %s", + cudaGetErrorString(gerr)); + + // Copy capture-step outputs to CPU + const bool copy_outputs = + !should_skip_copy_for_method(handle->method_name); + if (copy_outputs) { + for (size_t i = 0; i < n_outputs; i++) { + auto* cpu_out = &(args[i + n_inputs]->toTensor()); + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpyAsync( + cpu_out->mutable_data_ptr(), + handle->cuda_graph_state.static_output_ptrs[i], + handle->cuda_graph_state.static_output_nbytes[i], + cudaMemcpyDeviceToHost, + cuda_stream)); + // Don't delete — static buffers are owned by the handle + gpu_outputs[i] = nullptr; + } + cudaStreamSynchronize(cuda_stream); + } else { + // Even when skipping copy, null out gpu_outputs to prevent + // the ScopeGuard from deleting static output buffers. + for (size_t i = 0; i < n_outputs; i++) { + gpu_outputs[i] = nullptr; + } + } + + return Error::Ok; + } + + // ----- Normal / WARMUP execution continues here ----- + + // Decrement warmup counter if in warmup phase + if (handle->cuda_graph_state.phase == CudaGraphPhase::Warmup && + handle->cuda_graph_state.warmup_remaining > 0) { + handle->cuda_graph_state.warmup_remaining--; + ET_LOG( + Info, + "CUDA graph warmup: %d steps remaining for '%s'", + handle->cuda_graph_state.warmup_remaining, + handle->method_name.c_str()); + } + const bool copy_outputs = !should_skip_copy_for_method(handle->method_name); if (copy_outputs) { @@ -761,6 +976,9 @@ class ET_EXPERIMENTAL CudaBackend final mutable std::mutex skip_copy_method_mutex_; std::string skip_copy_method_; + mutable std::mutex cuda_graph_method_mutex_; + std::string cuda_graph_method_; + // Shared CUDA stream for all methods. When set (non-null), all methods use // the same stream to ensure proper ordering (critical for skip-copy // optimization). Created when use_shared_cuda_stream option is set to true. diff --git a/backends/cuda/runtime/cuda_delegate_handle.h b/backends/cuda/runtime/cuda_delegate_handle.h index 02d3356379f..87845fbc312 100644 --- a/backends/cuda/runtime/cuda_delegate_handle.h +++ b/backends/cuda/runtime/cuda_delegate_handle.h @@ -11,6 +11,7 @@ #include #include #include +#include namespace executorch { namespace backends { @@ -38,6 +39,93 @@ inline std::shared_ptr create_cuda_stream() { return std::shared_ptr( new cudaStream_t(stream), CudaStreamDeleter()); } + +enum class CudaGraphPhase { + Disabled = 0, + Warmup = 1, + Replay = 2, +}; + +// All CUDA graph related state grouped into a single struct. +struct CudaGraphState { + CudaGraphPhase phase = CudaGraphPhase::Disabled; + int warmup_remaining = 0; + + // Captured graph and executable instance + cudaGraph_t graph = nullptr; + cudaGraphExec_t graph_exec = nullptr; + + // Static input/output GPU buffers pinned during capture. + // These hold the tensor metadata; the underlying data pointers are fixed + // addresses that CUDA graph replay will write to / read from. + std::vector static_input_ptrs; + std::vector static_output_ptrs; + std::vector static_input_nbytes; + std::vector static_output_nbytes; + + CudaGraphState() = default; + + ~CudaGraphState() { + if (graph_exec) { + cudaGraphExecDestroy(graph_exec); + } + if (graph) { + cudaGraphDestroy(graph); + } + // Only free input buffers — output buffers are owned by the AOTI runtime + // (allocated during graph capture via the caching allocator). + for (auto* ptr : static_input_ptrs) { + if (ptr) + cudaFree(ptr); + } + } + + // Non-copyable: prevent double-free of CUDA resources + CudaGraphState(const CudaGraphState&) = delete; + CudaGraphState& operator=(const CudaGraphState&) = delete; + + // Movable + CudaGraphState(CudaGraphState&& other) noexcept + : phase(other.phase), + warmup_remaining(other.warmup_remaining), + graph(other.graph), + graph_exec(other.graph_exec), + static_input_ptrs(std::move(other.static_input_ptrs)), + static_output_ptrs(std::move(other.static_output_ptrs)), + static_input_nbytes(std::move(other.static_input_nbytes)), + static_output_nbytes(std::move(other.static_output_nbytes)) { + other.graph = nullptr; + other.graph_exec = nullptr; + } + + CudaGraphState& operator=(CudaGraphState&& other) noexcept { + if (this != &other) { + // Clean up existing resources + if (graph_exec) + cudaGraphExecDestroy(graph_exec); + if (graph) + cudaGraphDestroy(graph); + for (auto* ptr : static_input_ptrs) { + if (ptr) + cudaFree(ptr); + } + + phase = other.phase; + warmup_remaining = other.warmup_remaining; + graph = other.graph; + graph_exec = other.graph_exec; + static_input_ptrs = std::move(other.static_input_ptrs); + static_output_ptrs = std::move(other.static_output_ptrs); + static_input_nbytes = std::move(other.static_input_nbytes); + static_output_nbytes = std::move(other.static_output_nbytes); + + other.graph = nullptr; + other.graph_exec = nullptr; + } + return *this; + } +}; + // CUDA-specific delegate handle that extends AOTIDelegateHandle. // This consolidates CUDA stream management into a single location. struct CudaDelegateHandle : public aoti::AOTIDelegateHandle { @@ -58,6 +146,9 @@ struct CudaDelegateHandle : public aoti::AOTIDelegateHandle { bool has_cuda_stream() const { return cuda_stream != nullptr && *cuda_stream != nullptr; } + + // CUDA graph state (warmup, capture, replay, static buffers) + CudaGraphState cuda_graph_state; }; } // namespace cuda diff --git a/backends/cuda/runtime/utils.h b/backends/cuda/runtime/utils.h index e6a340f9a49..e72ad278e9c 100644 --- a/backends/cuda/runtime/utils.h +++ b/backends/cuda/runtime/utils.h @@ -17,6 +17,8 @@ #include #include #include +#include +#include namespace executorch::backends::cuda { @@ -314,4 +316,42 @@ inline void delete_slimtensor_vector( tensors.clear(); } +/** + * Creates a non-owning SlimTensor that wraps an external data pointer, using + * the shape, strides and dtype of the given ETensor as metadata. + * + * Common helper for the CUDA backend, where we frequently need to create a + * SlimTensor view of memory (e.g., a GPU buffer) that mirrors the layout of a + * CPU/GPU ETensor. The returned tensor does NOT own the data; the caller must + * keep it alive for the SlimTensor's lifetime. + * + * @param data_ptr Pointer to memory the SlimTensor should reference. + * @param etensor ETensor whose sizes/strides/dtype define the SlimTensor's + * metadata. + * @param device Device where data_ptr resides (defaults to the default + * CUDA device). + * @return A heap-allocated SlimTensor; the caller takes ownership. + */ +inline executorch::backends::aoti::slim::SlimTensor* +make_slimtensor_from_blob_with_etensor_metadata( + void* data_ptr, + const executorch::runtime::etensor::Tensor* etensor, + const executorch::backends::aoti::slim::c10::Device& device = + executorch::backends::aoti::slim::DEFAULT_CUDA_DEVICE) { + auto sizes = etensor->sizes(); + auto strides = etensor->strides(); + std::vector sizes_vec(sizes.begin(), sizes.end()); + std::vector strides_vec(strides.begin(), strides.end()); + + return new executorch::backends::aoti::slim::SlimTensor( + executorch::backends::aoti::slim::from_blob( + data_ptr, + executorch::backends::aoti::slim::makeArrayRef(sizes_vec), + executorch::backends::aoti::slim::makeArrayRef(strides_vec), + static_cast( + etensor->scalar_type()), + device, + /*storage_offset=*/0)); +} + } // namespace executorch::backends::cuda diff --git a/examples/models/qwen3_5_moe/main.cpp b/examples/models/qwen3_5_moe/main.cpp index 7f4e60596be..e7cd83dddc2 100644 --- a/examples/models/qwen3_5_moe/main.cpp +++ b/examples/models/qwen3_5_moe/main.cpp @@ -13,6 +13,8 @@ #include #include #include +#include +#include #include #include @@ -28,6 +30,7 @@ DEFINE_string(tokenizer_path, "", "HuggingFace tokenizer.json path."); DEFINE_string(prompt, "Hello", "Prompt text."); DEFINE_double(temperature, 0.8, "Sampling temperature (0 = greedy)."); DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate."); +DEFINE_bool(cuda_graph, false, "Enable CUDA graph for decode method."); namespace llm = ::executorch::extension::llm; using ::executorch::extension::from_blob; @@ -84,29 +87,25 @@ int main(int argc, char** argv) { } auto metadata = metadata_result.get(); + // Set CUDA graph option if requested (must be before load_method) + if (FLAGS_cuda_graph) { + executorch::runtime::BackendOptions<2> cuda_opts; + cuda_opts.set_option("enable_cuda_graph_for_method", "decode"); + executorch::runtime::set_option("CudaBackend", cuda_opts.view()); + printf("CUDA graph enabled for decode method\n"); + } + printf("Loading methods...\n"); - // Try loading both methods; fall back to single "forward" method - bool dual_method = true; - std::string prefill_method = "prefill"; auto err = module->load_method("prefill"); if (err != Error::Ok) { - // Try "forward" for single-method export - err = module->load_method("forward"); - if (err != Error::Ok) { - ET_LOG(Error, "Failed to load prefill/forward method"); - return 1; - } - prefill_method = "forward"; - dual_method = false; - printf("Using single-method mode (forward)\n"); + ET_LOG(Error, "Failed to load prefill method"); + return 1; } - if (dual_method) { - err = module->load_method("decode"); - if (err != Error::Ok) { - ET_LOG(Error, "Failed to load decode method"); - return 1; - } + err = module->load_method("decode"); + if (err != Error::Ok) { + ET_LOG(Error, "Failed to load decode method"); + return 1; } // Get EOS ids @@ -149,7 +148,7 @@ int main(int argc, char** argv) { prefill_inputs.push_back(tokens_tensor); prefill_inputs.push_back(pos_tensor); - auto prefill_result = module->execute(prefill_method, prefill_inputs); + auto prefill_result = module->execute("prefill", prefill_inputs); if (prefill_result.error() != Error::Ok) { ET_LOG(Error, "Prefill failed"); return 1; @@ -176,11 +175,6 @@ int main(int argc, char** argv) { // decode method, which may run on a different CUDA stream. cudaDeviceSynchronize(); - if (!dual_method) { - printf("Single-method mode: skipping decode\n"); - return 0; - } - // --------------------------------------------------------------- // Decode — generate tokens one at a time // ---------------------------------------------------------------