Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions python/infinilm/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,14 @@ def _update_requests(
sampled_tokens: List[int],
):
"""Update request status after inference step."""
# Only reset req blocks for paged cache
if is_prefill and self.cache_type == "paged":
self.scheduler.cache_manager.reset_req_blocks()
if is_prefill:
match self.cache_type:
case "paged":
self.scheduler.cache_manager.reset_req_blocks()
case "static":
self.scheduler.update_cache()
case _:
raise ValueError(f"Unsupported cache_type: {self.cache_type}")

for req, token_id in zip(requests, sampled_tokens):

Expand Down
96 changes: 85 additions & 11 deletions python/infinilm/llm/static_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import janus
from typing import List, Optional

from infinilm.llm.cache_manager import BlockManager
from infinilm.llm.request import (
RequestStatus,
InferenceRequest,
Expand All @@ -16,6 +17,8 @@

logger = logging.getLogger(__name__)

_BLOCK_SIZE = 16


class StaticSchedulerOutput:
"""Static scheduler output containing single request and execution phase info."""
Expand All @@ -24,10 +27,12 @@ def __init__(
self,
scheduled_requests: List[InferenceRequest],
is_prefill: bool = False,
prefix_hit_len: int = 0,
):
self.scheduled_requests = scheduled_requests
self.num_requests = len(scheduled_requests)
self.is_prefill = is_prefill
self.prefix_hit_len = prefix_hit_len

def build_model_inputs(
self, temperature: float = 1.0, top_p: float = 0.8, top_k: int = 1
Expand All @@ -36,29 +41,30 @@ def build_model_inputs(

Static cache model inputs:

Prefill phase:
- input_ids: All prompt tokens [1, prompt_length]
- position_ids: [0, 1, 2, ..., prompt_length-1]
- past_kv_lengths: [0] (no cached tokens initially)
Prefill phase (with prefix cache reuse):
- input_ids: Tokens after the cached prefix [1, prompt_length - prefix_hit_len]
- position_ids: [prefix_hit_len, ..., prompt_length-1]
- past_kv_lengths: [prefix_hit_len] (reuse cached prefix)
- total_kv_lengths: [prompt_length]

Decode phase:
- input_ids: Only the last generated token [1, 1]
- position_ids: [current_position] (position in full sequence)
- past_kv_lengths: [num_cached_tokens]
- total_kv_lengths: [total_tokens]
-
"""
req = self.scheduled_requests[0]

if self.is_prefill:
# Prefill: send all prompt tokens
# Prefill: only send tokens not already in cache
tokens = req.get_input_tokens()
input_ids = [tokens]
position_ids = [list(range(len(tokens)))]
past_kv_len = 0
prefix_hit_len = self.prefix_hit_len
input_tokens = tokens[prefix_hit_len:]
input_ids = [input_tokens]
position_ids = [list(range(prefix_hit_len, len(tokens)))]
past_kv_len = prefix_hit_len
total_kv_len = len(tokens)
input_offsets = [0, len(tokens)]
input_offsets = [0, len(input_tokens)]
else:
# Decode: send only the last generated token
last_token = req.generated_token_ids[-1]
Expand Down Expand Up @@ -91,12 +97,15 @@ class StaticScheduler:
- Only handles one request at a time
- No cache block management needed
- Simple waiting queue for incoming requests
- Prefix cache reuse via chained block hashing (block size = _BLOCK_SIZE)
"""

def __init__(self, max_cache_len: int = 4096):
self.waiting_queue = janus.Queue()
self.running_request: Optional[InferenceRequest] = None
self.max_cache_len = max_cache_len
self.cached_block_hashes: List[int] = []
self.pending_block_hashes: List[int] = []

def add_request(self, request: InferenceRequest):
if request is not None:
Expand Down Expand Up @@ -138,6 +147,23 @@ def schedule(self) -> Optional[StaticSchedulerOutput]:
)
continue

total_length = req.get_total_length()
if total_length % _BLOCK_SIZE == 1 and total_length > _BLOCK_SIZE:
block_index = total_length // _BLOCK_SIZE - 1
if len(self.cached_block_hashes) <= block_index:
all_tokens = req.get_all_token_ids()
block_tokens = all_tokens[-(_BLOCK_SIZE + 1) : -1]
prev_h = (
self.cached_block_hashes[-1]
if self.cached_block_hashes
else -1
)
new_h = BlockManager.compute_hash(block_tokens, prev_h)
self.cached_block_hashes.append(new_h)
logger.debug(
f"Decode: appended block hash at index {block_index}"
)

return StaticSchedulerOutput(scheduled_requests=[req], is_prefill=False)

# Case 2: Get new request from waiting queue (prefill phase)
Expand Down Expand Up @@ -175,9 +201,55 @@ def schedule(self) -> Optional[StaticSchedulerOutput]:
)
continue

tokens = req.prompt_token_ids
num_full_blocks = prompt_len // _BLOCK_SIZE
matched = 0

self.pending_block_hashes.clear()

for i in range(num_full_blocks):
prev_h = self.cached_block_hashes[i - 1] if i > 0 else -1
h = BlockManager.compute_hash(
tokens[i * _BLOCK_SIZE : (i + 1) * _BLOCK_SIZE], prev_h
)
if (
i < len(self.cached_block_hashes)
and h == self.cached_block_hashes[i]
):
matched = i + 1
else:
del self.cached_block_hashes[i:]
cur_h = h
self.pending_block_hashes.append(cur_h)
for j in range(i + 1, num_full_blocks):
cur_h = BlockManager.compute_hash(
tokens[j * _BLOCK_SIZE : (j + 1) * _BLOCK_SIZE],
cur_h,
)
self.pending_block_hashes.append(cur_h)
break
else:
del self.cached_block_hashes[matched:]

prefix_hit_len = matched * _BLOCK_SIZE
logger.info(
f"Prefill cache match: {matched}/{num_full_blocks} blocks "
f"({prefix_hit_len} tokens reused, {len(self.pending_block_hashes)} pending)"
)

req.status = RequestStatus.RUNNING
self.running_request = req
return StaticSchedulerOutput(scheduled_requests=[req], is_prefill=True)
return StaticSchedulerOutput(
scheduled_requests=[req], is_prefill=True, prefix_hit_len=prefix_hit_len
)

def update_cache(self):
"""Commit hashes computed during prefill into the confirmed cache hash list."""
self.cached_block_hashes.extend(self.pending_block_hashes)
self.pending_block_hashes.clear()
logger.debug(
f"update_cache: cached_block_hashes now has {len(self.cached_block_hashes)} blocks"
)

def complete_requests(self, requests: List[InferenceRequest]):
"""Handle completed requests."""
Expand All @@ -190,6 +262,8 @@ def get_cache_stats(self) -> dict:
"""Get cache statistics."""
return {
"max_cache_len": self.max_cache_len,
"cached_blocks": len(self.cached_block_hashes),
"cached_tokens": len(self.cached_block_hashes) * _BLOCK_SIZE,
"running_request": (
self.running_request.request_id if self.running_request else None
),
Expand Down