Replace control-token KV hiding with token-exchange (#8)#34
Open
AlonMalach wants to merge 14 commits into
Open
Replace control-token KV hiding with token-exchange (#8)#34AlonMalach wants to merge 14 commits into
AlonMalach wants to merge 14 commits into
Conversation
Adapter control tokens were padded into every Q/K/V in every decoder layer
via `control_dims=32` and masked with `finfo.min`. This bloated the KV
cache head_dim by 25-50% and forced FlashAttention onto padded 160/192-wide
vectors when only `num_hiding_groups` (typically 1) of the 32 extra dims
were ever non-zero.
Switch to token-exchange: after the switch reads `input_ids` and detects
which adapter to activate, replace each control token's embedding with a
substitute real-token embedding before the decoder runs. Control tokens
become ordinary content tokens in the residual stream and `control_dims`
collapses to 0, dropping the expansion entirely.
Substitute ids are computed at compose time:
- ALoRA adapters -> first token of alora_invocation_tokens
- LoRA/builtin adapters -> tokenizer.bos_token_id
New config field `adapter_substitute_token_ids` is persisted in config.json
and drives a `use_token_exchange` property read by both backends. Default
`control_dims` flips from 32 to 0.
The legacy KV-hiding path is preserved as an opt-in escape hatch via the
new `--legacy-hiding` composer flag; any adapter that regresses under
token-exchange can be composed with the old semantics unchanged.
Key validation:
- Reject num_adapters>0 with neither hiding nor substitute ids (would
leak raw control-token embeddings into attention).
- Reject duplicate adapter_token_ids (LUT collision).
- Reject negative / wrong-length substitute ids.
Position correction via `hidden_count` is skipped in token-exchange mode
since control tokens are real positions.
Design: docs/KV_CACHE_OVERHEAD_REMOVAL.md
Tracks issue #8.
Measures four metrics per position, teacher-forced, across a list of
prompts to compare legacy KV-hiding (control_dims>0) vs. token-exchange
(control_dims=0 + substitute ids):
1. KL(p_old || p_new) per position (log_softmax based to avoid underflow)
2. Top-1 agreement (tagged "(noisy)" on wide nuclei)
3. Nucleus (top-p=0.9) Jaccard (sampling-set overlap)
4. Mass under old nucleus by new (the actionable gate)
Results are partitioned into overall / pre-control / adapter-active buckets.
The pre-control bucket must be bit-for-bit identical (KL max == 0, top-1
agree == 1.0); any drift there signals a bug in the embedding-swap gating
rather than a mode trade-off.
Two modes:
- Synthetic (CPU): builds two HF models with identical base weights, one
in legacy hiding and one in token-exchange. Useful as a plumbing check
and regression guard. Runs as a standard pytest.
- Real-model (GPU, opt-in): set GRANITE_SWITCH_PARITY_MODELS='{"old":...,
"new":...}'. Loads composed checkpoints and uses demo-script prompts
(14 adapter-specific prompts from run_adapter_generation_direct.py)
rendered through the composed tokenizer's chat template. Thresholds:
top-1 >= 0.95, mean KL <= 0.02, mean mass-under-old-nucleus >= 0.88.
Also exposes build_demo_prompts() in the demo script. Short-circuits
_generate via a module-level capture flag so prompt text is collected
without touching model.generate. Used by the parity eval to pull
realistic adapter inputs without duplicating the demo prompt data.
CLI usage:
python -m tests.integration.test_token_exchange_parity \
--old /path/to/legacy_build --new /path/to/te_build --json-out report.json
Granite tokenizers alias bos_token_id to <|end_of_text|> (EOS), so the
previous BOS-based substitute for LoRA/builtin adapters would have
injected an end-of-text signal mid-prompt — a stop-generation marker
in a place the model was not trained to see it.
The chat template places the LoRA control token at sequence start,
immediately followed by <|start_of_role|>user<|end_of_role|>... — so
<|start_of_role|> is the deterministic "token that naturally follows"
for every LoRA adapter, and its embedding is well-trained in the base
model (part of the base vocab on Granite 4.0 and 4.1).
Parallels the ALoRA path (substitute = first invocation token).
Both paths now pick "the token that comes right after the control
token in the rendered chat prompt" — single principle, two sources.
Validated:
- tokenizer.convert_tokens_to_ids('<|start_of_role|>') == 100264
on ibm-granite/granite-4.1-3b and granite-4.0-micro (part of
base vocab, not composer-added).
- bos_token_id == eos_token_id == 100257 ('<|end_of_text|>') on
all three Granite tokenizers tested — confirming the prior
default was semantically wrong.
…#8) The runtime embedding swap replaces each adapter control token's embedding with a substitute token's embedding — for LoRA adapters this is <|start_of_role|>, for assistant-boundary ALoRA adapters it's also <|start_of_role|> (the first token of their invocation sequence). But the chat template then emits a *real* <|start_of_role|> at the next position: the user or assistant role marker that naturally follows the control-token prefix. Result before this change: two consecutive positions carrying <|start_of_role|>'s embedding. The model has never seen that pattern during pretraining — a duplicate-embedding OOD right at the start of the decoder's residual stream. Fix: add a skip-once Jinja flag (ns.skip_next_start_of_role). Arm it when lora_prefix_insertion emits the LoRA control token, or when alora_insertion fires the fallback path for assistant-boundary ALoRAs. Wrap every <|start_of_role|> emission in the base Granite template with a skip-once block that consumes the flag. The flag is single-shot — only the very first <|start_of_role|> after the control token is suppressed; all later role markers emit normally. Not addressed in this PR: ALoRAs whose invocation text is in-message content text (<requirements>, <guardian>, <certainty>). The first token of these invocations is the single character '<', and the rest of the invocation text cannot be cleanly sliced at the template level without changing what 'requirements>' (or 'guardian>', etc.) tokenizes to. Those adapters retain the duplicate-embedding pattern until a runtime-level drop lands in a follow-up. Backward compatibility: old checkpoints (composed before this change) load unchanged — the template edit only runs at compose time and affects only newly-composed models. Their rendered output for LoRA and assistant-boundary ALoRA is now one token shorter than before (the suppressed <|start_of_role|>). Update the three test_chat_template tests whose assertions encoded the old contract.
Closes the remaining duplicate-embedding OOD at the swap site. Complements the skip-once <|start_of_role|> edit from the previous commit by extending the same principle to ALoRA adapters whose invocation text lives inside a user message (<requirements>, <certainty>, <guardian>, <context>, etc.). Change: in alora_pass2, after inserting the control token before the invocation text, also drop the first CHARACTER of the invocation text. Example: "Please <|req_check|><requirements>" becomes "Please <|req_check|>requirements>". At runtime the embedding-swap replaces the control token's embedding with the first invocation token's embedding — the embedding of '<'. The decoder then sees [<|req_check|>→e_<, requirements, >] — exactly what "<requirements>" tokenizes to in isolation, with no duplicate. Why this is safe on the Granite tokenizer: verified empirically via a new property test (test_first_char_drop_equals_first_token_drop). For every ALoRA invocation in the standard library, tokenizing the full invocation and dropping the first token ID yields the same sequence as tokenizing the string with its first character removed. BPE's greedy- merge would break this property if the second-byte merges depended on the leading '<'; it doesn't, because '<' tokenizes as its own single- character token in every case. The accompanying test test_first_token_is_single_character asserts the complementary invariant: the first token of each invocation decodes to exactly one character. If a future invocation text starts with a multi-character first token, that test catches it — the Jinja edit (invocation_text[1:] drops one character) would otherwise silently produce a wrong-length drop. Combined with the previous commit (skip-once <|start_of_role|>), the duplicate-embedding pattern is now eliminated across all adapter types in the Granite adapter library: LoRA, assistant-boundary ALoRA, and in-message ALoRA.
Previously the composer hardcoded _LORA_SUBSTITUTE_TOKEN =
"<|start_of_role|>". That's the right answer for Granite 4.x but it
ties the default-path composer to a Granite-specific token name. Any
base model with a different chat template (different role marker,
different turn-open convention) would silently get the wrong
substitute — a token the base model knows, but not the one sitting
at position 1 of its rendered prompt.
Replace the hardcode with a compose-time probe: render a minimal
no-adapter user turn through tokenizer.apply_chat_template, tokenize,
and read input_ids[0]. That's by construction whatever the template
emits at the start of a normal turn, which is exactly what sits at
position 1 after a LoRA-prepended control token. The substitute and
the template's own behavior are now derived from the same source of
truth.
Verified: the probe returns 100264 (<|start_of_role|>) on
granite-4.1-3b, granite-4.0-micro, and granite-switch-4.1-3b-preview
— identical to the previous hardcoded value. Behavior on Granite is
unchanged; the door is open for non-Granite base models.
Error paths give actionable messages:
- Tokenizer has no chat_template → suggest --legacy-hiding
- Template render fails → report the Jinja error, suggest
--legacy-hiding
- First token is <unk> → report that the template emits something
outside the vocab
- Probe returns an empty id list → same
Tests:
- tests/composer/test_lora_substitute_probe.py (7 cases):
* Real tokenizer round-trip on granite-4.1-3b and 4.0-micro
* Synthetic tokenizer with a non-Granite template returns
the custom template's first-token id
* All four error paths raise ValueError with matching messages
Refactor: the runtime substitution LUT and the embedding-swap step
move out of each backend's decoder and into SingleSwitch (HF + vLLM).
The switch now performs both halves of token-exchange:
1. Adapter selection — read input_ids, detect control tokens via
input_ids == adapter_token_ids, emit per-token adapter_indices
(unchanged).
2. Token rewrite — replace each control token's id in input_ids
with its substitute id (from a switch-owned LUT). New.
The switch's forward signature changes from
-> adapter_indices
to
-> (adapter_indices, modified_input_ids)
The decoder consumes both: adapter_indices feeds the LoRA layers as
before, modified_input_ids feeds embed_tokens / get_input_embeddings
exactly once. There is no longer a decoder-side LUT, no scatter, no
clone-guard, no use_token_exchange branch in the embedding path.
Why this is cleaner:
- Single source of truth for the substitution. The switch already
knows which positions are control tokens; rewriting input_ids at
those positions is a natural extension of "decide which adapter is
active." The decoder is genuinely token-exchange-agnostic — it
just embeds whatever input_ids it receives.
- HF and vLLM converge to the same control flow. Both backends now
call switch(...), unpack two outputs, embed once. Previously each
backend had a near-identical but layout-specific (B,S,H vs T,H)
embedding-swap block + clone-guard that needed to be maintained
separately.
- Smaller diff for any future change to the substitution logic.
Whether to ship a different substitute strategy (e.g. learned
embedding, per-adapter rules) becomes a one-place change in the
switch instead of a two-place change across both decoders.
HF model forward also reorders slightly: switch runs before
embed_tokens, so we embed exactly once on modified_input_ids.
create_causal_mask now receives a stub embedding tensor of the right
shape and dtype (it only uses the tensor for batch/query/dtype
inference per the upstream docstring), since the real embedding
hasn't been computed yet.
Tests:
- tests/hf/test_single_switch.py: _run helper unpacks the new tuple
return; TestBatchProcessing similarly.
- tests/hf/test_token_exchange.py: LUT presence assertion now reads
model.model.switch.control_to_substitute_lut instead of
model.model.control_to_substitute_lut.
No behavior change verified by 756 passing tests (= same count as
before the refactor; +0 -0 after fixture updates).
Token-exchange has been the default for several commits. This change
deletes the dead-but-still-callable KV-hiding code path entirely:
Config:
- Drop control_dims, hiding_groups, hiding_policy, adapter_third_party
parameters and the corresponding state.
- Drop expanded_head_dim, num_hiding_groups, hiding_group_names,
use_token_exchange properties (token-exchange is now always on when
num_adapters > 0).
- Drop get_hiding_group_token_ids, get_third_party_adapter_mask,
get_adapter_hiding_policy_matrix methods.
- adapter_substitute_token_ids becomes required when num_adapters > 0.
- Net: -150 LoC (config.py 345 → 195).
Models:
- HF and vLLM both drop token_to_group_mask / adapter_hiding_matrix
buffers, hidden_count / adjusted_position_ids logic, and the
token_group_membership / query_group_suppression plumbing through
decoder layers.
- The HF decoder layer's forward signature drops two kwargs.
Attention layers (hf/core/lora.py, vllm/core/decoder.py):
- Drop expand_control_dims / control_dims / expanded_head_dim fields.
- Delete _expand_with_control_dimensions method entirely (~85 LoC each).
- Delete the expansion / trim-back branches in forward.
- vllm/core/decoder.py: attn_head_dim is unconditionally head_dim.
Switches:
- Drop config.expanded_head_dim references; head_dim is
config.projection_head_dim everywhere.
vllm/__init__.py:
- ModelArchConfigConvertor.get_head_size() returns
config.projection_head_dim (no expansion logic).
Composer:
- compose_granite_switch.py: drop --control-dims and --legacy-hiding
CLI flags. Delete the legacy-hiding branch in build(); always
token-exchange.
- compose_utils.py: drop hiding_groups / hiding_policy /
adapter_third_party kwargs.
- model_card.py: drop control_dims / legacy_hiding /
use_token_exchange reporting fields.
Tests deleted entirely:
- tests/unit/test_hiding_constant.py
- tests/hf/test_kv_hiding_gap_equivalence.py
- tests/vllm/test_kv_hiding_gap_equivalence.py
- tests/vllm/_kv_hiding_gap_tests.py
- tests/hf/test_position_zero_nan.py
- tests/vllm/_position_zero_nan_tests.py
- tests/integration/test_token_exchange_parity.py (compared old vs new
modes; with no old mode, nothing to compare).
- tests/composer/test_built_in_adapters.py (entire file tested removed
Mode A / Mode B distinction).
Tests rewritten:
- tests/conftest.py, tests/unit/test_config{,_edge_cases}.py,
tests/unit/test_token_exchange.py, tests/hf/test_model_forward.py,
tests/hf/test_token_exchange.py, tests/hf/test_qk_norm.py,
tests/shared/granite4_equivalence.py, tests/shared/generation_models.py:
fixtures and assertions updated for the simpler config surface.
Net diff: ~3000 LoC deleted, ~200 LoC added (test rewrites). 643
tests pass on CPU after the refactor (was 756; the difference is
parameterized hiding-equivalence tests + the parity harness, all
deleted).
Breaking change for any externally-composed checkpoint that was using
control_dims > 0: those checkpoints are unloadable under this version.
The token-exchange path has been the documented default since #8 and
the only path that received the chat-template drops, so any in-flight
build should already be on it.
) The new switch buffer was failing compose-pipeline validation because buffer_keywords still listed the deleted legacy buffer names instead of the new one. Replace token_to_group_mask / adapter_hiding_matrix / all_hiding_group_token_ids with control_to_substitute_lut in arch.py and in the two test_granite4_mini parameter-allowlist assertions.
The report described safety margins for the finfo.min K-side hiding constant. Hiding is gone, so the section is meaningless. Drop the module, the call site in compose_report.py, and the package re-exports.
Replace control_dims / hiding_groups / hiding_policy / adapter_third_party references with adapter_substitute_token_ids in test fixtures, and drop TestControlTokenKVInvisibility (tested the deleted hiding mechanism). This is a partial sweep — vLLM workers, hf/test_single_switch_e2e.py, and shared/granite4_equivalence.py still need follow-up edits.
- tests/hf/test_single_switch_e2e.py: drop CONTROL_DIMS_MODES axis; one parametrization on attention_multiplier only. Fixture returns a 3-tuple. - tests/vllm/_generation_equivalence_worker.py and _tp_integration_worker.py: remove control_dims/hiding_groups/hiding_policy/adapter_third_party from composer calls; pass adapter_substitute_token_ids instead. - tests/vllm/_single_switch_worker.py: mock_config uses projection_head_dim. - tests/vllm/test_generation_equivalence.py: docstring updated. - tests/shared/granite4_equivalence.py: rationale comments updated for token-exchange (no behavior change). - src/granite_switch/composer/compose_utils.py: docstring/comment cleanup.
The vLLM decoder is wrapped in @support_torch_compile; Dynamo cannot trace data-dependent branching like ``if is_control.any()``. The gate broke engine init on GPU runs. Replace it with an unconditional torch.where in both backends — keeps HF and vLLM symmetric, costs one indexed gather + one elementwise select per forward, and makes the switch compile-safe.
Three fixes uncovered by GPU run: 1. tests/vllm/_single_switch_worker.py: switch.forward now returns (adapter_indices, modified_input_ids); unpack and return only the indices. Worker was calling .cpu() on a tuple → every parametrized test in tests/vllm/test_single_switch.py failed at the same point. 2. tests/vllm/test_model_forward.py: drop the TestControlTokenKVInvisibility class stub. The inner class was deleted with the legacy hiding tests in 0ddaf0e, but the parametrized runner still referenced it. 3. tests/vllm/test_position_zero_nan.py: deleted. The inner _position_zero_nan_tests.py was removed (only existed for the legacy hiding path); the runner became orphan and pytest reported "file or directory not found" on every parametrized variant. The flash_api.cpp:697 "no kernel image" failures in test_model_forward are pre-existing GPU/FlashAttention environment issues, not branch bugs.
Collaborator
Author
|
A guide for this PR changes: |
3 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Closes #8. Replaces the legacy KV-hiding scheme (which padded every Q/K/V in every decoder layer with
control_dims=32extra dimensions and branded the control token's K withfinfo.min) with token exchange: at the switch layer, each control token's id is rewritten ininput_idsto a substitute id whose embedding the model was already trained to handle. The decoder embeds the rewritten ids once and runs natively — no Q/K/V padding, no KV cache expansion, no hidden-count position correction.Outcome on Granite 4.1-8b: KV cache
head_dimreturns to native 128 (from expanded 160, padded to 192). FlashAttention runs on native vectors. ~20% less KV memory, ~33% less attention compute per layer. No retraining required.How it works
Two paired layers:
tokenizer_setup.py): a skip-once Jinja flag (ns.skip_next_start_of_role) suppresses the role-marker token that would normally follow each control token, andalora_pass2drops the first character of in-message ALoRA invocation text (BPE-equivalent to dropping one tokenized piece). The rendered sequence is one token shorter than before.hf/switch/single.py+vllm/switch/single.py): when emittingadapter_indices, also rewrite each control token's id via acontrol_to_substitute_lutbuffer. Returns(adapter_indices, modified_input_ids). The decoder unpacks and embedsmodified_input_ids.Substitute derivation:
alora_invocation_tokens(read from adapter'sadapter_config.json)._probe_lora_substitute_token_id(tokenizer)— render a minimal probe chat, tokenize, takeinput_ids[0]. On Granite 4.x this resolves to<|start_of_role|>(id 100264).What's removed
The legacy KV-hiding path is gone, not gated.
_expand_with_control_dimensionsdeleted from both backends.control_dims,hiding_groups,hiding_policy,adapter_third_party,expanded_head_dim,num_hiding_groups,get_hiding_group_token_ids,get_third_party_adapter_mask,get_adapter_hiding_policy_matrixdeleted from config.adapter_substitute_token_idsis required whennum_adapters > 0. ~3000 LoC deleted net.Backward compatibility
Breaking. Checkpoints composed under the legacy scheme cannot load —
from_pretrainedraisesValueErrorfrom the newadapter_substitute_token_ids is requiredvalidator. Users must recompose with the currentcompose_granite_switch.pyagainst the same base + adapter sources (minutes per checkpoint).