From 100061d2aff834108e04e6d30248fcd033689eb8 Mon Sep 17 00:00:00 2001 From: juncaipeng <13006307475@163.com> Date: Fri, 15 May 2026 15:02:10 +0000 Subject: [PATCH 1/6] PD send cache via storage & Refine swap_cache_layout op --- custom_ops/gpu_ops/swap_cache_layout.cu | 280 ++++++++++++++---- examples/cache_storage/run_03b_pd_storage.sh | 2 +- fastdeploy/cache_manager/cache_messager.py | 18 ++ .../cache_manager/prefix_cache_manager.py | 177 +++++++++++ fastdeploy/engine/common_engine.py | 26 ++ .../engine/sched/resource_manager_v1.py | 8 +- fastdeploy/envs.py | 4 + fastdeploy/output/token_processor.py | 13 + 8 files changed, 468 insertions(+), 60 deletions(-) diff --git a/custom_ops/gpu_ops/swap_cache_layout.cu b/custom_ops/gpu_ops/swap_cache_layout.cu index 62adccb2d04..8fd497f2362 100644 --- a/custom_ops/gpu_ops/swap_cache_layout.cu +++ b/custom_ops/gpu_ops/swap_cache_layout.cu @@ -15,74 +15,239 @@ #include "helper.h" #include "paddle/extension.h" -// #define SWAP_DEBUG +// D2H: Each thread block handles ALL layers for one swap block. +// This produces perfectly contiguous host writes (1 block × all layers), +// maximizing write-combining efficiency. +template +__global__ void swap_d2h_kernel(T** __restrict__ layer_ptrs, + T* __restrict__ cpu_buffer, + const int64_t* __restrict__ gpu_block_ids, + int n_blocks, + int layer_num, + int64_t block_stride) { + int block_idx = blockIdx.x; + if (block_idx >= n_blocks) return; + + int64_t gpu_block = gpu_block_ids[block_idx]; + int64_t num_vec_per_layer = (block_stride * sizeof(T)) / sizeof(float4); + + T* dst_base = cpu_buffer + (int64_t)block_idx * layer_num * block_stride; + + for (int layer_idx = 0; layer_idx < layer_num; layer_idx++) { + const T* src = layer_ptrs[layer_idx] + gpu_block * block_stride; + float4* dst4 = + reinterpret_cast(dst_base + layer_idx * block_stride); + const float4* src4 = reinterpret_cast(src); + + for (int64_t i = threadIdx.x; i < num_vec_per_layer; i += blockDim.x) { + dst4[i] = src4[i]; + } + } +} + +// H2D: scatter from contiguous staging buffer to scattered GPU layer tensors +template +__global__ void scatter_blocks_kernel(T** __restrict__ layer_ptrs, + const T* __restrict__ staging, + const int64_t* __restrict__ gpu_block_ids, + int n_blocks, + int layer_num, + int64_t block_stride) { + int pair_idx = blockIdx.x; + int block_idx = pair_idx / layer_num; + int layer_idx = pair_idx % layer_num; + + if (block_idx >= n_blocks) return; + + int64_t gpu_block = gpu_block_ids[block_idx]; + const T* src = staging + (int64_t)block_idx * layer_num * block_stride + + layer_idx * block_stride; + T* dst = layer_ptrs[layer_idx] + gpu_block * block_stride; + + int64_t num_vec = (block_stride * sizeof(T)) / sizeof(float4); + const float4* src4 = reinterpret_cast(src); + float4* dst4 = reinterpret_cast(dst); + + for (int64_t i = threadIdx.x; i < num_vec; i += blockDim.x) { + dst4[i] = src4[i]; + } +} + +static void* g_staging_buffer = nullptr; +static size_t g_staging_buffer_size = 0; +static void* g_device_block_ids = nullptr; +static size_t g_device_block_ids_size = 0; +static void* g_device_layer_ptrs = nullptr; +static size_t g_device_layer_ptrs_size = 0; + +static void ensure_staging_buffer(size_t required_size) { + if (g_staging_buffer_size < required_size) { + if (g_staging_buffer) cudaFree(g_staging_buffer); + cudaError_t err = cudaMalloc(&g_staging_buffer, required_size); + PADDLE_ENFORCE_EQ( + err, + cudaSuccess, + phi::errors::External("cudaMalloc staging buffer failed: %s", + cudaGetErrorString(err))); + g_staging_buffer_size = required_size; + } +} + +static void ensure_device_block_ids(size_t required_size) { + if (g_device_block_ids_size < required_size) { + if (g_device_block_ids) cudaFree(g_device_block_ids); + cudaError_t err = cudaMalloc(&g_device_block_ids, required_size); + PADDLE_ENFORCE_EQ( + err, + cudaSuccess, + phi::errors::External("cudaMalloc device block_ids failed: %s", + cudaGetErrorString(err))); + g_device_block_ids_size = required_size; + } +} + +static void ensure_device_layer_ptrs(size_t required_size) { + if (g_device_layer_ptrs_size < required_size) { + if (g_device_layer_ptrs) cudaFree(g_device_layer_ptrs); + cudaError_t err = cudaMalloc(&g_device_layer_ptrs, required_size); + PADDLE_ENFORCE_EQ( + err, + cudaSuccess, + phi::errors::External("cudaMalloc device layer_ptrs failed: %s", + cudaGetErrorString(err))); + g_device_layer_ptrs_size = required_size; + } +} + +static bool is_cpu_block_ids_sequential( + const std::vector& cpu_block_ids) { + if (cpu_block_ids.empty()) return true; + int64_t start = cpu_block_ids[0]; + for (size_t i = 1; i < cpu_block_ids.size(); i++) { + if (cpu_block_ids[i] != start + static_cast(i)) return false; + } + return true; +} template -void SwapCacheImpLayout( - const std::vector& cache_gpu_tensors, // gpu - const int64_t& cache_cpu_pointer, // cpu - const std::vector& cache_shape, - const std::vector& gpu_block_ids, - const std::vector& cpu_block_ids, - int mode) { - /* - mode is 0: gpu to cpu; 1: cpu to gpu - - cache layout: layer_num * [block_num, head_num, block_size, head_dim] - scale layout: layer_num * [block_num, head_num, block_size] - cache buffer layout: [block_num, layer_num, head_num, block_size, head_dim] - scale buffer layout: [block_num, layer_num, head_num, block_size] - */ +void SwapCacheImpLayout(const std::vector& cache_gpu_tensors, + const int64_t& cache_cpu_pointer, + const std::vector& cache_shape, + const std::vector& gpu_block_ids, + const std::vector& cpu_block_ids, + int mode) { typedef PDTraits traits_; typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; const int64_t layer_number = cache_gpu_tensors.size(); int64_t cache_block_stride = 1; - for (int i = 1; i < cache_shape.size(); i++) { + for (size_t i = 1; i < cache_shape.size(); i++) { cache_block_stride *= cache_shape[i]; } + const int n_blocks = gpu_block_ids.size(); + if (n_blocks == 0) return; + auto stream = cache_gpu_tensors[0].stream(); - const cudaMemcpyKind copy_kind = - (mode == 0) ? cudaMemcpyDeviceToHost : cudaMemcpyHostToDevice; - - for (int layer_idx = 0; layer_idx < cache_gpu_tensors.size(); layer_idx++) { - const paddle::Tensor& cache_gpu = cache_gpu_tensors[layer_idx]; - data_t* cache_gpu_ptr = const_cast(cache_gpu.data()); - auto* cache_cpu_ptr = reinterpret_cast(cache_cpu_pointer); - - for (int block_idx = 0; block_idx < gpu_block_ids.size(); block_idx++) { - auto cur_gpu_block_id = gpu_block_ids[block_idx]; - auto cur_cpu_block_id = cpu_block_ids[block_idx]; - auto* cache_gpu_ptr_now = - cache_gpu_ptr + cur_gpu_block_id * cache_block_stride; - auto* cache_cpu_ptr_now = - cache_cpu_ptr + cur_cpu_block_id * cache_block_stride * layer_number + - layer_idx * cache_block_stride; - - cudaError_t status = cudaMemcpyAsync( - (copy_kind == cudaMemcpyDeviceToHost) ? cache_cpu_ptr_now - : cache_gpu_ptr_now, - (copy_kind == cudaMemcpyDeviceToHost) ? cache_gpu_ptr_now - : cache_cpu_ptr_now, - cache_block_stride * sizeof(DataType_), - copy_kind, - stream); + const size_t block_bytes = cache_block_stride * sizeof(DataType_); + const size_t total_bytes = (size_t)n_blocks * layer_number * block_bytes; + + bool use_optimized = is_cpu_block_ids_sequential(cpu_block_ids); + + if (use_optimized) { + ensure_device_block_ids(n_blocks * sizeof(int64_t)); + ensure_device_layer_ptrs(layer_number * sizeof(DataType_*)); + cudaMemcpyAsync(g_device_block_ids, + gpu_block_ids.data(), + n_blocks * sizeof(int64_t), + cudaMemcpyHostToDevice, + stream); + + std::vector h_layer_ptrs(layer_number); + for (int64_t i = 0; i < layer_number; i++) { + h_layer_ptrs[i] = reinterpret_cast( + const_cast(cache_gpu_tensors[i].data())); + } + cudaMemcpyAsync(g_device_layer_ptrs, + h_layer_ptrs.data(), + layer_number * sizeof(DataType_*), + cudaMemcpyHostToDevice, + stream); + + int64_t cpu_start_block = cpu_block_ids[0]; + auto* cache_cpu_base = reinterpret_cast(cache_cpu_pointer) + + cpu_start_block * layer_number * cache_block_stride; + + int grid_size = n_blocks * layer_number; + + if (mode == 0) { + // GPU→CPU: direct kernel write to pinned host memory + // Multi-layer kernel: each block handles all layers for one swap block + swap_d2h_kernel<<>>( + reinterpret_cast(g_device_layer_ptrs), + cache_cpu_base, + reinterpret_cast(g_device_block_ids), + n_blocks, + layer_number, + cache_block_stride); + } else { + // CPU→GPU: DMA memcpy to staging then scatter kernel + ensure_staging_buffer(total_bytes); + + cudaError_t status = cudaMemcpyAsync(g_staging_buffer, + cache_cpu_base, + total_bytes, + cudaMemcpyHostToDevice, + stream); PADDLE_ENFORCE_EQ(status, cudaSuccess, - phi::errors::External("cudaMemcpyAsync failed: %s", + phi::errors::External("cudaMemcpyAsync H2D failed: %s", cudaGetErrorString(status))); -#ifdef SWAP_DEBUG - cudaStreamSynchronize(stream); - std::cout << "mode:" << mode << ", layer_idx:" << layer_idx - << ", block_idx:" << block_idx << ", cache_cpu_ptr_now data:" - << static_cast(*cache_cpu_ptr_now) << std::endl; -#endif + scatter_blocks_kernel<<>>( + reinterpret_cast(g_device_layer_ptrs), + reinterpret_cast(g_staging_buffer), + reinterpret_cast(g_device_block_ids), + n_blocks, + layer_number, + cache_block_stride); + } + } else { + const cudaMemcpyKind copy_kind = + (mode == 0) ? cudaMemcpyDeviceToHost : cudaMemcpyHostToDevice; + for (int64_t layer_idx = 0; layer_idx < layer_number; layer_idx++) { + const paddle::Tensor& cache_gpu = cache_gpu_tensors[layer_idx]; + data_t* cache_gpu_ptr = const_cast(cache_gpu.data()); + auto* cache_cpu_ptr = reinterpret_cast(cache_cpu_pointer); + + for (int block_idx = 0; block_idx < n_blocks; block_idx++) { + auto cur_gpu_block_id = gpu_block_ids[block_idx]; + auto cur_cpu_block_id = cpu_block_ids[block_idx]; + auto* cache_gpu_ptr_now = + cache_gpu_ptr + cur_gpu_block_id * cache_block_stride; + auto* cache_cpu_ptr_now = + cache_cpu_ptr + + cur_cpu_block_id * cache_block_stride * layer_number + + layer_idx * cache_block_stride; + + cudaError_t status = cudaMemcpyAsync( + (copy_kind == cudaMemcpyDeviceToHost) ? cache_cpu_ptr_now + : cache_gpu_ptr_now, + (copy_kind == cudaMemcpyDeviceToHost) ? cache_gpu_ptr_now + : cache_cpu_ptr_now, + block_bytes, + copy_kind, + stream); + PADDLE_ENFORCE_EQ(status, + cudaSuccess, + phi::errors::External("cudaMemcpyAsync failed: %s", + cudaGetErrorString(status))); + } } } + cudaError_t sync_status = cudaStreamSynchronize(stream); PADDLE_ENFORCE_EQ(sync_status, cudaSuccess, @@ -90,15 +255,14 @@ void SwapCacheImpLayout( cudaGetErrorString(sync_status))); } -void SwapCacheLayout( - const std::vector& cache_gpu_tensors, // gpu - const int64_t& cache_cpu_ptrs, // cpu memory pointer - const std::vector& cache_shape, - const std::vector& gpu_block_ids, - const std::vector& cpu_block_ids, - int rank, - int mode) { - cudaSetDevice(rank); // used for distributed launch +void SwapCacheLayout(const std::vector& cache_gpu_tensors, + const int64_t& cache_cpu_ptrs, + const std::vector& cache_shape, + const std::vector& gpu_block_ids, + const std::vector& cpu_block_ids, + int rank, + int mode) { + cudaSetDevice(rank); assert(cache_gpu_tensors.size() > 0); switch (cache_gpu_tensors[0].dtype()) { case paddle::DataType::BFLOAT16: diff --git a/examples/cache_storage/run_03b_pd_storage.sh b/examples/cache_storage/run_03b_pd_storage.sh index 5577a0ebf27..996927acf1f 100644 --- a/examples/cache_storage/run_03b_pd_storage.sh +++ b/examples/cache_storage/run_03b_pd_storage.sh @@ -18,7 +18,7 @@ metadata_port=15002 export MOONCAKE_MASTER_SERVER_ADDR="${master_ip}:${master_port}" export MOONCAKE_METADATA_SERVER="http://${master_ip}:${metadata_port}/metadata" -export MOONCAKE_GLOBAL_SEGMENT_SIZE="50000000000" +export MOONCAKE_GLOBAL_SEGMENT_SIZE="200000000000" # export MOONCAKE_PROTOCOL="tcp" export MOONCAKE_PROTOCOL="rdma" # export MOONCAKE_RDMA_DEVICES="mlx5_0" diff --git a/fastdeploy/cache_manager/cache_messager.py b/fastdeploy/cache_manager/cache_messager.py index 08c8dea003a..44cec96699f 100644 --- a/fastdeploy/cache_manager/cache_messager.py +++ b/fastdeploy/cache_manager/cache_messager.py @@ -682,6 +682,24 @@ def prefill_layerwise_send_cache_thread(self): try: batch_engine_signals = self.cache_prefilled_engine_ids_queue.get() self.engine_worker_queue.begin_send_cache_barrier.wait() + + # Storage pool mode: skip RDMA/IPC transfer, immediately notify completion + if envs.FD_PD_TRANSFER_VIA_STORAGE: + with self.engine_cache_task_thread_lock: + for engine_idx, _ in batch_engine_signals: + self._maybe_wait_for_cache_task(engine_idx) + task = self.idx_cache_task_dict[engine_idx] + task["status"] = "finished" + logger.info( + f"[PD Storage] Skip RDMA transfer, mark as finished, " f"req_id: {task['request_id']}" + ) + self.engine_worker_queue.finish_send_cache_barrier.wait() + self.engine_worker_queue.put_finished_req([[task["request_id"], task["status"]]]) + self.engine_cache_tasks[task["current_id"]] = dict() + del self.cache_info[task["request_id"]] + del self.idx_cache_task_dict[task["current_id"]] + continue + block_start_end_list = [] current_prefilled_token_num_list = [] for engine_index, current_step_prefilled_token_num in batch_engine_signals: diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index c41a6109029..5560e331048 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -1257,6 +1257,7 @@ def write_cache_to_storage_decode(self, request: Request): # Incremental logic is handled by CacheTransferManager.write_back_storage_task() req_id = request.request_id logger.info(f"[D instance] start write cache to storage, req_id: {req_id}, block num: {len(keys)}") + trace_print(LoggingEventName.WRITE_CACHE_TO_STORAGE_START, request.request_id, getattr(request, "user", "")) write_storage_task = WriteStorageTask( task_id=req_id, @@ -1269,6 +1270,182 @@ def write_cache_to_storage_decode(self, request: Request): self.issue_write_back_storage_task(write_storage_task, is_sync=True) cost_time = time.time() - tic logger.info(f"[D instance] finish write cache to storage, req_id: {req_id}, cost_time: {cost_time:.6f}s") + trace_print(LoggingEventName.WRITE_CACHE_TO_STORAGE_END, request.request_id, getattr(request, "user", "")) + + def write_all_cache_to_storage(self, request: Request, include_output=True): + """ + Write ALL token cache (including last incomplete block) to storage. + Used in PD storage-pool mode where P writes to storage instead of RDMA to D, + and D writes back all cache (including output tokens) on request completion. + + Unlike write_cache_to_storage_decode which skips incomplete blocks, this method + writes the last incomplete block by padding it to block_size in the storage key + computation (using a ':partial:N' suffix on the key). + + The actual GPU block is still full-sized, so swap_cache_layout works normally. + + Args: + request: The request object. + include_output: If True, include output_token_ids in the write (used by D). + If False, only write prompt_token_ids (used by P). + """ + if self.kvcache_storage_backend is None: + return + + # 1. Get complete token_ids + token_ids = request.prompt_token_ids + if isinstance(token_ids, np.ndarray): + token_ids = token_ids.tolist() + else: + token_ids = list(token_ids) + + input_token_ids = token_ids + request.output_token_ids if include_output else token_ids + + # 2. Calculate cache keys using chained hash, including last partial block + keys = [] + prefix_block_key = [] + block_size = self.config.cache_config.block_size + mm_idx = 0 + + for i in range(0, len(input_token_ids), block_size): + block_token_ids = input_token_ids[i : i + block_size] + actual_token_num = len(block_token_ids) + + if actual_token_num < block_size: + # Last incomplete block: compute key with actual tokens + partial marker + key = get_hash_str(block_token_ids, prefix_block_key) + key = f"{key}:partial:{actual_token_num}" + keys.append(key) + else: + # Full block: compute key normally (same as write_cache_to_storage_decode) + mm_idx, extra_keys = self.get_block_hash_extra_keys( + request=request, + start_idx=i, + end_idx=i + block_size, + mm_idx=mm_idx, + ) + prefix_block_key.extend(extra_keys) + key = get_hash_str(block_token_ids, prefix_block_key) + keys.append(key) + + prefix_block_key = [key] + + if not keys: + return + + # 3. Get corresponding gpu_block_ids + gpu_block_ids = request.block_tables[: len(keys)] + + # 4. Construct WriteStorageTask and send + req_id = request.request_id + logger.info( + f"[PD Storage] start write all cache to storage, req_id: {req_id}, " + f"block num: {len(keys)}, total_tokens: {len(input_token_ids)}" + ) + trace_print(LoggingEventName.WRITE_CACHE_TO_STORAGE_START, request.request_id, getattr(request, "user", "")) + + write_storage_task = WriteStorageTask( + task_id=req_id, + keys=keys, + token_ids=input_token_ids if self.kvcache_storage_backend == "attention_store" else None, + gpu_block_ids=gpu_block_ids, + ) + + tic = time.time() + self.issue_write_back_storage_task(write_storage_task, is_sync=True) + cost_time = time.time() - tic + logger.info(f"[PD Storage] finish write all cache to storage, req_id: {req_id}, cost_time: {cost_time:.6f}s") + trace_print(LoggingEventName.WRITE_CACHE_TO_STORAGE_END, request.request_id, getattr(request, "user", "")) + + def read_cache_from_storage_for_pd(self, request: Request): + """ + PD storage-pool mode: D instance reads cache from storage that P wrote. + + This is different from request_match_blocks() storage read: + - Called on D instance after receiving first_token notification from P + - Reads ALL blocks (including last partial block) that P wrote to storage + - Target gpu_block_ids are D's pre-allocated blocks + + Returns: + list: gpu_block_ids if all blocks fetched successfully, + empty list if any block failed to fetch (caller should abort this request). + """ + if self.kvcache_storage_backend is None: + return [] + + # 1. Get token_ids (same as what P prefilled) + token_ids = request.prompt_token_ids + if isinstance(token_ids, np.ndarray): + token_ids = token_ids.tolist() + else: + token_ids = list(token_ids) + + input_token_ids = token_ids + + # 2. Calculate cache keys using same algorithm as write_all_cache_to_storage + keys = [] + prefix_block_key = [] + block_size = self.config.cache_config.block_size + mm_idx = 0 + + for i in range(0, len(input_token_ids), block_size): + block_token_ids = input_token_ids[i : i + block_size] + actual_token_num = len(block_token_ids) + + if actual_token_num < block_size: + key = get_hash_str(block_token_ids, prefix_block_key) + key = f"{key}:partial:{actual_token_num}" + keys.append(key) + else: + mm_idx, extra_keys = self.get_block_hash_extra_keys( + request=request, + start_idx=i, + end_idx=i + block_size, + mm_idx=mm_idx, + ) + prefix_block_key.extend(extra_keys) + key = get_hash_str(block_token_ids, prefix_block_key) + keys.append(key) + + prefix_block_key = [key] + + if not keys: + return [] + + # 3. gpu_block_ids = D's pre-allocated block_tables + gpu_block_ids = request.block_tables[: len(keys)] + + # 4. Issue ReadStorageTask + req_id = request.request_id + logger.info( + f"[PD Storage] D start read cache from storage, req_id: {req_id}, " + f"block num: {len(keys)}, total_tokens: {len(input_token_ids)}" + ) + + read_task = ReadStorageTask( + task_id=req_id, + keys=keys, + token_ids=input_token_ids if self.kvcache_storage_backend == "attention_store" else None, + gpu_block_ids=gpu_block_ids, + start_read_block_idx=0, + ) + + tic = time.time() + storage_block_ids = self.issue_prefetch_storage_task(read_task, is_sync=True) + cost_time = time.time() - tic + + if len(storage_block_ids) != len(keys): + logger.error( + f"[PD Storage] D failed to read all blocks from storage, req_id: {req_id}, " + f"matched blocks: {len(storage_block_ids)}/{len(keys)}, cost_time: {cost_time:.6f}s" + ) + return [] + else: + logger.info( + f"[PD Storage] D finish reading the cache of all blocks from storage, req_id: {req_id}, " + f"matched blocks: {len(storage_block_ids)}/{len(keys)}, cost_time: {cost_time:.6f}s" + ) + return storage_block_ids def issue_write_back_storage_task(self, task: WriteStorageTask, is_sync=True): if self.kvcache_storage_backend is None: diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index f553a4f8ee5..ec995364a3e 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -1861,6 +1861,32 @@ def _process_prefilled_requests(): self.token_processor.tokens_counter[request_id] = 1 if envs.FD_ENABLE_INTERNAL_ADAPTER: # first token sent by D instance self.scheduler.put_results([req_output]) + + # Storage pool mode: D reads cache from storage before adding to running queue + if envs.FD_PD_TRANSFER_VIA_STORAGE: + request = self.resource_manager.requests[request_id] + self.llm_logger.info(f"[PD Storage] D reading cache from storage, request_id: {request_id}") + storage_block_ids = self.resource_manager.cache_manager.read_cache_from_storage_for_pd(request) + if not storage_block_ids: + self.llm_logger.error( + f"[PD Storage] D failed to read cache from storage, " f"request_id: {request_id}" + ) + self.resource_manager.pre_recycle_resource(request_id) + if request_id in self.token_processor.tokens_counter: + del self.token_processor.tokens_counter[request_id] + req_output.error_code = 502 + req_output.error_msg = ( + f"PD Storage Error: D failed to read all blocks from storage, " + f"request_id: {request_id}" + ) + req_output.finished = True + self.scheduler.put_results([req_output]) + continue + self.llm_logger.info( + f"[PD Storage] D successfully read cache from storage, " + f"request_id: {request_id}, blocks: {len(storage_block_ids)}" + ) + self.resource_manager.add_prefilled_request(req_output) self.llm_logger.info(f"D has successfully added prefilled request, {request_id}") diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index de89ab3adca..46389ccc742 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -869,6 +869,7 @@ def get_enough_request(request, scheduled_reqs): # First, schedule the RUNNING requests. req_index = 0 num_decoding_req_nums = 0 + while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] need_block_num = self.need_block_num_signal.value[request.idx] @@ -1673,7 +1674,12 @@ def finish_requests(self, request_ids: Union[str, Iterable[str]]): # Do not block the main thread here # Write cache to storage if kvcache_storage_backend is enabled for req in need_postprocess_reqs: - if self.config.scheduler_config.splitwise_role == "decode": + if envs.FD_PD_TRANSFER_VIA_STORAGE: + # Storage pool mode: P already writes cache in token_processor before notifying D, + # only D needs to write here (including output tokens generated during decode) + if self.config.scheduler_config.splitwise_role == "decode": + self.cache_manager.write_all_cache_to_storage(req) + elif self.config.scheduler_config.splitwise_role == "decode": # D instance uses simplified write method (does not rely on Radix Tree) self.cache_manager.write_cache_to_storage_decode(req) else: diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 509f9a768d9..6ec9aecb777 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -252,6 +252,10 @@ def _validate_split_kv_size(value: int) -> int: "FD_DETERMINISTIC_LOG_MODE": lambda: bool(int(os.getenv("FD_DETERMINISTIC_LOG_MODE", "0"))), # Whether to use PD REORDER, can set 0 or 1 "FD_PD_REORDER": lambda: int(os.getenv("FD_PD_REORDER", "0")), + # PD disaggregation cache transfer mode: + # 0 (default): Direct transfer mode, P writes cache to D's GPU via RDMA/IPC + # 1: Storage pool mode, P writes cache to global storage pool, D reads from storage pool + "FD_PD_TRANSFER_VIA_STORAGE": lambda: int(os.getenv("FD_PD_TRANSFER_VIA_STORAGE", "0")), # Whether to enable KV cache lock, enforcing mutual exclusion between # PrefixCacheManager and Worker when accessing GPU KV cache. # Under certain DP+EP configurations, concurrent access (even read-only) diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 2a8328b28fe..a29a4d9e2fc 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -653,6 +653,19 @@ def _recycle_resources(self, task_id, index, task, result=None, is_prefill=False f"wait for sending cache, request_id: {task_id}, cost seconds: {time.time()-start_time:.5f}" ) trace_print(LoggingEventName.CHECK_CACHE_TRANSFER_END, task_id, getattr(task, "user", "")) + + # Storage pool mode: write all cache to storage before sending first token to D + if envs.FD_PD_TRANSFER_VIA_STORAGE and result.error_code == 200: + llm_logger.info( + f"[PD Storage] P writing cache to storage before send_first_token, " + f"request_id: {task_id}" + ) + self.resource_manager.cache_manager.write_all_cache_to_storage(task, include_output=False) + llm_logger.info( + f"[PD Storage] P finished writing cache to storage, " + f"request_id: {task_id}, cost: {time.time()-start_time:.5f}s" + ) + result.metrics.send_request_output_to_decode_time = time.time() self.split_connector.send_first_token(task.disaggregate_info, [result]) if envs.ENABLE_V1_KVCACHE_SCHEDULER: From 8104540dd1cf855eaafd921984b8dd3aa1006536 Mon Sep 17 00:00:00 2001 From: juncaipeng <13006307475@163.com> Date: Sun, 17 May 2026 05:51:42 +0000 Subject: [PATCH 2/6] skip messager --- examples/cache_storage/run_03b_pd_storage.sh | 2 +- fastdeploy/cache_manager/cache_messager.py | 2 +- .../engine/common_engine_prepare_mixin.py | 4 +- .../layers/attention/append_attn_backend.py | 10 +- .../layers/attention/dsa_attention_backend.py | 11 +- .../layers/attention/flash_attn_backend.py | 10 +- .../attention/flash_mask_attn_backend.py | 11 +- .../layers/attention/mla_attention_backend.py | 15 ++- fastdeploy/output/token_processor.py | 106 ++++++++++-------- 9 files changed, 105 insertions(+), 66 deletions(-) diff --git a/examples/cache_storage/run_03b_pd_storage.sh b/examples/cache_storage/run_03b_pd_storage.sh index 996927acf1f..c940fe9a8ef 100644 --- a/examples/cache_storage/run_03b_pd_storage.sh +++ b/examples/cache_storage/run_03b_pd_storage.sh @@ -18,7 +18,7 @@ metadata_port=15002 export MOONCAKE_MASTER_SERVER_ADDR="${master_ip}:${master_port}" export MOONCAKE_METADATA_SERVER="http://${master_ip}:${metadata_port}/metadata" -export MOONCAKE_GLOBAL_SEGMENT_SIZE="200000000000" +export MOONCAKE_GLOBAL_SEGMENT_SIZE="50000000000" # 50GB # export MOONCAKE_PROTOCOL="tcp" export MOONCAKE_PROTOCOL="rdma" # export MOONCAKE_RDMA_DEVICES="mlx5_0" diff --git a/fastdeploy/cache_manager/cache_messager.py b/fastdeploy/cache_manager/cache_messager.py index 44cec96699f..cf74c2ec36b 100644 --- a/fastdeploy/cache_manager/cache_messager.py +++ b/fastdeploy/cache_manager/cache_messager.py @@ -693,11 +693,11 @@ def prefill_layerwise_send_cache_thread(self): logger.info( f"[PD Storage] Skip RDMA transfer, mark as finished, " f"req_id: {task['request_id']}" ) - self.engine_worker_queue.finish_send_cache_barrier.wait() self.engine_worker_queue.put_finished_req([[task["request_id"], task["status"]]]) self.engine_cache_tasks[task["current_id"]] = dict() del self.cache_info[task["request_id"]] del self.idx_cache_task_dict[task["current_id"]] + self.engine_worker_queue.finish_send_cache_barrier.wait() continue block_start_end_list = [] diff --git a/fastdeploy/engine/common_engine_prepare_mixin.py b/fastdeploy/engine/common_engine_prepare_mixin.py index 71327025458..92dffe4778d 100644 --- a/fastdeploy/engine/common_engine_prepare_mixin.py +++ b/fastdeploy/engine/common_engine_prepare_mixin.py @@ -226,8 +226,8 @@ def _fetch_request_prefill(self) -> bool: tasks.remove(tmp_task) self.resource_manager.pre_recycle_resource(tmp_task.request_id) - # Send cache info to messager - if tasks: + # Send cache info to messager (skip in storage pool mode - messager is bypassed) + if tasks and not envs.FD_PD_TRANSFER_VIA_STORAGE: self.split_connector.send_cache_info_to_messager(tasks, 0) # Fetch requests and add them to the scheduling queue diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index 15b657c249d..ad3dc59c28b 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -210,7 +210,11 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): # pd_disaggregation metadata.kv_signal_data_list = [None] * self.num_layers if self.pd_disaggregation_mode == "per_chunk": - if not self.keep_pd_step_flag and not forward_meta.is_dummy_or_profile_run: + if ( + not self.keep_pd_step_flag + and not forward_meta.is_dummy_or_profile_run + and not envs.FD_PD_TRANSFER_VIA_STORAGE + ): init_kv_signal_per_query( forward_meta.seq_lens_encoder, forward_meta.seq_lens_this_time, @@ -218,7 +222,7 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): self.rank, self.num_layers + self.num_layers_draft_model, ) - elif self.pd_disaggregation_mode == "per_query": + elif self.pd_disaggregation_mode == "per_query" and not envs.FD_PD_TRANSFER_VIA_STORAGE: metadata.kv_signal_metadata = open_shm_and_get_meta_signal( self.rank, int(self.device_id), self.keep_pd_step_flag ) @@ -306,7 +310,7 @@ def forward_mixed( # 64 is gpt-oss assert forward_meta.rotary_embs.shape[4] in [128, 32, 64] - if self.pd_disaggregation_mode == "per_query": + if self.pd_disaggregation_mode == "per_query" and not envs.FD_PD_TRANSFER_VIA_STORAGE: metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise( metadata.kv_signal_metadata, layer.layer_id + self.start_layer_index, diff --git a/fastdeploy/model_executor/layers/attention/dsa_attention_backend.py b/fastdeploy/model_executor/layers/attention/dsa_attention_backend.py index acb73f5420a..88f28467569 100644 --- a/fastdeploy/model_executor/layers/attention/dsa_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/dsa_attention_backend.py @@ -28,6 +28,7 @@ if current_platform.is_cuda(): paddle.enable_compat(scope={"flash_mla"}) +from fastdeploy import envs from fastdeploy.model_executor.layers.attention.ops import ( get_block_shape_and_split_kv_block, init_kv_signal_per_query, @@ -243,7 +244,11 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): # pd_disaggregation metadata.kv_signal_data_list = [None] * self.num_layers if self.pd_disaggregation_mode == "per_chunk": - if not self.keep_pd_step_flag and not forward_meta.is_dummy_or_profile_run: + if ( + not self.keep_pd_step_flag + and not forward_meta.is_dummy_or_profile_run + and not envs.FD_PD_TRANSFER_VIA_STORAGE + ): init_kv_signal_per_query( forward_meta.seq_lens_encoder, forward_meta.seq_lens_this_time, @@ -251,7 +256,7 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): self.rank, self.num_layers + self.num_layers_draft_model, ) - elif self.pd_disaggregation_mode == "per_query": + elif self.pd_disaggregation_mode == "per_query" and not envs.FD_PD_TRANSFER_VIA_STORAGE: metadata.kv_signal_metadata = open_shm_and_get_meta_signal( self.rank, int(self.device_id), self.keep_pd_step_flag ) @@ -304,7 +309,7 @@ def forward_mixed( # speculate_decoder = self.speculative_method is not None # speculate_max_tokens = self.speculate_max_draft_token_num - if self.pd_disaggregation_mode == "per_query": + if self.pd_disaggregation_mode == "per_query" and not envs.FD_PD_TRANSFER_VIA_STORAGE: metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise( metadata.kv_signal_metadata, layer.layer_id + self.start_layer_index, diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index b203fdbb221..f3935505fa8 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -299,7 +299,11 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): # pd_disaggregation metadata.kv_signal_data_list = [None] * self.num_layers if self.pd_disaggregation_mode == "per_chunk": - if not self.keep_pd_step_flag and not forward_meta.is_dummy_or_profile_run: + if ( + not self.keep_pd_step_flag + and not forward_meta.is_dummy_or_profile_run + and not envs.FD_PD_TRANSFER_VIA_STORAGE + ): init_kv_signal_per_query( forward_meta.seq_lens_encoder, forward_meta.seq_lens_this_time, @@ -307,7 +311,7 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): self.rank, self.num_layers + self.num_layers_draft_model, ) - elif self.pd_disaggregation_mode == "per_query": + elif self.pd_disaggregation_mode == "per_query" and not envs.FD_PD_TRANSFER_VIA_STORAGE: metadata.kv_signal_metadata = open_shm_and_get_meta_signal( self.rank, int(self.device_id), self.keep_pd_step_flag ) @@ -334,7 +338,7 @@ def forward_mixed( ): metadata = self.attention_metadata - if self.pd_disaggregation_mode == "per_query": + if self.pd_disaggregation_mode == "per_query" and not envs.FD_PD_TRANSFER_VIA_STORAGE: metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise( metadata.kv_signal_metadata, layer.layer_id + self.start_layer_index, diff --git a/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py index 5b3c5ecdd3a..4c663c1a702 100644 --- a/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py @@ -22,6 +22,7 @@ import paddle +from fastdeploy import envs from fastdeploy.config import FDConfig from fastdeploy.model_executor.layers.attention.attention import Attention from fastdeploy.model_executor.layers.attention.base_attention_backend import ( @@ -146,7 +147,11 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): # metadata only save pd_disaggregation info. metadata.kv_signal_data_list = [None] * self.num_layers if self.pd_disaggregation_mode == "per_chunk": - if not self.keep_pd_step_flag and not forward_meta.is_dummy_or_profile_run: + if ( + not self.keep_pd_step_flag + and not forward_meta.is_dummy_or_profile_run + and not envs.FD_PD_TRANSFER_VIA_STORAGE + ): init_kv_signal_per_query( forward_meta.seq_lens_encoder, forward_meta.seq_lens_this_time, @@ -154,7 +159,7 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): self.rank, self.num_layers + self.num_layers_draft_model, ) - elif self.pd_disaggregation_mode == "per_query": + elif self.pd_disaggregation_mode == "per_query" and not envs.FD_PD_TRANSFER_VIA_STORAGE: metadata.kv_signal_metadata = open_shm_and_get_meta_signal( self.rank, int(self.device_id), self.keep_pd_step_flag ) @@ -197,7 +202,7 @@ def forward_mixed( cache_k_scales = getattr(layer, "cache_k_scale", None) cache_v_scales = getattr(layer, "cache_v_scale", None) - if self.pd_disaggregation_mode == "per_query": + if self.pd_disaggregation_mode == "per_query" and not envs.FD_PD_TRANSFER_VIA_STORAGE: metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise( metadata.kv_signal_metadata, layer.layer_id + self.start_layer_index, diff --git a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py index 932c13decc3..c2a0c9e6a2d 100644 --- a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py @@ -34,6 +34,7 @@ logger.debug(f"flash_attention_v3_varlen not available: {e}") flash_attention_v3_varlen = None +from fastdeploy import envs from fastdeploy.model_executor.layers.attention.ops import ( get_block_shape_and_split_kv_block, init_kv_signal_per_query, @@ -358,7 +359,11 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): # pd_disaggregation metadata.kv_signal_data_list = [None] * self.num_layers if self.pd_disaggregation_mode == "per_chunk": - if not self.keep_pd_step_flag and not forward_meta.is_dummy_or_profile_run: + if ( + not self.keep_pd_step_flag + and not forward_meta.is_dummy_or_profile_run + and not envs.FD_PD_TRANSFER_VIA_STORAGE + ): init_kv_signal_per_query( forward_meta.seq_lens_encoder, forward_meta.seq_lens_this_time, @@ -366,7 +371,7 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): self.rank, self.num_layers + self.num_layers_draft_model, ) - elif self.pd_disaggregation_mode == "per_query": + elif self.pd_disaggregation_mode == "per_query" and not envs.FD_PD_TRANSFER_VIA_STORAGE: metadata.kv_signal_metadata = open_shm_and_get_meta_signal( self.rank, int(self.device_id), self.keep_pd_step_flag ) @@ -405,7 +410,7 @@ def forward_extend( """ metadata = self.attention_metadata - if self.pd_disaggregation_mode == "per_query": + if self.pd_disaggregation_mode == "per_query" and not envs.FD_PD_TRANSFER_VIA_STORAGE: metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise( metadata.kv_signal_metadata, layer.layer_id + self.start_layer_index, @@ -459,7 +464,7 @@ def forward_decode( """ metadata = self.attention_metadata - if self.pd_disaggregation_mode == "per_query": + if self.pd_disaggregation_mode == "per_query" and not envs.FD_PD_TRANSFER_VIA_STORAGE: metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise( metadata.kv_signal_metadata, layer.layer_id + self.start_layer_index, @@ -549,7 +554,7 @@ def forward_mixed( speculate_decoder = self.speculative_method is not None speculate_max_tokens = self.speculate_max_draft_token_num - if self.pd_disaggregation_mode == "per_query": + if self.pd_disaggregation_mode == "per_query" and not envs.FD_PD_TRANSFER_VIA_STORAGE: metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise( metadata.kv_signal_metadata, layer.layer_id + self.start_layer_index, diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index a29a4d9e2fc..c545492c1ee 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -631,57 +631,73 @@ def _recycle_resources(self, task_id, index, task, result=None, is_prefill=False recycle resources """ if is_prefill: - start_time = time.time() - result.metrics.wait_for_sending_cache_time = time.time() - trace_print(LoggingEventName.CHECK_CACHE_TRANSFER_START, task_id, getattr(task, "user", "")) - - while True: - finished_task_ids = self.engine_worker_queue.get_finished_req() - if len(finished_task_ids) > 0: - for finished_task_id in finished_task_ids: - llm_logger.info(f"finished_task_id: {finished_task_id}") - self.prefill_result_status[finished_task_id[0]] = finished_task_id[1] - if task_id in self.prefill_result_status: - if self.prefill_result_status[task_id] != "finished": - result.error_code = 501 - result.error_msg = ( - f"PD Error: prefill failed to send cache to decode, " - f"{task_id}, {self.prefill_result_status[task_id]}" - ) - self.prefill_result_status.pop(task_id) + if envs.FD_PD_TRANSFER_VIA_STORAGE: + # Storage pool mode: bypass CacheMessager entirely. + # At this point, all transformer layers are complete and KV cache is in GPU memory. + # Directly write cache to storage and send first token to D. + result.metrics.wait_for_sending_cache_time = time.time() + trace_print(LoggingEventName.CHECK_CACHE_TRANSFER_START, task_id, getattr(task, "user", "")) + if result.error_code == 200: + write_cache_start_time = time.time() + llm_logger.info(f"[PD Storage] P writing cache to storage (direct), request_id: {task_id}") + self.resource_manager.cache_manager.write_all_cache_to_storage(task, include_output=False) llm_logger.info( - f"wait for sending cache, request_id: {task_id}, cost seconds: {time.time()-start_time:.5f}" + f"[PD Storage] P finished writing cache to storage (direct), " + f"request_id: {task_id}, cost: {time.time()-write_cache_start_time:.5f}s" ) - trace_print(LoggingEventName.CHECK_CACHE_TRANSFER_END, task_id, getattr(task, "user", "")) - - # Storage pool mode: write all cache to storage before sending first token to D - if envs.FD_PD_TRANSFER_VIA_STORAGE and result.error_code == 200: - llm_logger.info( - f"[PD Storage] P writing cache to storage before send_first_token, " - f"request_id: {task_id}" - ) - self.resource_manager.cache_manager.write_all_cache_to_storage(task, include_output=False) + trace_print(LoggingEventName.CHECK_CACHE_TRANSFER_END, task_id, getattr(task, "user", "")) + result.metrics.send_request_output_to_decode_time = time.time() + self.split_connector.send_first_token(task.disaggregate_info, [result]) + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + self.resource_manager.finish_requests_async(task_id) + else: + self.resource_manager.stop_flags[index] = True + self.resource_manager.tasks_list[index] = None + self.resource_manager._recycle_block_tables(task) + if task_id in self.resource_manager.req_dict: + del self.resource_manager.req_dict[task_id] + else: + # RDMA/IPC mode: poll CacheMessager for transfer completion + start_time = time.time() + result.metrics.wait_for_sending_cache_time = time.time() + trace_print(LoggingEventName.CHECK_CACHE_TRANSFER_START, task_id, getattr(task, "user", "")) + + while True: + finished_task_ids = self.engine_worker_queue.get_finished_req() + if len(finished_task_ids) > 0: + for finished_task_id in finished_task_ids: + llm_logger.info(f"finished_task_id: {finished_task_id}") + self.prefill_result_status[finished_task_id[0]] = finished_task_id[1] + if task_id in self.prefill_result_status: + if self.prefill_result_status[task_id] != "finished": + result.error_code = 501 + result.error_msg = ( + f"PD Error: prefill failed to send cache to decode, " + f"{task_id}, {self.prefill_result_status[task_id]}" + ) + self.prefill_result_status.pop(task_id) llm_logger.info( - f"[PD Storage] P finished writing cache to storage, " - f"request_id: {task_id}, cost: {time.time()-start_time:.5f}s" + f"wait for sending cache, request_id: {task_id}, " + f"cost seconds: {time.time()-start_time:.5f}" ) + trace_print(LoggingEventName.CHECK_CACHE_TRANSFER_END, task_id, getattr(task, "user", "")) - result.metrics.send_request_output_to_decode_time = time.time() - self.split_connector.send_first_token(task.disaggregate_info, [result]) - if envs.ENABLE_V1_KVCACHE_SCHEDULER: - self.resource_manager.finish_requests_async(task_id) + result.metrics.send_request_output_to_decode_time = time.time() + self.split_connector.send_first_token(task.disaggregate_info, [result]) + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + self.resource_manager.finish_requests_async(task_id) + else: + self.resource_manager.stop_flags[index] = True + self.resource_manager.tasks_list[index] = None + self.resource_manager._recycle_block_tables(task) + if task_id in self.resource_manager.req_dict: + del self.resource_manager.req_dict[task_id] + break else: - self.resource_manager.stop_flags[index] = True - self.resource_manager.tasks_list[index] = None - self.resource_manager._recycle_block_tables(task) - if task_id in self.resource_manager.req_dict: - del self.resource_manager.req_dict[task_id] - break - else: - # TODO: Refine checking sending cache and do not keep waiting - if time.time() - start_time > 30: - llm_logger.warning(f"wait for sending cache, {task_id}") - time.sleep(0.002) + # TODO: Refine checking sending cache and do not keep waiting + if time.time() - start_time > 30: + llm_logger.warning(f"wait for sending cache, {task_id}") + time.sleep(0.002) else: if envs.ENABLE_V1_KVCACHE_SCHEDULER: self.resource_manager.finish_requests_async(task_id) From c2ea4e2c400636c7518e7e937b4fb6c5259ac2bc Mon Sep 17 00:00:00 2001 From: juncaipeng <13006307475@163.com> Date: Mon, 18 May 2026 02:51:41 +0000 Subject: [PATCH 3/6] up --- fastdeploy/cache_manager/cache_messager.py | 17 ----------------- fastdeploy/engine/args_utils.py | 4 ++++ fastdeploy/splitwise/splitwise_connector.py | 2 +- 3 files changed, 5 insertions(+), 18 deletions(-) diff --git a/fastdeploy/cache_manager/cache_messager.py b/fastdeploy/cache_manager/cache_messager.py index cf74c2ec36b..e0d625cf7a1 100644 --- a/fastdeploy/cache_manager/cache_messager.py +++ b/fastdeploy/cache_manager/cache_messager.py @@ -683,23 +683,6 @@ def prefill_layerwise_send_cache_thread(self): batch_engine_signals = self.cache_prefilled_engine_ids_queue.get() self.engine_worker_queue.begin_send_cache_barrier.wait() - # Storage pool mode: skip RDMA/IPC transfer, immediately notify completion - if envs.FD_PD_TRANSFER_VIA_STORAGE: - with self.engine_cache_task_thread_lock: - for engine_idx, _ in batch_engine_signals: - self._maybe_wait_for_cache_task(engine_idx) - task = self.idx_cache_task_dict[engine_idx] - task["status"] = "finished" - logger.info( - f"[PD Storage] Skip RDMA transfer, mark as finished, " f"req_id: {task['request_id']}" - ) - self.engine_worker_queue.put_finished_req([[task["request_id"], task["status"]]]) - self.engine_cache_tasks[task["current_id"]] = dict() - del self.cache_info[task["request_id"]] - del self.idx_cache_task_dict[task["current_id"]] - self.engine_worker_queue.finish_send_cache_barrier.wait() - continue - block_start_end_list = [] current_prefilled_token_num_list = [] for engine_index, current_step_prefilled_token_num in batch_engine_signals: diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 892d6668859..24607fa5902 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -637,6 +637,10 @@ def __post_init__(self): "kvcache_storage_backend is only supported when ENABLE_V1_KVCACHE_SCHEDULER=1" ) + if envs.FD_PD_TRANSFER_VIA_STORAGE: + if self.kvcache_storage_backend is None: + raise ValueError("Must set kvcache_storage_backend when FD_PD_TRANSFER_VIA_STORAGE=1") + valid_model_impls = ["auto", "fastdeploy", "paddleformers"] if self.model_impl not in valid_model_impls: raise NotImplementedError( diff --git a/fastdeploy/splitwise/splitwise_connector.py b/fastdeploy/splitwise/splitwise_connector.py index 77f75ee4de7..8ad78543bfa 100644 --- a/fastdeploy/splitwise/splitwise_connector.py +++ b/fastdeploy/splitwise/splitwise_connector.py @@ -389,7 +389,7 @@ def _process_message(self, frames: List[bytes]): self.current_request_ids[task["request_id"]] = current_status if self.enable_decode_cache_task: del self.current_request_ids[task["request_id"]] - if current_status == "finished": + if current_status == "finished" and not envs.FD_PD_TRANSFER_VIA_STORAGE: self.engine_worker_queue.put_cache_info(payload) except Exception as e: From 51affd23b8bca24a541eab0c2e53021d9ae742e6 Mon Sep 17 00:00:00 2001 From: juncaipeng <13006307475@163.com> Date: Mon, 18 May 2026 09:15:34 +0000 Subject: [PATCH 4/6] consider write cache error --- .../cache_manager/cache_transfer_manager.py | 68 +++++-- .../cache_manager/prefix_cache_manager.py | 187 +++++++++++------- fastdeploy/output/token_processor.py | 15 +- .../test_prefix_cache_manager.py | 1 + 4 files changed, 186 insertions(+), 85 deletions(-) diff --git a/fastdeploy/cache_manager/cache_transfer_manager.py b/fastdeploy/cache_manager/cache_transfer_manager.py index 36306ee5dc6..d3f26511372 100644 --- a/fastdeploy/cache_manager/cache_transfer_manager.py +++ b/fastdeploy/cache_manager/cache_transfer_manager.py @@ -1131,13 +1131,42 @@ def _run_write_back_storage( target_sizes.extend([self.scale_buffer_stride_bytes] * block_num * 2) start_time = time.time() - self.storage_backend.batch_set(keys=keys, target_locations=target_locations, target_sizes=target_sizes) + result = self.storage_backend.batch_set( + keys=keys, target_locations=target_locations, target_sizes=target_sizes + ) write_cost_time = time.time() - start_time + # Per-block success validation (same pattern as _run_read_storage) + # batch_set returns List[int]: 0 = success, negative = error + if k_scale_keys and v_scale_keys: + k_result = result[:block_num] + v_result = result[block_num : 2 * block_num] + k_scale_result = result[2 * block_num : 3 * block_num] + v_scale_result = result[3 * block_num :] + success_block_num = 0 + for k, v, ks, vs in zip(k_result, v_result, k_scale_result, v_scale_result): + if not (k == 0 and v == 0 and ks == 0 and vs == 0): + break + success_block_num += 1 + else: + k_result = result[:block_num] + v_result = result[block_num : 2 * block_num] + success_block_num = 0 + for k, v in zip(k_result, v_result): + if not (k == 0 and v == 0): + break + success_block_num += 1 + + if success_block_num < block_num: + logger.error( + f"_run_write_back_storage partial failure: " + f"{success_block_num}/{block_num} blocks written, task_id: {task_id}" + ) + logger.debug( f"_run_write_back_storage, swap_cost_time: {swap_cost_time:.6f}s, write_cost_time: {write_cost_time:.6f}s" ) - return block_num + return success_block_num elif self.storage_backend_type == "attention_store": key_cache = [] @@ -1222,14 +1251,13 @@ def write_back_storage_task(self, task: WriteStorageTask): if match_block_num >= len(k_cache_keys): logger.info(f"No uncached keys found for task {task.task_id}") - gpu_block_ids = [] else: try: k_cache_keys = k_cache_keys[match_block_num:] v_cache_keys = v_cache_keys[match_block_num:] k_scale_keys = k_scale_keys[match_block_num:] if k_scale_keys else None v_scale_keys = v_scale_keys[match_block_num:] if v_scale_keys else None - gpu_block_ids = gpu_block_ids[match_block_num:] + write_gpu_block_ids = gpu_block_ids[match_block_num:] cpu_block_ids = cpu_block_ids[match_block_num:] # TODO: support timeout with actual block count write_block_num = self._run_write_back_storage( @@ -1240,19 +1268,28 @@ def write_back_storage_task(self, task: WriteStorageTask): v_cache_keys, k_scale_keys, v_scale_keys, - gpu_block_ids, + write_gpu_block_ids, cpu_block_ids, task.timeout, ) logger.info( f"Successfully wrote {write_block_num} blocks to cache storage for task {task.task_id}" ) - # Write routing data to storage (shares dedup with KVCache) - remaining_keys = task.keys[match_block_num:] - self._write_routing_to_storage(remaining_keys, gpu_block_ids) + # Check for partial write failure + if write_block_num < len(write_gpu_block_ids): + logger.error( + f"Partial write failure for task {task.task_id}: " + f"{write_block_num}/{len(write_gpu_block_ids)} blocks written" + ) + # Report: match_block_num (already cached) + write_block_num (newly written) + gpu_block_ids = gpu_block_ids[: match_block_num + write_block_num] + # Write routing data to storage only for actually-written blocks + written_block_ids = write_gpu_block_ids[:write_block_num] + remaining_keys = task.keys[match_block_num : match_block_num + len(written_block_ids)] + self._write_routing_to_storage(remaining_keys, written_block_ids) except Exception as e: logger.error(f"Error in write back storage task: {e}, traceback:{traceback.format_exc()}") - gpu_block_ids = [] + gpu_block_ids = gpu_block_ids[:match_block_num] finally: try: if (self.rank == 0) and self.storage_backend_type == "attention_store": @@ -1265,14 +1302,19 @@ def write_back_storage_task(self, task: WriteStorageTask): result = (CacheStatus.GPU2STORAGE, task.task_id, task.keys, gpu_block_ids) self.cache_task_queue.swap_to_storage_barrier.wait() - if self.rank == 0: # 只有当rank为0时执行同步操作 - self.cache_task_queue.swap_to_storage_barrier.reset() - self.cache_task_queue.put_transfer_done_signal(result) # 发送传输完成信号 - logger.debug(f"write_back_storage_task: put_transfer_done_signal {result}") + self.cache_task_queue.put_transfer_done_signal(result) + logger.debug(f"write_back_storage_task: put_transfer_done_signal {result}") except Exception as e: logger.error( f"An error occurred in write_back_storage_task, " f"error: {e}, traceback:\n{traceback.format_exc()}" ) + # Prevent caller from blocking forever: send empty done signal + try: + result = (CacheStatus.GPU2STORAGE, task.task_id, task.keys, []) + self.cache_task_queue.swap_to_storage_barrier.wait() + self.cache_task_queue.put_transfer_done_signal(result) + except Exception as barrier_err: + logger.error(f"Failed to send failure signal for task {task.task_id}: {barrier_err}") def _do_swap_to_cpu_task( self, diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index 5560e331048..b21b4349172 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -97,6 +97,7 @@ def __init__( self.kvcache_storage_backend = self.cache_config.kvcache_storage_backend self.write_policy = self.cache_config.write_policy self.task_write_back_event = {} + self.storage_write_back_result = {} self.task_prefetch_event = {} self.storage_prefetch_block_ids = {} @@ -1186,9 +1187,15 @@ def write_cache_to_storage(self, request: Request): ) logger.debug(f"issue write storage task: {write_storage_task}") tic = time.time() - self.issue_write_back_storage_task(write_storage_task, is_sync=True) + success = self.issue_write_back_storage_task(write_storage_task, is_sync=True) cost_time = time.time() - tic - logger.info(f"finish write cache back to storage, req_id: {req_id}, cost_time: {cost_time:.6f}s") + if not success: + logger.error( + f"write cache back to storage FAILED, req_id: {req_id}, " + f"block num: {len(keys)}, cost_time: {cost_time:.6f}s" + ) + else: + logger.info(f"finish write cache back to storage, req_id: {req_id}, cost_time: {cost_time:.6f}s") trace_print(LoggingEventName.WRITE_CACHE_TO_STORAGE_END, request.request_id, getattr(request, "user", "")) def write_cache_to_storage_decode(self, request: Request): @@ -1267,41 +1274,30 @@ def write_cache_to_storage_decode(self, request: Request): ) tic = time.time() - self.issue_write_back_storage_task(write_storage_task, is_sync=True) + success = self.issue_write_back_storage_task(write_storage_task, is_sync=True) cost_time = time.time() - tic - logger.info(f"[D instance] finish write cache to storage, req_id: {req_id}, cost_time: {cost_time:.6f}s") + if not success: + logger.error( + f"[D instance] write cache to storage FAILED, req_id: {req_id}, " + f"block num: {len(keys)}, cost_time: {cost_time:.6f}s" + ) + else: + logger.info(f"[D instance] finish write cache to storage, req_id: {req_id}, cost_time: {cost_time:.6f}s") trace_print(LoggingEventName.WRITE_CACHE_TO_STORAGE_END, request.request_id, getattr(request, "user", "")) - def write_all_cache_to_storage(self, request: Request, include_output=True): + def _compute_pd_storage_keys(self, request: Request, input_token_ids: list): """ - Write ALL token cache (including last incomplete block) to storage. - Used in PD storage-pool mode where P writes to storage instead of RDMA to D, - and D writes back all cache (including output tokens) on request completion. - - Unlike write_cache_to_storage_decode which skips incomplete blocks, this method - writes the last incomplete block by padding it to block_size in the storage key - computation (using a ':partial:N' suffix on the key). - - The actual GPU block is still full-sized, so swap_cache_layout works normally. + Compute cache keys (including :partial:N suffix for last incomplete block) + for PD storage-pool mode. Used by both write_all_cache_to_storage (P/D) and + read_cache_from_storage_for_pd (D) to ensure consistent key computation. Args: - request: The request object. - include_output: If True, include output_token_ids in the write (used by D). - If False, only write prompt_token_ids (used by P). - """ - if self.kvcache_storage_backend is None: - return + request: The request object (needed for get_block_hash_extra_keys). + input_token_ids: The token IDs to compute keys for. - # 1. Get complete token_ids - token_ids = request.prompt_token_ids - if isinstance(token_ids, np.ndarray): - token_ids = token_ids.tolist() - else: - token_ids = list(token_ids) - - input_token_ids = token_ids + request.output_token_ids if include_output else token_ids - - # 2. Calculate cache keys using chained hash, including last partial block + Returns: + list: The computed hash keys for each block. + """ keys = [] prefix_block_key = [] block_size = self.config.cache_config.block_size @@ -1317,7 +1313,7 @@ def write_all_cache_to_storage(self, request: Request, include_output=True): key = f"{key}:partial:{actual_token_num}" keys.append(key) else: - # Full block: compute key normally (same as write_cache_to_storage_decode) + # Full block: compute key normally mm_idx, extra_keys = self.get_block_hash_extra_keys( request=request, start_idx=i, @@ -1330,8 +1326,45 @@ def write_all_cache_to_storage(self, request: Request, include_output=True): prefix_block_key = [key] + return keys + + def write_all_cache_to_storage(self, request: Request, include_output=True): + """ + Write ALL token cache (including last incomplete block) to storage. + Used in PD storage-pool mode where P writes to storage instead of RDMA to D, + and D writes back all cache (including output tokens) on request completion. + + Unlike write_cache_to_storage_decode which skips incomplete blocks, this method + writes the last incomplete block by padding it to block_size in the storage key + computation (using a ':partial:N' suffix on the key). + + The actual GPU block is still full-sized, so swap_cache_layout works normally. + + Args: + request: The request object. + include_output: If True, include output_token_ids in the write (used by D). + If False, only write prompt_token_ids (used by P). + + Returns: + bool: True if all blocks written successfully, False otherwise. + """ + if self.kvcache_storage_backend is None: + return True + + # 1. Get complete token_ids + token_ids = request.prompt_token_ids + if isinstance(token_ids, np.ndarray): + token_ids = token_ids.tolist() + else: + token_ids = list(token_ids) + + input_token_ids = token_ids + request.output_token_ids if include_output else token_ids + + # 2. Calculate cache keys using shared helper + keys = self._compute_pd_storage_keys(request, input_token_ids) + if not keys: - return + return True # 3. Get corresponding gpu_block_ids gpu_block_ids = request.block_tables[: len(keys)] @@ -1352,10 +1385,19 @@ def write_all_cache_to_storage(self, request: Request, include_output=True): ) tic = time.time() - self.issue_write_back_storage_task(write_storage_task, is_sync=True) + success = self.issue_write_back_storage_task(write_storage_task, is_sync=True) cost_time = time.time() - tic - logger.info(f"[PD Storage] finish write all cache to storage, req_id: {req_id}, cost_time: {cost_time:.6f}s") + if not success: + logger.error( + f"[PD Storage] write all cache to storage FAILED, req_id: {req_id}, " + f"block num: {len(keys)}, cost_time: {cost_time:.6f}s" + ) + else: + logger.info( + f"[PD Storage] finish write all cache to storage, req_id: {req_id}, cost_time: {cost_time:.6f}s" + ) trace_print(LoggingEventName.WRITE_CACHE_TO_STORAGE_END, request.request_id, getattr(request, "user", "")) + return success def read_cache_from_storage_for_pd(self, request: Request): """ @@ -1379,35 +1421,10 @@ def read_cache_from_storage_for_pd(self, request: Request): token_ids = token_ids.tolist() else: token_ids = list(token_ids) - input_token_ids = token_ids - # 2. Calculate cache keys using same algorithm as write_all_cache_to_storage - keys = [] - prefix_block_key = [] - block_size = self.config.cache_config.block_size - mm_idx = 0 - - for i in range(0, len(input_token_ids), block_size): - block_token_ids = input_token_ids[i : i + block_size] - actual_token_num = len(block_token_ids) - - if actual_token_num < block_size: - key = get_hash_str(block_token_ids, prefix_block_key) - key = f"{key}:partial:{actual_token_num}" - keys.append(key) - else: - mm_idx, extra_keys = self.get_block_hash_extra_keys( - request=request, - start_idx=i, - end_idx=i + block_size, - mm_idx=mm_idx, - ) - prefix_block_key.extend(extra_keys) - key = get_hash_str(block_token_ids, prefix_block_key) - keys.append(key) - - prefix_block_key = [key] + # 2. Calculate cache keys using shared helper (same algorithm as write_all_cache_to_storage) + keys = self._compute_pd_storage_keys(request, token_ids) if not keys: return [] @@ -1448,8 +1465,12 @@ def read_cache_from_storage_for_pd(self, request: Request): return storage_block_ids def issue_write_back_storage_task(self, task: WriteStorageTask, is_sync=True): + """ + Issue a write-back storage task. + Returns True if all blocks written successfully (sync mode), True always (async mode). + """ if self.kvcache_storage_backend is None: - return + return True if not envs.FD_AS_ONLY_FLUSH and len(task.keys) != len(task.gpu_block_ids): err_msg = ( @@ -1462,15 +1483,37 @@ def issue_write_back_storage_task(self, task: WriteStorageTask, is_sync=True): self.task_write_back_event[task.task_id] = Event() self.cache_task_queue.put_transfer_task((CacheStatus.GPU2STORAGE, task)) if is_sync: - self.wait_write_storage_task(task.task_id) + return self.wait_write_storage_task(task.task_id, expected_block_num=len(task.gpu_block_ids)) + return True - def wait_write_storage_task(self, req_id): + def wait_write_storage_task(self, req_id, expected_block_num=0, timeout=60.0): """ - Sync write back task + Sync write back task. + Returns True if all expected blocks written successfully across all TP ranks. + + Args: + req_id: request ID + expected_block_num: number of blocks expected to be written + timeout: max wait time in seconds """ if req_id in self.task_write_back_event: - self.task_write_back_event[req_id].wait() + success = self.task_write_back_event[req_id].wait(timeout=timeout) del self.task_write_back_event[req_id] + if not success: + logger.error(f"[PD Storage] write storage task timeout after {timeout}s, req_id: {req_id}") + self.storage_write_back_result.pop(req_id, None) + return False + # Check actual written block count vs expected + written_block_ids = self.storage_write_back_result.pop(req_id, []) + actual_written = len(written_block_ids) + if expected_block_num > 0 and actual_written < expected_block_num: + logger.error( + f"[PD Storage] write storage incomplete: {actual_written}/{expected_block_num} blocks, " + f"req_id: {req_id}" + ) + return False + return True + return True def issue_prefetch_storage_task(self, task: ReadStorageTask, is_sync=True): """ @@ -2403,8 +2446,16 @@ def recv_data_transfer_result(self): elif event_type.value == CacheStatus.GPU2STORAGE.value: logger.debug(f"recv_data_transfer_result: {data}") task_id, hash_keys, block_ids = data[1:] - if task_id in self.task_write_back_event: - self.task_write_back_event[task_id].set() + # Collect results from all TP ranks (same pattern as STORAGE2GPU path) + if task_id not in self.storage_write_back_result: + self.storage_write_back_result[task_id] = [] + saved_results = self.storage_write_back_result[task_id] + saved_results.append(block_ids) + if len(saved_results) == self.tensor_parallel_size: + # Take minimum across all ranks (conservative, same as read path) + self.storage_write_back_result[task_id] = min(saved_results, key=len) + if task_id in self.task_write_back_event: + self.task_write_back_event[task_id].set() else: ( event_type, diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index c545492c1ee..da0a3fc2b55 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -640,11 +640,18 @@ def _recycle_resources(self, task_id, index, task, result=None, is_prefill=False if result.error_code == 200: write_cache_start_time = time.time() llm_logger.info(f"[PD Storage] P writing cache to storage (direct), request_id: {task_id}") - self.resource_manager.cache_manager.write_all_cache_to_storage(task, include_output=False) - llm_logger.info( - f"[PD Storage] P finished writing cache to storage (direct), " - f"request_id: {task_id}, cost: {time.time()-write_cache_start_time:.5f}s" + write_success = self.resource_manager.cache_manager.write_all_cache_to_storage( + task, include_output=False ) + if not write_success: + result.error_code = 501 + result.error_msg = f"P instance failed to write cache to storage for request {task_id}" + llm_logger.error(f"[PD Storage] {result.error_msg}") + else: + llm_logger.info( + f"[PD Storage] P finished writing cache to storage (direct), " + f"request_id: {task_id}, cost: {time.time()-write_cache_start_time:.5f}s" + ) trace_print(LoggingEventName.CHECK_CACHE_TRANSFER_END, task_id, getattr(task, "user", "")) result.metrics.send_request_output_to_decode_time = time.time() self.split_connector.send_first_token(task.disaggregate_info, [result]) diff --git a/tests/cache_manager/test_prefix_cache_manager.py b/tests/cache_manager/test_prefix_cache_manager.py index f2a4a5fa116..8dd9b5162c4 100644 --- a/tests/cache_manager/test_prefix_cache_manager.py +++ b/tests/cache_manager/test_prefix_cache_manager.py @@ -1485,6 +1485,7 @@ def test_recv_data_transfer_result_handles_storage_events(self): (CacheStatus.STORAGE2GPU, "pref", ["h1"], [1, 2]), (CacheStatus.STORAGE2GPU, "pref", ["h2"], [1]), (CacheStatus.GPU2STORAGE, "write", ["h3"], [9]), + (CacheStatus.GPU2STORAGE, "write", ["h3"], [9]), ] manager.cache_task_queue = _FakeTransferQueue(payloads) with self.assertRaises(SystemExit): From eafc5799cb64981ae4d453d72424217b2970ad8f Mon Sep 17 00:00:00 2001 From: juncaipeng <13006307475@163.com> Date: Mon, 18 May 2026 14:29:47 +0000 Subject: [PATCH 5/6] fix ci --- tests/cache_manager/test_cache_transfer_manager.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/cache_manager/test_cache_transfer_manager.py b/tests/cache_manager/test_cache_transfer_manager.py index 599e0b8c5e0..5d2a054761e 100644 --- a/tests/cache_manager/test_cache_transfer_manager.py +++ b/tests/cache_manager/test_cache_transfer_manager.py @@ -449,7 +449,7 @@ def test_write_back_storage_task_skips_cached_keys(self): self.manager._run_write_back_storage.assert_not_called() self.manager.cache_task_queue.put_transfer_done_signal.assert_called_once_with( - (cache_transfer_manager.CacheStatus.GPU2STORAGE, "5", ["k1", "k2"], []) + (cache_transfer_manager.CacheStatus.GPU2STORAGE, "5", ["k1", "k2"], [0, 1]) ) def test_read_storage_task_no_matches(self): @@ -737,7 +737,7 @@ class LocalArgs(Args): def test_write_back_storage_task_nonzero_rank_no_signal(self): self.manager.cache_task_queue.swap_to_storage_barrier = MagicMock() self.manager.cache_task_queue.put_transfer_done_signal = MagicMock() - self.manager._run_write_back_storage = MagicMock() + self.manager._run_write_back_storage = MagicMock(return_value=1) self.manager.rank = 1 # Mock storage backend to return 0 matches (no keys exist) @@ -761,7 +761,10 @@ def test_write_back_storage_task_nonzero_rank_no_signal(self): [0], 0.1, ) - self.manager.cache_task_queue.put_transfer_done_signal.assert_not_called() + # After the refactor, the done signal is always sent regardless of rank. + self.manager.cache_task_queue.put_transfer_done_signal.assert_called_once_with( + (cache_transfer_manager.CacheStatus.GPU2STORAGE, "9", ["k1"], [0]) + ) def test_get_key_prefix_from_version(self): with patch("fastdeploy.cache_manager.cache_transfer_manager.yaml.safe_load") as mock_load: From 2eedbec3637f3b67cfaa855f98e2e65c641b059c Mon Sep 17 00:00:00 2001 From: juncaipeng <13006307475@163.com> Date: Tue, 19 May 2026 02:42:00 +0000 Subject: [PATCH 6/6] up --- custom_ops/gpu_ops/swap_cache_layout.cu | 55 ++++++++++++++++++------- 1 file changed, 40 insertions(+), 15 deletions(-) diff --git a/custom_ops/gpu_ops/swap_cache_layout.cu b/custom_ops/gpu_ops/swap_cache_layout.cu index 8fd497f2362..08f64197f9b 100644 --- a/custom_ops/gpu_ops/swap_cache_layout.cu +++ b/custom_ops/gpu_ops/swap_cache_layout.cu @@ -155,26 +155,51 @@ void SwapCacheImpLayout(const std::vector& cache_gpu_tensors, bool use_optimized = is_cpu_block_ids_sequential(cpu_block_ids); + // float4 vectorized kernels require block_bytes to be 16-byte aligned + // and cache_cpu_base to be 16-byte aligned for correct float4 access. + if (use_optimized && (block_bytes % sizeof(float4) != 0)) { + use_optimized = false; + } + if (use_optimized) { + int64_t cpu_start_block = cpu_block_ids[0]; + uintptr_t cpu_base_addr = + static_cast(cache_cpu_pointer) + + cpu_start_block * layer_number * cache_block_stride * sizeof(DataType_); + if (cpu_base_addr % sizeof(float4) != 0) { + use_optimized = false; + } + } + if (use_optimized) { ensure_device_block_ids(n_blocks * sizeof(int64_t)); ensure_device_layer_ptrs(layer_number * sizeof(DataType_*)); - cudaMemcpyAsync(g_device_block_ids, - gpu_block_ids.data(), - n_blocks * sizeof(int64_t), - cudaMemcpyHostToDevice, - stream); + cudaError_t status = cudaMemcpyAsync(g_device_block_ids, + gpu_block_ids.data(), + n_blocks * sizeof(int64_t), + cudaMemcpyHostToDevice, + stream); + PADDLE_ENFORCE_EQ( + status, + cudaSuccess, + phi::errors::External("cudaMemcpyAsync block_ids H2D failed: %s", + cudaGetErrorString(status))); std::vector h_layer_ptrs(layer_number); for (int64_t i = 0; i < layer_number; i++) { h_layer_ptrs[i] = reinterpret_cast( const_cast(cache_gpu_tensors[i].data())); } - cudaMemcpyAsync(g_device_layer_ptrs, - h_layer_ptrs.data(), - layer_number * sizeof(DataType_*), - cudaMemcpyHostToDevice, - stream); + status = cudaMemcpyAsync(g_device_layer_ptrs, + h_layer_ptrs.data(), + layer_number * sizeof(DataType_*), + cudaMemcpyHostToDevice, + stream); + PADDLE_ENFORCE_EQ( + status, + cudaSuccess, + phi::errors::External("cudaMemcpyAsync layer_ptrs H2D failed: %s", + cudaGetErrorString(status))); int64_t cpu_start_block = cpu_block_ids[0]; auto* cache_cpu_base = reinterpret_cast(cache_cpu_pointer) + @@ -196,11 +221,11 @@ void SwapCacheImpLayout(const std::vector& cache_gpu_tensors, // CPU→GPU: DMA memcpy to staging then scatter kernel ensure_staging_buffer(total_bytes); - cudaError_t status = cudaMemcpyAsync(g_staging_buffer, - cache_cpu_base, - total_bytes, - cudaMemcpyHostToDevice, - stream); + status = cudaMemcpyAsync(g_staging_buffer, + cache_cpu_base, + total_bytes, + cudaMemcpyHostToDevice, + stream); PADDLE_ENFORCE_EQ(status, cudaSuccess, phi::errors::External("cudaMemcpyAsync H2D failed: %s",