Skip to content

Replace control-token KV hiding with token-exchange (#8)#34

Open
AlonMalach wants to merge 14 commits into
mainfrom
feature/reduce-control-token-kv-overhead
Open

Replace control-token KV hiding with token-exchange (#8)#34
AlonMalach wants to merge 14 commits into
mainfrom
feature/reduce-control-token-kv-overhead

Conversation

@AlonMalach
Copy link
Copy Markdown
Collaborator

@AlonMalach AlonMalach commented May 15, 2026

Summary

Closes #8. Replaces the legacy KV-hiding scheme (which padded every Q/K/V in every decoder layer with control_dims=32 extra dimensions and branded the control token's K with finfo.min) with token exchange: at the switch layer, each control token's id is rewritten in input_ids to a substitute id whose embedding the model was already trained to handle. The decoder embeds the rewritten ids once and runs natively — no Q/K/V padding, no KV cache expansion, no hidden-count position correction.

Outcome on Granite 4.1-8b: KV cache head_dim returns to native 128 (from expanded 160, padded to 192). FlashAttention runs on native vectors. ~20% less KV memory, ~33% less attention compute per layer. No retraining required.

How it works

Two paired layers:

  1. Chat template (compose-time, tokenizer_setup.py): a skip-once Jinja flag (ns.skip_next_start_of_role) suppresses the role-marker token that would normally follow each control token, and alora_pass2 drops the first character of in-message ALoRA invocation text (BPE-equivalent to dropping one tokenized piece). The rendered sequence is one token shorter than before.
  2. Switch (runtime, hf/switch/single.py + vllm/switch/single.py): when emitting adapter_indices, also rewrite each control token's id via a control_to_substitute_lut buffer. Returns (adapter_indices, modified_input_ids). The decoder unpacks and embeds modified_input_ids.

Substitute derivation:

  • ALoRA → first token of alora_invocation_tokens (read from adapter's adapter_config.json).
  • LoRA / built-in → whatever the tokenizer's chat template emits at the start of a no-adapter user turn. Derived at compose time by _probe_lora_substitute_token_id(tokenizer) — render a minimal probe chat, tokenize, take input_ids[0]. On Granite 4.x this resolves to <|start_of_role|> (id 100264).

What's removed

The legacy KV-hiding path is gone, not gated. _expand_with_control_dimensions deleted from both backends. control_dims, hiding_groups, hiding_policy, adapter_third_party, expanded_head_dim, num_hiding_groups, get_hiding_group_token_ids, get_third_party_adapter_mask, get_adapter_hiding_policy_matrix deleted from config. adapter_substitute_token_ids is required when num_adapters > 0. ~3000 LoC deleted net.

Backward compatibility

Breaking. Checkpoints composed under the legacy scheme cannot load — from_pretrained raises ValueError from the new adapter_substitute_token_ids is required validator. Users must recompose with the current compose_granite_switch.py against the same base + adapter sources (minutes per checkpoint).

AlonMalach added 14 commits May 15, 2026 10:35
Adapter control tokens were padded into every Q/K/V in every decoder layer
via `control_dims=32` and masked with `finfo.min`. This bloated the KV
cache head_dim by 25-50% and forced FlashAttention onto padded 160/192-wide
vectors when only `num_hiding_groups` (typically 1) of the 32 extra dims
were ever non-zero.

Switch to token-exchange: after the switch reads `input_ids` and detects
which adapter to activate, replace each control token's embedding with a
substitute real-token embedding before the decoder runs. Control tokens
become ordinary content tokens in the residual stream and `control_dims`
collapses to 0, dropping the expansion entirely.

Substitute ids are computed at compose time:
  - ALoRA adapters -> first token of alora_invocation_tokens
  - LoRA/builtin adapters -> tokenizer.bos_token_id

New config field `adapter_substitute_token_ids` is persisted in config.json
and drives a `use_token_exchange` property read by both backends. Default
`control_dims` flips from 32 to 0.

The legacy KV-hiding path is preserved as an opt-in escape hatch via the
new `--legacy-hiding` composer flag; any adapter that regresses under
token-exchange can be composed with the old semantics unchanged.

Key validation:
  - Reject num_adapters>0 with neither hiding nor substitute ids (would
    leak raw control-token embeddings into attention).
  - Reject duplicate adapter_token_ids (LUT collision).
  - Reject negative / wrong-length substitute ids.

Position correction via `hidden_count` is skipped in token-exchange mode
since control tokens are real positions.

Design: docs/KV_CACHE_OVERHEAD_REMOVAL.md
Tracks issue #8.
Measures four metrics per position, teacher-forced, across a list of
prompts to compare legacy KV-hiding (control_dims>0) vs. token-exchange
(control_dims=0 + substitute ids):

  1. KL(p_old || p_new) per position  (log_softmax based to avoid underflow)
  2. Top-1 agreement                    (tagged "(noisy)" on wide nuclei)
  3. Nucleus (top-p=0.9) Jaccard        (sampling-set overlap)
  4. Mass under old nucleus by new      (the actionable gate)

Results are partitioned into overall / pre-control / adapter-active buckets.
The pre-control bucket must be bit-for-bit identical (KL max == 0, top-1
agree == 1.0); any drift there signals a bug in the embedding-swap gating
rather than a mode trade-off.

Two modes:
  - Synthetic (CPU): builds two HF models with identical base weights, one
    in legacy hiding and one in token-exchange. Useful as a plumbing check
    and regression guard. Runs as a standard pytest.
  - Real-model (GPU, opt-in): set GRANITE_SWITCH_PARITY_MODELS='{"old":...,
    "new":...}'. Loads composed checkpoints and uses demo-script prompts
    (14 adapter-specific prompts from run_adapter_generation_direct.py)
    rendered through the composed tokenizer's chat template. Thresholds:
    top-1 >= 0.95, mean KL <= 0.02, mean mass-under-old-nucleus >= 0.88.

Also exposes build_demo_prompts() in the demo script. Short-circuits
_generate via a module-level capture flag so prompt text is collected
without touching model.generate. Used by the parity eval to pull
realistic adapter inputs without duplicating the demo prompt data.

CLI usage:
    python -m tests.integration.test_token_exchange_parity \
        --old /path/to/legacy_build --new /path/to/te_build --json-out report.json
Granite tokenizers alias bos_token_id to <|end_of_text|> (EOS), so the
previous BOS-based substitute for LoRA/builtin adapters would have
injected an end-of-text signal mid-prompt — a stop-generation marker
in a place the model was not trained to see it.

The chat template places the LoRA control token at sequence start,
immediately followed by <|start_of_role|>user<|end_of_role|>... — so
<|start_of_role|> is the deterministic "token that naturally follows"
for every LoRA adapter, and its embedding is well-trained in the base
model (part of the base vocab on Granite 4.0 and 4.1).

Parallels the ALoRA path (substitute = first invocation token).
Both paths now pick "the token that comes right after the control
token in the rendered chat prompt" — single principle, two sources.

Validated:
  - tokenizer.convert_tokens_to_ids('<|start_of_role|>') == 100264
    on ibm-granite/granite-4.1-3b and granite-4.0-micro (part of
    base vocab, not composer-added).
  - bos_token_id == eos_token_id == 100257 ('<|end_of_text|>') on
    all three Granite tokenizers tested — confirming the prior
    default was semantically wrong.
…#8)

The runtime embedding swap replaces each adapter control token's
embedding with a substitute token's embedding — for LoRA adapters this
is <|start_of_role|>, for assistant-boundary ALoRA adapters it's also
<|start_of_role|> (the first token of their invocation sequence). But
the chat template then emits a *real* <|start_of_role|> at the next
position: the user or assistant role marker that naturally follows the
control-token prefix.

Result before this change: two consecutive positions carrying
<|start_of_role|>'s embedding. The model has never seen that pattern
during pretraining — a duplicate-embedding OOD right at the start of
the decoder's residual stream.

Fix: add a skip-once Jinja flag (ns.skip_next_start_of_role). Arm it
when lora_prefix_insertion emits the LoRA control token, or when
alora_insertion fires the fallback path for assistant-boundary ALoRAs.
Wrap every <|start_of_role|> emission in the base Granite template
with a skip-once block that consumes the flag. The flag is single-shot
— only the very first <|start_of_role|> after the control token is
suppressed; all later role markers emit normally.

Not addressed in this PR: ALoRAs whose invocation text is in-message
content text (<requirements>, <guardian>, <certainty>). The first
token of these invocations is the single character '<', and the rest
of the invocation text cannot be cleanly sliced at the template level
without changing what 'requirements>' (or 'guardian>', etc.)
tokenizes to. Those adapters retain the duplicate-embedding pattern
until a runtime-level drop lands in a follow-up.

Backward compatibility: old checkpoints (composed before this change)
load unchanged — the template edit only runs at compose time and
affects only newly-composed models. Their rendered output for LoRA
and assistant-boundary ALoRA is now one token shorter than before
(the suppressed <|start_of_role|>). Update the three
test_chat_template tests whose assertions encoded the old contract.
Closes the remaining duplicate-embedding OOD at the swap site. Complements
the skip-once <|start_of_role|> edit from the previous commit by extending
the same principle to ALoRA adapters whose invocation text lives inside a
user message (<requirements>, <certainty>, <guardian>, <context>, etc.).

Change: in alora_pass2, after inserting the control token before the
invocation text, also drop the first CHARACTER of the invocation text.
Example: "Please <|req_check|><requirements>" becomes
"Please <|req_check|>requirements>". At runtime the embedding-swap
replaces the control token's embedding with the first invocation token's
embedding — the embedding of '<'. The decoder then sees
[<|req_check|>→e_<, requirements, >] — exactly what "<requirements>"
tokenizes to in isolation, with no duplicate.

Why this is safe on the Granite tokenizer: verified empirically via a
new property test (test_first_char_drop_equals_first_token_drop). For
every ALoRA invocation in the standard library, tokenizing the full
invocation and dropping the first token ID yields the same sequence as
tokenizing the string with its first character removed. BPE's greedy-
merge would break this property if the second-byte merges depended on
the leading '<'; it doesn't, because '<' tokenizes as its own single-
character token in every case.

The accompanying test test_first_token_is_single_character asserts the
complementary invariant: the first token of each invocation decodes to
exactly one character. If a future invocation text starts with a
multi-character first token, that test catches it — the Jinja edit
(invocation_text[1:] drops one character) would otherwise silently
produce a wrong-length drop.

Combined with the previous commit (skip-once <|start_of_role|>), the
duplicate-embedding pattern is now eliminated across all adapter types
in the Granite adapter library: LoRA, assistant-boundary ALoRA, and
in-message ALoRA.
Previously the composer hardcoded _LORA_SUBSTITUTE_TOKEN =
"<|start_of_role|>". That's the right answer for Granite 4.x but it
ties the default-path composer to a Granite-specific token name. Any
base model with a different chat template (different role marker,
different turn-open convention) would silently get the wrong
substitute — a token the base model knows, but not the one sitting
at position 1 of its rendered prompt.

Replace the hardcode with a compose-time probe: render a minimal
no-adapter user turn through tokenizer.apply_chat_template, tokenize,
and read input_ids[0]. That's by construction whatever the template
emits at the start of a normal turn, which is exactly what sits at
position 1 after a LoRA-prepended control token. The substitute and
the template's own behavior are now derived from the same source of
truth.

Verified: the probe returns 100264 (<|start_of_role|>) on
granite-4.1-3b, granite-4.0-micro, and granite-switch-4.1-3b-preview
— identical to the previous hardcoded value. Behavior on Granite is
unchanged; the door is open for non-Granite base models.

Error paths give actionable messages:
  - Tokenizer has no chat_template → suggest --legacy-hiding
  - Template render fails → report the Jinja error, suggest
    --legacy-hiding
  - First token is <unk> → report that the template emits something
    outside the vocab
  - Probe returns an empty id list → same

Tests:
  - tests/composer/test_lora_substitute_probe.py (7 cases):
    * Real tokenizer round-trip on granite-4.1-3b and 4.0-micro
    * Synthetic tokenizer with a non-Granite template returns
      the custom template's first-token id
    * All four error paths raise ValueError with matching messages
Refactor: the runtime substitution LUT and the embedding-swap step
move out of each backend's decoder and into SingleSwitch (HF + vLLM).
The switch now performs both halves of token-exchange:

  1. Adapter selection — read input_ids, detect control tokens via
     input_ids == adapter_token_ids, emit per-token adapter_indices
     (unchanged).
  2. Token rewrite — replace each control token's id in input_ids
     with its substitute id (from a switch-owned LUT). New.

The switch's forward signature changes from
  -> adapter_indices
to
  -> (adapter_indices, modified_input_ids)

The decoder consumes both: adapter_indices feeds the LoRA layers as
before, modified_input_ids feeds embed_tokens / get_input_embeddings
exactly once. There is no longer a decoder-side LUT, no scatter, no
clone-guard, no use_token_exchange branch in the embedding path.

Why this is cleaner:

- Single source of truth for the substitution. The switch already
  knows which positions are control tokens; rewriting input_ids at
  those positions is a natural extension of "decide which adapter is
  active." The decoder is genuinely token-exchange-agnostic — it
  just embeds whatever input_ids it receives.

- HF and vLLM converge to the same control flow. Both backends now
  call switch(...), unpack two outputs, embed once. Previously each
  backend had a near-identical but layout-specific (B,S,H vs T,H)
  embedding-swap block + clone-guard that needed to be maintained
  separately.

- Smaller diff for any future change to the substitution logic.
  Whether to ship a different substitute strategy (e.g. learned
  embedding, per-adapter rules) becomes a one-place change in the
  switch instead of a two-place change across both decoders.

HF model forward also reorders slightly: switch runs before
embed_tokens, so we embed exactly once on modified_input_ids.
create_causal_mask now receives a stub embedding tensor of the right
shape and dtype (it only uses the tensor for batch/query/dtype
inference per the upstream docstring), since the real embedding
hasn't been computed yet.

Tests:
- tests/hf/test_single_switch.py: _run helper unpacks the new tuple
  return; TestBatchProcessing similarly.
- tests/hf/test_token_exchange.py: LUT presence assertion now reads
  model.model.switch.control_to_substitute_lut instead of
  model.model.control_to_substitute_lut.

No behavior change verified by 756 passing tests (= same count as
before the refactor; +0 -0 after fixture updates).
Token-exchange has been the default for several commits. This change
deletes the dead-but-still-callable KV-hiding code path entirely:

Config:
- Drop control_dims, hiding_groups, hiding_policy, adapter_third_party
  parameters and the corresponding state.
- Drop expanded_head_dim, num_hiding_groups, hiding_group_names,
  use_token_exchange properties (token-exchange is now always on when
  num_adapters > 0).
- Drop get_hiding_group_token_ids, get_third_party_adapter_mask,
  get_adapter_hiding_policy_matrix methods.
- adapter_substitute_token_ids becomes required when num_adapters > 0.
- Net: -150 LoC (config.py 345 → 195).

Models:
- HF and vLLM both drop token_to_group_mask / adapter_hiding_matrix
  buffers, hidden_count / adjusted_position_ids logic, and the
  token_group_membership / query_group_suppression plumbing through
  decoder layers.
- The HF decoder layer's forward signature drops two kwargs.

Attention layers (hf/core/lora.py, vllm/core/decoder.py):
- Drop expand_control_dims / control_dims / expanded_head_dim fields.
- Delete _expand_with_control_dimensions method entirely (~85 LoC each).
- Delete the expansion / trim-back branches in forward.
- vllm/core/decoder.py: attn_head_dim is unconditionally head_dim.

Switches:
- Drop config.expanded_head_dim references; head_dim is
  config.projection_head_dim everywhere.

vllm/__init__.py:
- ModelArchConfigConvertor.get_head_size() returns
  config.projection_head_dim (no expansion logic).

Composer:
- compose_granite_switch.py: drop --control-dims and --legacy-hiding
  CLI flags. Delete the legacy-hiding branch in build(); always
  token-exchange.
- compose_utils.py: drop hiding_groups / hiding_policy /
  adapter_third_party kwargs.
- model_card.py: drop control_dims / legacy_hiding /
  use_token_exchange reporting fields.

Tests deleted entirely:
- tests/unit/test_hiding_constant.py
- tests/hf/test_kv_hiding_gap_equivalence.py
- tests/vllm/test_kv_hiding_gap_equivalence.py
- tests/vllm/_kv_hiding_gap_tests.py
- tests/hf/test_position_zero_nan.py
- tests/vllm/_position_zero_nan_tests.py
- tests/integration/test_token_exchange_parity.py (compared old vs new
  modes; with no old mode, nothing to compare).
- tests/composer/test_built_in_adapters.py (entire file tested removed
  Mode A / Mode B distinction).

Tests rewritten:
- tests/conftest.py, tests/unit/test_config{,_edge_cases}.py,
  tests/unit/test_token_exchange.py, tests/hf/test_model_forward.py,
  tests/hf/test_token_exchange.py, tests/hf/test_qk_norm.py,
  tests/shared/granite4_equivalence.py, tests/shared/generation_models.py:
  fixtures and assertions updated for the simpler config surface.

Net diff: ~3000 LoC deleted, ~200 LoC added (test rewrites). 643
tests pass on CPU after the refactor (was 756; the difference is
parameterized hiding-equivalence tests + the parity harness, all
deleted).

Breaking change for any externally-composed checkpoint that was using
control_dims > 0: those checkpoints are unloadable under this version.
The token-exchange path has been the documented default since #8 and
the only path that received the chat-template drops, so any in-flight
build should already be on it.
)

The new switch buffer was failing compose-pipeline validation because
buffer_keywords still listed the deleted legacy buffer names instead of
the new one. Replace token_to_group_mask / adapter_hiding_matrix /
all_hiding_group_token_ids with control_to_substitute_lut in arch.py
and in the two test_granite4_mini parameter-allowlist assertions.
The report described safety margins for the finfo.min K-side hiding
constant. Hiding is gone, so the section is meaningless. Drop the
module, the call site in compose_report.py, and the package re-exports.
Replace control_dims / hiding_groups / hiding_policy / adapter_third_party
references with adapter_substitute_token_ids in test fixtures, and drop
TestControlTokenKVInvisibility (tested the deleted hiding mechanism).

This is a partial sweep — vLLM workers, hf/test_single_switch_e2e.py,
and shared/granite4_equivalence.py still need follow-up edits.
- tests/hf/test_single_switch_e2e.py: drop CONTROL_DIMS_MODES axis; one
  parametrization on attention_multiplier only. Fixture returns a 3-tuple.
- tests/vllm/_generation_equivalence_worker.py and _tp_integration_worker.py:
  remove control_dims/hiding_groups/hiding_policy/adapter_third_party from
  composer calls; pass adapter_substitute_token_ids instead.
- tests/vllm/_single_switch_worker.py: mock_config uses projection_head_dim.
- tests/vllm/test_generation_equivalence.py: docstring updated.
- tests/shared/granite4_equivalence.py: rationale comments updated for
  token-exchange (no behavior change).
- src/granite_switch/composer/compose_utils.py: docstring/comment cleanup.
The vLLM decoder is wrapped in @support_torch_compile; Dynamo cannot
trace data-dependent branching like ``if is_control.any()``. The gate
broke engine init on GPU runs.

Replace it with an unconditional torch.where in both backends — keeps
HF and vLLM symmetric, costs one indexed gather + one elementwise
select per forward, and makes the switch compile-safe.
Three fixes uncovered by GPU run:

1. tests/vllm/_single_switch_worker.py: switch.forward now returns
   (adapter_indices, modified_input_ids); unpack and return only the
   indices. Worker was calling .cpu() on a tuple → every parametrized
   test in tests/vllm/test_single_switch.py failed at the same point.

2. tests/vllm/test_model_forward.py: drop the TestControlTokenKVInvisibility
   class stub. The inner class was deleted with the legacy hiding tests
   in 0ddaf0e, but the parametrized runner still referenced it.

3. tests/vllm/test_position_zero_nan.py: deleted. The inner
   _position_zero_nan_tests.py was removed (only existed for the legacy
   hiding path); the runner became orphan and pytest reported "file or
   directory not found" on every parametrized variant.

The flash_api.cpp:697 "no kernel image" failures in test_model_forward
are pre-existing GPU/FlashAttention environment issues, not branch bugs.
@AlonMalach
Copy link
Copy Markdown
Collaborator Author

A guide for this PR changes:
KV_CACHE_OVERHEAD_REMOVAL.html

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Reduce KV Cache Overhead from Control Dimension Expansion

1 participant