diff --git a/python/infinilm/llm/llm.py b/python/infinilm/llm/llm.py index 7293a2fe..d31b47cc 100644 --- a/python/infinilm/llm/llm.py +++ b/python/infinilm/llm/llm.py @@ -248,9 +248,14 @@ def _update_requests( sampled_tokens: List[int], ): """Update request status after inference step.""" - # Only reset req blocks for paged cache - if is_prefill and self.cache_type == "paged": - self.scheduler.cache_manager.reset_req_blocks() + if is_prefill: + match self.cache_type: + case "paged": + self.scheduler.cache_manager.reset_req_blocks() + case "static": + self.scheduler.update_cache() + case _: + raise ValueError(f"Unsupported cache_type: {self.cache_type}") for req, token_id in zip(requests, sampled_tokens): diff --git a/python/infinilm/llm/static_scheduler.py b/python/infinilm/llm/static_scheduler.py index e7336242..de4d9d35 100644 --- a/python/infinilm/llm/static_scheduler.py +++ b/python/infinilm/llm/static_scheduler.py @@ -7,6 +7,7 @@ import janus from typing import List, Optional +from infinilm.llm.cache_manager import BlockManager from infinilm.llm.request import ( RequestStatus, InferenceRequest, @@ -16,6 +17,8 @@ logger = logging.getLogger(__name__) +_BLOCK_SIZE = 16 + class StaticSchedulerOutput: """Static scheduler output containing single request and execution phase info.""" @@ -24,10 +27,12 @@ def __init__( self, scheduled_requests: List[InferenceRequest], is_prefill: bool = False, + prefix_hit_len: int = 0, ): self.scheduled_requests = scheduled_requests self.num_requests = len(scheduled_requests) self.is_prefill = is_prefill + self.prefix_hit_len = prefix_hit_len def build_model_inputs( self, temperature: float = 1.0, top_p: float = 0.8, top_k: int = 1 @@ -36,10 +41,10 @@ def build_model_inputs( Static cache model inputs: - Prefill phase: - - input_ids: All prompt tokens [1, prompt_length] - - position_ids: [0, 1, 2, ..., prompt_length-1] - - past_kv_lengths: [0] (no cached tokens initially) + Prefill phase (with prefix cache reuse): + - input_ids: Tokens after the cached prefix [1, prompt_length - prefix_hit_len] + - position_ids: [prefix_hit_len, ..., prompt_length-1] + - past_kv_lengths: [prefix_hit_len] (reuse cached prefix) - total_kv_lengths: [prompt_length] Decode phase: @@ -47,18 +52,19 @@ def build_model_inputs( - position_ids: [current_position] (position in full sequence) - past_kv_lengths: [num_cached_tokens] - total_kv_lengths: [total_tokens] - - """ req = self.scheduled_requests[0] if self.is_prefill: - # Prefill: send all prompt tokens + # Prefill: only send tokens not already in cache tokens = req.get_input_tokens() - input_ids = [tokens] - position_ids = [list(range(len(tokens)))] - past_kv_len = 0 + prefix_hit_len = self.prefix_hit_len + input_tokens = tokens[prefix_hit_len:] + input_ids = [input_tokens] + position_ids = [list(range(prefix_hit_len, len(tokens)))] + past_kv_len = prefix_hit_len total_kv_len = len(tokens) - input_offsets = [0, len(tokens)] + input_offsets = [0, len(input_tokens)] else: # Decode: send only the last generated token last_token = req.generated_token_ids[-1] @@ -91,12 +97,15 @@ class StaticScheduler: - Only handles one request at a time - No cache block management needed - Simple waiting queue for incoming requests + - Prefix cache reuse via chained block hashing (block size = _BLOCK_SIZE) """ def __init__(self, max_cache_len: int = 4096): self.waiting_queue = janus.Queue() self.running_request: Optional[InferenceRequest] = None self.max_cache_len = max_cache_len + self.cached_block_hashes: List[int] = [] + self.pending_block_hashes: List[int] = [] def add_request(self, request: InferenceRequest): if request is not None: @@ -138,6 +147,23 @@ def schedule(self) -> Optional[StaticSchedulerOutput]: ) continue + total_length = req.get_total_length() + if total_length % _BLOCK_SIZE == 1 and total_length > _BLOCK_SIZE: + block_index = total_length // _BLOCK_SIZE - 1 + if len(self.cached_block_hashes) <= block_index: + all_tokens = req.get_all_token_ids() + block_tokens = all_tokens[-(_BLOCK_SIZE + 1) : -1] + prev_h = ( + self.cached_block_hashes[-1] + if self.cached_block_hashes + else -1 + ) + new_h = BlockManager.compute_hash(block_tokens, prev_h) + self.cached_block_hashes.append(new_h) + logger.debug( + f"Decode: appended block hash at index {block_index}" + ) + return StaticSchedulerOutput(scheduled_requests=[req], is_prefill=False) # Case 2: Get new request from waiting queue (prefill phase) @@ -175,9 +201,55 @@ def schedule(self) -> Optional[StaticSchedulerOutput]: ) continue + tokens = req.prompt_token_ids + num_full_blocks = prompt_len // _BLOCK_SIZE + matched = 0 + + self.pending_block_hashes.clear() + + for i in range(num_full_blocks): + prev_h = self.cached_block_hashes[i - 1] if i > 0 else -1 + h = BlockManager.compute_hash( + tokens[i * _BLOCK_SIZE : (i + 1) * _BLOCK_SIZE], prev_h + ) + if ( + i < len(self.cached_block_hashes) + and h == self.cached_block_hashes[i] + ): + matched = i + 1 + else: + del self.cached_block_hashes[i:] + cur_h = h + self.pending_block_hashes.append(cur_h) + for j in range(i + 1, num_full_blocks): + cur_h = BlockManager.compute_hash( + tokens[j * _BLOCK_SIZE : (j + 1) * _BLOCK_SIZE], + cur_h, + ) + self.pending_block_hashes.append(cur_h) + break + else: + del self.cached_block_hashes[matched:] + + prefix_hit_len = matched * _BLOCK_SIZE + logger.info( + f"Prefill cache match: {matched}/{num_full_blocks} blocks " + f"({prefix_hit_len} tokens reused, {len(self.pending_block_hashes)} pending)" + ) + req.status = RequestStatus.RUNNING self.running_request = req - return StaticSchedulerOutput(scheduled_requests=[req], is_prefill=True) + return StaticSchedulerOutput( + scheduled_requests=[req], is_prefill=True, prefix_hit_len=prefix_hit_len + ) + + def update_cache(self): + """Commit hashes computed during prefill into the confirmed cache hash list.""" + self.cached_block_hashes.extend(self.pending_block_hashes) + self.pending_block_hashes.clear() + logger.debug( + f"update_cache: cached_block_hashes now has {len(self.cached_block_hashes)} blocks" + ) def complete_requests(self, requests: List[InferenceRequest]): """Handle completed requests.""" @@ -190,6 +262,8 @@ def get_cache_stats(self) -> dict: """Get cache statistics.""" return { "max_cache_len": self.max_cache_len, + "cached_blocks": len(self.cached_block_hashes), + "cached_tokens": len(self.cached_block_hashes) * _BLOCK_SIZE, "running_request": ( self.running_request.request_id if self.running_request else None ),