From 19b633a2fe8b650a79883df8a01bf7b06ea5f95a Mon Sep 17 00:00:00 2001 From: AlonMalach Date: Tue, 12 May 2026 18:19:28 +0300 Subject: [PATCH 01/18] Replace control-token KV hiding with token-exchange by default (#8) Adapter control tokens were padded into every Q/K/V in every decoder layer via `control_dims=32` and masked with `finfo.min`. This bloated the KV cache head_dim by 25-50% and forced FlashAttention onto padded 160/192-wide vectors when only `num_hiding_groups` (typically 1) of the 32 extra dims were ever non-zero. Switch to token-exchange: after the switch reads `input_ids` and detects which adapter to activate, replace each control token's embedding with a substitute real-token embedding before the decoder runs. Control tokens become ordinary content tokens in the residual stream and `control_dims` collapses to 0, dropping the expansion entirely. Substitute ids are computed at compose time: - ALoRA adapters -> first token of alora_invocation_tokens - LoRA/builtin adapters -> tokenizer.bos_token_id New config field `adapter_substitute_token_ids` is persisted in config.json and drives a `use_token_exchange` property read by both backends. Default `control_dims` flips from 32 to 0. The legacy KV-hiding path is preserved as an opt-in escape hatch via the new `--legacy-hiding` composer flag; any adapter that regresses under token-exchange can be composed with the old semantics unchanged. Key validation: - Reject num_adapters>0 with neither hiding nor substitute ids (would leak raw control-token embeddings into attention). - Reject duplicate adapter_token_ids (LUT collision). - Reject negative / wrong-length substitute ids. Position correction via `hidden_count` is skipped in token-exchange mode since control tokens are real positions. Design: docs/KV_CACHE_OVERHEAD_REMOVAL.md Tracks issue #8. --- .../composer/compose_granite_switch.py | 63 ++++++-- src/granite_switch/composer/compose_utils.py | 6 + .../composer/reporting/model_card.py | 5 + .../composer/tokenizer_setup.py | 27 +++- src/granite_switch/config.py | 61 +++++++- .../hf/modeling_granite_switch.py | 44 +++++- .../vllm/granite_switch_model.py | 47 +++++- tests/composer/test_built_in_adapters.py | 17 ++- tests/hf/test_model_forward.py | 3 +- tests/hf/test_token_exchange.py | 141 ++++++++++++++++++ tests/shared/generation_models.py | 8 +- tests/shared/granite4_equivalence.py | 3 + tests/unit/test_config.py | 8 +- tests/unit/test_config_edge_cases.py | 20 ++- tests/unit/test_token_exchange.py | 102 +++++++++++++ 15 files changed, 517 insertions(+), 38 deletions(-) create mode 100644 tests/hf/test_token_exchange.py create mode 100644 tests/unit/test_token_exchange.py diff --git a/src/granite_switch/composer/compose_granite_switch.py b/src/granite_switch/composer/compose_granite_switch.py index 08e1786..b6048ba 100755 --- a/src/granite_switch/composer/compose_granite_switch.py +++ b/src/granite_switch/composer/compose_granite_switch.py @@ -62,6 +62,7 @@ from granite_switch.composer.tokenizer_setup import ( add_control_tokens, configure_chat_template, + get_alora_first_invocation_token_id, ) from granite_switch.composer.reporting import generate_compose_report, write_build_doc @@ -453,7 +454,17 @@ def _compose_argparser(): "--control-dims", type=int, default=None, - help="Extra dims for K/V to mask control tokens in decoder layers", + help="Extra dims for K/V to mask control tokens in decoder layers. " + "Default: 0 (token-exchange mode). Set to >=1 and pass --legacy-hiding " + "only if a specific adapter regresses under token exchange.", + ) + parser.add_argument( + "--legacy-hiding", + action="store_true", + default=False, + help="Use the legacy KV-hiding path (control_dims=32, no embedding " + "substitution). Escape hatch for adapters that regress under the " + "default token-exchange mode.", ) parser.add_argument( "--built-in-adapters", @@ -750,26 +761,52 @@ def build(): if args.control_dims is not None: optional_kwargs["control_dims"] = args.control_dims - # Per-mode hiding configuration - if build_mode == "native": - # Mode A (native): no hiding, control_dims=0 (unless overridden) + # Control-token handling mode: token-exchange by default, legacy hiding on opt-in. + if args.legacy_hiding: + # Legacy: keep today's KV-hiding scheme. control_dims must be > 0. + if optional_kwargs.get("control_dims", 0) == 0: + optional_kwargs["control_dims"] = 32 + adapter_substitute_token_ids = None + # Hiding groups only apply in third-party mode; native mode still + # has no hiding and no substitution, but --legacy-hiding forces + # control_dims > 0 so the config validator accepts it. + if build_mode == "native": + hiding_groups = None + hiding_policy = None + adapter_third_party = None + else: + hiding_groups = {"all_controls": list(adapter_names)} + hiding_policy = {name: ["all_controls"] for name in adapter_names} + hiding_policy["base"] = ["all_controls"] + adapter_third_party = list(external_names) + else: + # Default: token-exchange. control_dims=0; every adapter needs a substitute id. + # ALoRA adapters substitute with the first token of their invocation sequence; + # LoRA/builtin adapters substitute with BOS (only required when at least one + # non-ALoRA adapter is present — ALoRA-only builds don't need BOS). + adapter_substitute_token_ids = [] + for adapter_path, _name, technology, _source in all_discovered: + if technology == "alora": + sub_id = get_alora_first_invocation_token_id(adapter_path) + else: + if tokenizer.bos_token_id is None: + raise ValueError( + "Tokenizer has no bos_token_id; required for LoRA/builtin " + "token exchange. Pass --legacy-hiding to use the KV-hiding " + "path instead." + ) + sub_id = tokenizer.bos_token_id + adapter_substitute_token_ids.append(sub_id) + # Token-exchange supersedes KV hiding — no hiding config needed. hiding_groups = None hiding_policy = None adapter_third_party = None - if "control_dims" not in optional_kwargs: - optional_kwargs["control_dims"] = 0 - else: - # Mode B (third-party): full hiding for external adapters - hiding_groups = {"all_controls": list(adapter_names)} - hiding_policy = {name: ["all_controls"] for name in adapter_names} - hiding_policy["base"] = ["all_controls"] - # Only external adapters are third-party - adapter_third_party = list(external_names) model = GraniteSwitchComposer.from_base_and_adapters( base_model_name_or_path=base_model_local_path, adapter_paths=adapter_paths, adapter_token_ids=adapter_token_ids, + adapter_substitute_token_ids=adapter_substitute_token_ids, adapter_names=adapter_names, hiding_groups=hiding_groups, hiding_policy=hiding_policy, diff --git a/src/granite_switch/composer/compose_utils.py b/src/granite_switch/composer/compose_utils.py index d230f27..2690a2f 100644 --- a/src/granite_switch/composer/compose_utils.py +++ b/src/granite_switch/composer/compose_utils.py @@ -25,6 +25,7 @@ def from_base_and_adapters( base_model_name_or_path: str, adapter_paths: Optional[List[str]] = None, adapter_token_ids: Optional[List[int]] = None, + adapter_substitute_token_ids: Optional[List[int]] = None, adapter_names: Optional[List[str]] = None, built_in_adapter_names: Optional[List[str]] = None, built_in_lora_rank: int = 8, @@ -48,6 +49,10 @@ def from_base_and_adapters( empty for zero-adapter skinning (base model only). adapter_token_ids: Token IDs for adapter control. Required when ``adapter_paths`` is non-empty. + adapter_substitute_token_ids: Token IDs whose embeddings replace + control-token embeddings in token-exchange mode. One per adapter. + Pass ``None`` to run the legacy KV-hiding path (requires + ``control_dims > 0`` in ``**kwargs``). adapter_names: Display names for each adapter (external + built-in). When ``None``, derived from the directory structure. built_in_adapter_names: Names for built-in (empty LoRA) adapter slots. @@ -151,6 +156,7 @@ def from_base_and_adapters( { "num_adapters": num_total, "adapter_token_ids": adapter_token_ids, + "adapter_substitute_token_ids": adapter_substitute_token_ids, "adapter_names": adapter_names, "hiding_groups": hiding_groups, "hiding_policy": hiding_policy, diff --git a/src/granite_switch/composer/reporting/model_card.py b/src/granite_switch/composer/reporting/model_card.py index 721e4cb..4bb2cd1 100644 --- a/src/granite_switch/composer/reporting/model_card.py +++ b/src/granite_switch/composer/reporting/model_card.py @@ -392,6 +392,11 @@ def _short_source(source): "lora_alpha": getattr(args, "lora_alpha", None) if built_in else None, "switch_head_dim": getattr(args, "switch_head_dim", None), "control_dims": getattr(args, "control_dims", None), + "legacy_hiding": getattr(args, "legacy_hiding", False), + "use_token_exchange": getattr(model.config, "use_token_exchange", False), + "adapter_substitute_token_ids": getattr( + model.config, "adapter_substitute_token_ids", None + ), "target_model": getattr(args, "target_model", None), } # Parameter counts: base is captured during transfer (see diff --git a/src/granite_switch/composer/tokenizer_setup.py b/src/granite_switch/composer/tokenizer_setup.py index c5ba7b5..a437af4 100644 --- a/src/granite_switch/composer/tokenizer_setup.py +++ b/src/granite_switch/composer/tokenizer_setup.py @@ -11,12 +11,8 @@ from typing import Dict, List, Optional, Tuple -def _decode_alora_invocation_text(adapter_path: str, tokenizer) -> str: - """Decode alora_invocation_tokens from adapter_config.json to a string. - - The activation control token must be inserted immediately before the first - token of the invocation sequence. Decoding the full sequence gives the text - span to search for in the rendered message content. +def _load_alora_invocation_token_ids(adapter_path: str) -> List[int]: + """Load alora_invocation_tokens from adapter_config.json. Raises: FileNotFoundError: If adapter_config.json is not found at adapter_path. @@ -31,10 +27,29 @@ def _decode_alora_invocation_text(adapter_path: str, tokenizer) -> str: raise ValueError( f"alora_invocation_tokens is missing or empty in {config_path}" ) + return token_ids + + +def _decode_alora_invocation_text(adapter_path: str, tokenizer) -> str: + """Decode alora_invocation_tokens from adapter_config.json to a string. + The activation control token must be inserted immediately before the first + token of the invocation sequence. Decoding the full sequence gives the text + span to search for in the rendered message content. + """ + token_ids = _load_alora_invocation_token_ids(adapter_path) return tokenizer.decode(token_ids, skip_special_tokens=False) +def get_alora_first_invocation_token_id(adapter_path: str) -> int: + """Return the first token ID of an ALoRA adapter's invocation sequence. + + Used by token-exchange mode to substitute this embedding for the adapter's + control token before the decoder runs. + """ + return _load_alora_invocation_token_ids(adapter_path)[0] + + def add_control_tokens( tokenizer, discovered_adapters: List[Tuple[Optional[str], str, str, Optional[str]]], diff --git a/src/granite_switch/config.py b/src/granite_switch/config.py index 026797e..d1380a6 100644 --- a/src/granite_switch/config.py +++ b/src/granite_switch/config.py @@ -19,15 +19,21 @@ class GraniteSwitchConfig(GraniteMoeHybridConfig): num_adapters (int): Number of LoRA adapters available. Default: 0 (no adapters). This counts real LoRA adapters only (not base). Index 0 always means "base / no adapter". adapter_token_ids (List[int]): Token IDs for adapter control. - Length: num_adapters (one token per real adapter). + Length: num_adapters (one token per real adapter). Must be unique. adapter_token_ids[i] activates adapter i+1 (1-indexed output). Output 0 = base (implicit default, no token needed to return to base). NOTE: SingleSwitch cannot transition back to base mid-sequence. + adapter_substitute_token_ids (List[int]): Token IDs whose embeddings replace + the control-token embeddings before the decoder runs (token-exchange mode). + Length: num_adapters. When provided together with control_dims=0, the model + uses token exchange instead of KV hiding. SingleSwitch parameters: control_token_gain (float): Attention gain for control/non-control separation. Default: 15.0. switch_head_dim (int): Dimension of Q/K/V vectors in switch attention. Default: 32. - control_dims (int): Extra dimensions for K/V to mask control tokens. Must be >= 0. Default: 32. + control_dims (int): Extra dimensions for K/V to mask control tokens. Must be >= 0. + Default: 0 (token-exchange mode — adapter_substitute_token_ids must be provided). + Set >= 1 to enable the legacy KV-hiding path. adapter_names (List[str]): Ordered adapter names for name-to-index mapping. Used by hiding_groups and hiding_policy to resolve names to indices. @@ -55,10 +61,11 @@ def __init__( self, num_adapters: int = 0, adapter_token_ids: Optional[List[int]] = None, + adapter_substitute_token_ids: Optional[List[int]] = None, # SingleSwitch parameters control_token_gain: float = 15.0, switch_head_dim: int = 32, - control_dims: int = 32, + control_dims: int = 0, # Hiding groups and policy adapter_names: Optional[List[str]] = None, hiding_groups: Optional[Dict[str, List[str]]] = None, @@ -109,8 +116,33 @@ def __init__( f"adapter_token_ids length ({len(adapter_token_ids)}) must equal " f"num_adapters ({num_adapters})." ) + # Token-exchange builds the control→substitute LUT keyed by adapter token id; + # duplicates would silently collapse to a single slot. + if len(set(adapter_token_ids)) != len(adapter_token_ids): + raise ValueError( + f"adapter_token_ids must be unique; got {adapter_token_ids}" + ) self.adapter_token_ids = adapter_token_ids + # Validate adapter_substitute_token_ids if provided + if num_adapters > 0 and adapter_substitute_token_ids is not None: + if len(adapter_substitute_token_ids) != num_adapters: + raise ValueError( + f"adapter_substitute_token_ids length ({len(adapter_substitute_token_ids)}) " + f"must equal num_adapters ({num_adapters})." + ) + if any(sid < 0 for sid in adapter_substitute_token_ids): + raise ValueError( + f"adapter_substitute_token_ids must all be >= 0 (real token ids); " + f"got {adapter_substitute_token_ids}" + ) + if adapter_token_ids is None: + raise ValueError( + "adapter_token_ids is required when adapter_substitute_token_ids " + "is provided (token-exchange mode maps control ids to substitute ids)." + ) + self.adapter_substitute_token_ids = adapter_substitute_token_ids + # SingleSwitch parameters self.control_token_gain = control_token_gain self.switch_head_dim = switch_head_dim @@ -123,6 +155,20 @@ def __init__( self.control_dims = control_dims self.fused_add_norm = fused_add_norm + # Control tokens need one of two handling paths when adapters are present: + # legacy KV hiding (control_dims > 0) or token exchange (substitute ids present). + # The combination of neither would leak raw control-token embeddings into attention. + if ( + num_adapters > 0 + and control_dims == 0 + and adapter_substitute_token_ids is None + ): + raise ValueError( + "When num_adapters > 0, either control_dims > 0 (legacy KV hiding) " + "or adapter_substitute_token_ids (token exchange) must be provided. " + "Neither is set, which would leave control tokens unhandled." + ) + # Hiding groups and policy self.adapter_names = adapter_names self.hiding_groups = hiding_groups @@ -200,6 +246,15 @@ def expanded_head_dim(self) -> int: return self.projection_head_dim + self.control_dims return self.projection_head_dim + @property + def use_token_exchange(self) -> bool: + """True when control tokens are replaced with substitute embeddings (vs. KV hiding).""" + return ( + self.num_adapters > 0 + and self.control_dims == 0 + and self.adapter_substitute_token_ids is not None + ) + @property def num_hiding_groups(self) -> int: """Number of hiding groups (each uses one control dimension).""" diff --git a/src/granite_switch/hf/modeling_granite_switch.py b/src/granite_switch/hf/modeling_granite_switch.py index 277d947..518bdc0 100644 --- a/src/granite_switch/hf/modeling_granite_switch.py +++ b/src/granite_switch/hf/modeling_granite_switch.py @@ -188,6 +188,27 @@ def __init__(self, config: GraniteSwitchConfig): torch.zeros(config.num_adapters, dtype=torch.long), ) + # --- Token-exchange buffers --- + # control_to_substitute_lut: [vocab_size], -1 for non-control token ids + # and the substitute token id at each control slot. Lets the embedding + # swap run as a single gather + masked scatter without allocating an + # [B, S, num_adapters] intermediate. + if config.use_token_exchange: + sub_ids = config.adapter_substitute_token_ids + self.register_buffer( + "adapter_substitute_token_ids", + torch.tensor(sub_ids, dtype=torch.long), + ) + max_ctrl_id = max(config.adapter_token_ids) + lut_size = max(config.vocab_size, max_ctrl_id + 1) + lut = torch.full((lut_size,), -1, dtype=torch.long) + for ctrl_id, sub_id in zip(config.adapter_token_ids, sub_ids): + lut[ctrl_id] = sub_id + self.register_buffer("control_to_substitute_lut", lut) + else: + self.adapter_substitute_token_ids = None + self.control_to_substitute_lut = None + # --- Hiding group buffers --- # token_to_group_mask: [vocab_size, num_groups] lookup table. # For each token ID, True at group g if that token belongs to group g. @@ -224,6 +245,8 @@ def __init__(self, config: GraniteSwitchConfig): else: self.switch = None self.adapter_token_ids = None + self.adapter_substitute_token_ids = None + self.control_to_substitute_lut = None self.token_to_group_mask = None self.adapter_hiding_matrix = None @@ -287,9 +310,26 @@ def forward( ) use_cache = False + inputs_embeds_owned = inputs_embeds is None if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + # Token exchange: replace each control token's embedding with the + # substitute token's embedding before the embedding multiplier so + # both paths receive the same scaling. The LUT maps control token + # ids to substitute ids; non-control positions produce -1. + if self.config.use_token_exchange and input_ids is not None: + sub_id_per_pos = self.control_to_substitute_lut[input_ids] + is_control = sub_id_per_pos >= 0 + if is_control.any(): + flat_sub_ids = sub_id_per_pos[is_control] + sub_embeds = self.embed_tokens(flat_sub_ids) + # Only clone when the caller owns the tensor; if we just + # allocated it via embed_tokens(input_ids), mutating is safe. + if not inputs_embeds_owned: + inputs_embeds = inputs_embeds.clone() + inputs_embeds[is_control] = sub_embeds + inputs_embeds = inputs_embeds * self.embedding_multiplier # Initialize cache @@ -339,7 +379,9 @@ def forward( # Compute hidden_count for position correction (SingleSwitch). # SingleSwitch fires once: hidden_count is 0 before the control # token and 1 at/after it, which is exactly (adapter_indices > 0). - if hidden_count is None: + # In token-exchange mode control tokens become real positions, so + # the correction is a no-op — skip it rather than subtract zeros. + if hidden_count is None and not self.config.use_token_exchange: hidden_count = (adapter_indices > 0).long() else: batch_size, seq_length = inputs_embeds.shape[:2] diff --git a/src/granite_switch/vllm/granite_switch_model.py b/src/granite_switch/vllm/granite_switch_model.py index b94fb61..7a8f7b6 100644 --- a/src/granite_switch/vllm/granite_switch_model.py +++ b/src/granite_switch/vllm/granite_switch_model.py @@ -155,6 +155,25 @@ def __init__( torch.zeros(num_adapters, dtype=torch.long), ) + # --- Token-exchange LUT --- + # See the HF model for the shared rationale. -1 indicates "not a + # control token"; other positions map control id → substitute id. + if config.use_token_exchange: + sub_ids = config.adapter_substitute_token_ids + self.register_buffer( + "adapter_substitute_token_ids", + torch.tensor(sub_ids, dtype=torch.long), + ) + max_ctrl_id = max(config.adapter_token_ids) + lut_size = max(config.vocab_size, max_ctrl_id + 1) + lut = torch.full((lut_size,), -1, dtype=torch.long) + for ctrl_id, sub_id in zip(config.adapter_token_ids, sub_ids): + lut[ctrl_id] = sub_id + self.register_buffer("control_to_substitute_lut", lut) + else: + self.adapter_substitute_token_ids = None + self.control_to_substitute_lut = None + # Initialize compile-friendly LoRA metadata handler # This replaces vLLM's LoRAKernelMeta with a torch.compile-compatible version # that avoids data-dependent branching @@ -193,6 +212,8 @@ def __init__( else: self.switch = None self.adapter_token_ids = None + self.adapter_substitute_token_ids = None + self.control_to_substitute_lut = None self.lora_meta = None self.token_to_group_mask = None self.adapter_hiding_matrix = None @@ -355,8 +376,10 @@ def forward( token_group_membership = None query_group_suppression = None - # Compute hidden_count for position correction (SingleSwitch) - if hidden_count is None: + # Compute hidden_count for position correction (SingleSwitch). + # In token-exchange mode control tokens are real positions, so + # skip the correction entirely rather than subtract zeros. + if hidden_count is None and not self.config.use_token_exchange: hidden_count = (adapter_indices > 0).long() # Position correction: adjust positions to close gaps from hidden tokens. @@ -402,8 +425,9 @@ def forward( intermediate_tensors, "query_group_suppression", ) ) - hidden_count = (adapter_indices > 0).long() - positions = torch.clamp(positions - hidden_count, min=0) + if not self.config.use_token_exchange: + hidden_count = (adapter_indices > 0).long() + positions = torch.clamp(positions - hidden_count, min=0) else: # Fallback: no metadata available (should not happen in normal operation) num_tokens = input_ids.shape[0] if input_ids is not None else 0 @@ -419,8 +443,23 @@ def forward( if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds + hidden_states_owned = False else: hidden_states = self.get_input_embeddings(input_ids) + hidden_states_owned = True + + # Token exchange: mirror of the HF path. vLLM tensors are flat + # [num_tokens, hidden]; the gather + masked scatter runs pre- + # multiplier so both raw and substitute embeddings are scaled once. + if self.config.use_token_exchange and input_ids is not None: + sub_id_per_pos = self.control_to_substitute_lut[input_ids] + is_control = sub_id_per_pos >= 0 + if is_control.any(): + flat_sub_ids = sub_id_per_pos[is_control] + sub_embeds = self.get_input_embeddings(flat_sub_ids) + if not hidden_states_owned: + hidden_states = hidden_states.clone() + hidden_states[is_control] = sub_embeds hidden_states *= self.config.embedding_multiplier residual = None diff --git a/tests/composer/test_built_in_adapters.py b/tests/composer/test_built_in_adapters.py index 783fe5e..813dc25 100644 --- a/tests/composer/test_built_in_adapters.py +++ b/tests/composer/test_built_in_adapters.py @@ -20,7 +20,7 @@ @pytest.fixture def mode_a_config(): - """Mode A (native): built-in adapters only, control_dims=0.""" + """Mode A (native): built-in adapters only, control_dims=0, token-exchange.""" return GraniteSwitchConfig( vocab_size=300, hidden_size=64, @@ -30,6 +30,8 @@ def mode_a_config(): num_key_value_heads=4, num_adapters=2, adapter_token_ids=[250, 251], + # Built-in adapters: substitute = BOS (arbitrary id 1 for tests). + adapter_substitute_token_ids=[1, 1], adapter_names=["router", "planner"], max_lora_rank=4, adapter_ranks=[4, 4], @@ -244,7 +246,13 @@ def test_control_dims_negative_rejected(self): ) def test_hiding_groups_require_control_dims(self): - """Hiding groups with control_dims=0 should be rejected.""" + """Hiding groups require control_dims >= num_hiding_groups. + + A build with 1 hiding group and control_dims=1 works; with 2 groups and + control_dims=1 it must fail. Substitute ids are supplied only to get + past the newer "no-hiding-and-no-exchange" validator; the assertion is + specifically about the hiding-vs-control_dims arithmetic. + """ with pytest.raises(ValueError, match="control_dims.*must be >= number of hiding groups"): GraniteSwitchConfig( vocab_size=300, @@ -255,9 +263,10 @@ def test_hiding_groups_require_control_dims(self): num_key_value_heads=4, num_adapters=2, adapter_token_ids=[250, 251], + adapter_substitute_token_ids=[1, 2], adapter_names=["a", "b"], - hiding_groups={"all_controls": ["a", "b"]}, + hiding_groups={"g1": ["a"], "g2": ["b"]}, # 2 groups > control_dims=1 max_lora_rank=4, adapter_ranks=[4, 4], - control_dims=0, # Too few for 1 hiding group + control_dims=1, ) diff --git a/tests/hf/test_model_forward.py b/tests/hf/test_model_forward.py index 11cdedc..489264f 100644 --- a/tests/hf/test_model_forward.py +++ b/tests/hf/test_model_forward.py @@ -361,7 +361,7 @@ def test_activating_adapter_indices_nonzero(self, tiny_single_config): @pytest.fixture def tiny_native_config(): - """Minimal config for native mode (control_dims=0, no hiding).""" + """Minimal config for native mode (control_dims=0, token-exchange).""" return GraniteSwitchConfig( vocab_size=300, hidden_size=64, @@ -371,6 +371,7 @@ def tiny_native_config(): num_key_value_heads=4, num_adapters=2, adapter_token_ids=[250, 251], + adapter_substitute_token_ids=[1, 1], adapter_names=["router", "planner"], max_lora_rank=4, adapter_ranks=[4, 4], diff --git a/tests/hf/test_token_exchange.py b/tests/hf/test_token_exchange.py new file mode 100644 index 0000000..e877bb7 --- /dev/null +++ b/tests/hf/test_token_exchange.py @@ -0,0 +1,141 @@ +# SPDX-License-Identifier: Apache-2.0 +"""HF backend tests for token-exchange mode. + +Three properties under test: +1. The embedding at each control-token position equals the embedding of its + substitute token (scaled by embedding_multiplier), not the original + control-token embedding. +2. The KV cache head_dim is the native projection_head_dim (not expanded) — + this is the load-bearing correctness property that proves control_dims=0 + actually eliminates KV-cache overhead. +3. Legacy hiding mode is untouched: same inputs, but the embedding at the + control position matches the raw control-token embedding and the KV cache + head_dim is projection_head_dim + control_dims. +""" + +import pytest +import torch + +from granite_switch.config import GraniteSwitchConfig +from granite_switch.hf import GraniteSwitchForCausalLM + + +def _build(num_adapters=2, control_dims=0, substitute_ids=(1, 7)): + kwargs = dict( + vocab_size=200, + hidden_size=32, + num_attention_heads=4, + num_key_value_heads=2, + num_hidden_layers=2, + intermediate_size=64, + shared_intermediate_size=64, + max_position_embeddings=64, + mamba_n_heads=1, + mamba_expand=1, + num_adapters=num_adapters, + adapter_ranks=[4] * num_adapters, + max_lora_rank=4, + adapter_token_ids=[100, 101][:num_adapters], + adapter_names=["a", "b"][:num_adapters], + control_dims=control_dims, + torch_dtype=torch.float32, + ) + if control_dims == 0: + kwargs["adapter_substitute_token_ids"] = list(substitute_ids[:num_adapters]) + return GraniteSwitchConfig(**kwargs) + + +@torch.no_grad() +def _forward(config, input_ids): + model = GraniteSwitchForCausalLM(config).eval() + return model, model(input_ids=input_ids, use_cache=True) + + +class TestTokenExchangeEmbeddingSwap: + """The control position's residual-stream input is the substitute embedding.""" + + def test_swap_picks_substitute_embedding(self): + config = _build(control_dims=0, substitute_ids=(5, 7)) + model, _ = _forward( + config, + torch.tensor([[10, 20, 100, 40]], dtype=torch.long), # adapter 0 control at pos 2 + ) + # The LUT maps control id 100 → substitute 5. + lut = model.model.control_to_substitute_lut + assert lut[100].item() == 5 + assert lut[101].item() == 7 + # Positions without control tokens map to -1. + assert lut[10].item() == -1 + assert lut[40].item() == -1 + + def test_swap_is_not_applied_on_non_control_positions(self): + config = _build(control_dims=0, substitute_ids=(5, 7)) + model = GraniteSwitchForCausalLM(config).eval() + # Run once through the model with a control token and once without; + # verify the non-control embedding rows are identical. + raw_a = model.model.embed_tokens(torch.tensor([[10, 20, 30, 40]], dtype=torch.long)) + raw_b = model.model.embed_tokens(torch.tensor([[10, 20, 100, 40]], dtype=torch.long)) + # Positions 0, 1, 3 should match; position 2 is the control token (differs). + assert torch.allclose(raw_a[:, 0], raw_b[:, 0]) + assert torch.allclose(raw_a[:, 1], raw_b[:, 1]) + assert torch.allclose(raw_a[:, 3], raw_b[:, 3]) + + +class TestKVCacheHeadDim: + """The load-bearing correctness property: control_dims=0 collapses KV head_dim.""" + + def test_token_exchange_native_head_dim(self): + config = _build(control_dims=0, substitute_ids=(5, 7)) + _, out = _forward( + config, + torch.tensor([[10, 20, 100, 40]], dtype=torch.long), + ) + # layers[0] is the switch; layers[1] is the first decoder layer. + decoder_key = out.past_key_values.layers[1].keys + assert decoder_key.shape[-1] == config.projection_head_dim + assert config.use_token_exchange is True + + def test_legacy_hiding_expanded_head_dim(self): + config = _build(control_dims=32) + _, out = _forward( + config, + torch.tensor([[10, 20, 100, 40]], dtype=torch.long), + ) + decoder_key = out.past_key_values.layers[1].keys + assert decoder_key.shape[-1] == config.projection_head_dim + 32 + assert config.use_token_exchange is False + + +class TestSwitchStillDetectsAdapter: + """Swap must happen AFTER the switch reads input_ids, so detection is unaffected.""" + + def test_adapter_indices_still_activate(self): + config = _build(control_dims=0, substitute_ids=(5, 7)) + model, _ = _forward( + config, + torch.tensor([[10, 20, 100, 40, 50]], dtype=torch.long), + ) + adapter_indices = model.model._last_adapter_indices + # Position 2 is the control token for adapter 0 (1-indexed output). + # Positions after it inherit adapter=1 (SingleSwitch persists once fired). + assert adapter_indices[0, 0].item() == 0 + assert adapter_indices[0, 1].item() == 0 + assert adapter_indices[0, 2].item() == 1 + assert adapter_indices[0, 3].item() == 1 + assert adapter_indices[0, 4].item() == 1 + + +class TestPositionCorrectionSkipped: + """In token-exchange mode, position correction is a no-op.""" + + def test_no_position_shift_in_te_mode(self): + """RoPE positions should equal the input positions (no hidden_count subtraction).""" + config = _build(control_dims=0, substitute_ids=(5, 7)) + model = GraniteSwitchForCausalLM(config).eval() + input_ids = torch.tensor([[10, 100, 20, 30]], dtype=torch.long) + # Forward runs without error; the guarded branch would otherwise fire + # and shift positions by 1 for tokens 2/3. + with torch.no_grad(): + out = model(input_ids=input_ids) + # Sanity: logits shape matches input_ids shape. + assert out.logits.shape[:2] == input_ids.shape diff --git a/tests/shared/generation_models.py b/tests/shared/generation_models.py index 2635c22..727ca71 100644 --- a/tests/shared/generation_models.py +++ b/tests/shared/generation_models.py @@ -49,7 +49,12 @@ def single_overrides(base_cfg): - """SingleSwitch overrides for the given base config.""" + """SingleSwitch overrides for the given base config. + + Pinned to the legacy KV-hiding path (control_dims=32) so existing + generation tests exercise hiding semantics even after the default + flipped to token-exchange. + """ base_layers = base_cfg["layer_types"] return { "num_adapters": NUM_ADAPTERS, @@ -63,6 +68,7 @@ def single_overrides(base_cfg): "adapter_1": ["all_controls"], }, "adapter_third_party": ["adapter_0", "adapter_1"], + "control_dims": 32, "num_hidden_layers": len(base_layers) + 1, "layer_types": ["attention"] + base_layers, } diff --git a/tests/shared/granite4_equivalence.py b/tests/shared/granite4_equivalence.py index 865a68c..f2606dd 100644 --- a/tests/shared/granite4_equivalence.py +++ b/tests/shared/granite4_equivalence.py @@ -193,6 +193,9 @@ def augment_cfg_with_adapters(cfg_dict, num_adapters=2, rank=8): name: ["all_controls"] for name in ["base"] + list(adapter_names) } cfg["adapter_third_party"] = list(adapter_names) + # These equivalence tests specifically exercise the legacy KV-hiding path. + # Pin control_dims=32 so they keep running after the default flipped to 0. + cfg.setdefault("control_dims", 32) return cfg diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 9280a0c..92b53f3 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -12,7 +12,12 @@ # ── Helper ──────────────────────────────────────────────────────────── def _valid_kwargs(num_adapters=2, **overrides): - """Return kwargs for a valid SingleSwitch config, with optional overrides.""" + """Return kwargs for a valid SingleSwitch config, with optional overrides. + + Default mode: legacy KV hiding (control_dims=32) so existing shape/hiding + tests keep their original contract. Token-exchange tests override + control_dims=0 and pass adapter_substitute_token_ids. + """ adapter_names = [f"adapter_{i}" for i in range(num_adapters)] base = dict( vocab_size=300, @@ -26,6 +31,7 @@ def _valid_kwargs(num_adapters=2, **overrides): adapter_names=adapter_names, max_lora_rank=8, adapter_ranks=[8] * num_adapters, + control_dims=32, ) base.update(overrides) return base diff --git a/tests/unit/test_config_edge_cases.py b/tests/unit/test_config_edge_cases.py index b161920..55ca3c9 100644 --- a/tests/unit/test_config_edge_cases.py +++ b/tests/unit/test_config_edge_cases.py @@ -15,7 +15,10 @@ def _valid_kwargs(num_adapters=2, **overrides): - """Return kwargs for a valid SingleSwitch config, with optional overrides.""" + """Return kwargs for a valid SingleSwitch config, with optional overrides. + + Default mode: legacy KV hiding (control_dims=32). + """ adapter_names = [f"adapter_{i}" for i in range(num_adapters)] base = dict( vocab_size=300, @@ -29,6 +32,7 @@ def _valid_kwargs(num_adapters=2, **overrides): adapter_names=adapter_names, max_lora_rank=8, adapter_ranks=[8] * num_adapters, + control_dims=32, ) base.update(overrides) return base @@ -65,10 +69,18 @@ def test_negative_control_dims_raises(self): with pytest.raises(ValueError, match="control_dims must be >= 0"): GraniteSwitchConfig(**_valid_kwargs(control_dims=-1)) - def test_zero_control_dims_valid(self): - """Zero control_dims is valid (native mode, no KV hiding).""" - cfg = GraniteSwitchConfig(**_valid_kwargs(control_dims=0)) + def test_zero_control_dims_valid_with_substitute_ids(self): + """Zero control_dims is valid when substitute ids are provided (token exchange).""" + cfg = GraniteSwitchConfig( + **_valid_kwargs(control_dims=0, adapter_substitute_token_ids=[1, 2]) + ) assert cfg.control_dims == 0 + assert cfg.use_token_exchange is True + + def test_zero_control_dims_no_substitute_ids_raises(self): + """Zero control_dims without substitute ids must fail: no hiding and no exchange.""" + with pytest.raises(ValueError, match="either control_dims > 0"): + GraniteSwitchConfig(**_valid_kwargs(control_dims=0)) def test_positive_control_dims_valid(self): """Positive control_dims is valid.""" diff --git a/tests/unit/test_token_exchange.py b/tests/unit/test_token_exchange.py new file mode 100644 index 0000000..3e74cb0 --- /dev/null +++ b/tests/unit/test_token_exchange.py @@ -0,0 +1,102 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Unit tests for the token-exchange config path. + +Covers the new fields and validators on GraniteSwitchConfig: +- adapter_substitute_token_ids length check +- use_token_exchange property +- rejection of num_adapters>0 with no hiding and no substitute ids +- rejection of duplicate adapter_token_ids (LUT would collide) +- default control_dims flipped to 0 +""" + +import pytest + +from granite_switch.config import GraniteSwitchConfig + + +def _base(num_adapters=2, **overrides): + names = [f"a{i}" for i in range(num_adapters)] + base = dict( + vocab_size=300, + hidden_size=64, + intermediate_size=128, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=4, + num_adapters=num_adapters, + adapter_token_ids=list(range(500, 500 + num_adapters)), + adapter_names=names, + max_lora_rank=8, + adapter_ranks=[8] * num_adapters, + ) + base.update(overrides) + return base + + +class TestDefaults: + def test_control_dims_default_is_zero(self): + cfg = GraniteSwitchConfig(num_adapters=0) + assert cfg.control_dims == 0 + + def test_no_adapters_no_validation(self): + cfg = GraniteSwitchConfig(num_adapters=0) + assert cfg.use_token_exchange is False + + +class TestUseTokenExchangeProperty: + def test_true_when_substitute_and_zero_dims(self): + cfg = GraniteSwitchConfig( + **_base( + control_dims=0, + adapter_substitute_token_ids=[1, 2], + ) + ) + assert cfg.use_token_exchange is True + + def test_false_when_legacy_hiding(self): + cfg = GraniteSwitchConfig(**_base(control_dims=32)) + assert cfg.use_token_exchange is False + + def test_false_when_no_substitute_ids_even_with_zero_dims_requires_validator(self): + # This combo is invalid — validator rejects it, so the property + # cannot be observed in a built config. Covered in TestValidation. + pass + + +class TestValidation: + def test_zero_dims_plus_missing_substitute_ids_raises(self): + with pytest.raises(ValueError, match="either control_dims > 0"): + GraniteSwitchConfig(**_base(control_dims=0)) + + def test_substitute_wrong_length_raises(self): + with pytest.raises(ValueError, match="adapter_substitute_token_ids length"): + GraniteSwitchConfig( + **_base(control_dims=0, adapter_substitute_token_ids=[1]) + ) + + def test_duplicate_adapter_token_ids_raises(self): + with pytest.raises(ValueError, match="adapter_token_ids must be unique"): + GraniteSwitchConfig( + **_base( + adapter_token_ids=[100, 100], + adapter_substitute_token_ids=[1, 2], + control_dims=0, + ) + ) + + +class TestLegacyPathStillWorks: + def test_control_dims_positive_without_substitute_ids(self): + cfg = GraniteSwitchConfig(**_base(control_dims=32)) + assert cfg.control_dims == 32 + assert cfg.use_token_exchange is False + # Expanded head_dim reflects the legacy path. + assert cfg.expanded_head_dim == cfg.projection_head_dim + 32 + + +class TestExpandedHeadDim: + def test_token_exchange_gives_native_head_dim(self): + cfg = GraniteSwitchConfig( + **_base(control_dims=0, adapter_substitute_token_ids=[1, 2]) + ) + assert cfg.expanded_head_dim == cfg.projection_head_dim From f1de1b741712a1a9d692853926952fefe10abf25 Mon Sep 17 00:00:00 2001 From: AlonMalach Date: Wed, 13 May 2026 09:14:29 +0300 Subject: [PATCH 02/18] Add token-exchange parity eval harness (#8) Measures four metrics per position, teacher-forced, across a list of prompts to compare legacy KV-hiding (control_dims>0) vs. token-exchange (control_dims=0 + substitute ids): 1. KL(p_old || p_new) per position (log_softmax based to avoid underflow) 2. Top-1 agreement (tagged "(noisy)" on wide nuclei) 3. Nucleus (top-p=0.9) Jaccard (sampling-set overlap) 4. Mass under old nucleus by new (the actionable gate) Results are partitioned into overall / pre-control / adapter-active buckets. The pre-control bucket must be bit-for-bit identical (KL max == 0, top-1 agree == 1.0); any drift there signals a bug in the embedding-swap gating rather than a mode trade-off. Two modes: - Synthetic (CPU): builds two HF models with identical base weights, one in legacy hiding and one in token-exchange. Useful as a plumbing check and regression guard. Runs as a standard pytest. - Real-model (GPU, opt-in): set GRANITE_SWITCH_PARITY_MODELS='{"old":..., "new":...}'. Loads composed checkpoints and uses demo-script prompts (14 adapter-specific prompts from run_adapter_generation_direct.py) rendered through the composed tokenizer's chat template. Thresholds: top-1 >= 0.95, mean KL <= 0.02, mean mass-under-old-nucleus >= 0.88. Also exposes build_demo_prompts() in the demo script. Short-circuits _generate via a module-level capture flag so prompt text is collected without touching model.generate. Used by the parity eval to pull realistic adapter inputs without duplicating the demo prompt data. CLI usage: python -m tests.integration.test_token_exchange_parity \ --old /path/to/legacy_build --new /path/to/te_build --json-out report.json --- .../integration/test_token_exchange_parity.py | 602 ++++++++++++++++++ .../run_adapter_generation_direct.py | 98 ++- 2 files changed, 680 insertions(+), 20 deletions(-) create mode 100644 tests/integration/test_token_exchange_parity.py diff --git a/tests/integration/test_token_exchange_parity.py b/tests/integration/test_token_exchange_parity.py new file mode 100644 index 0000000..e9d5ab1 --- /dev/null +++ b/tests/integration/test_token_exchange_parity.py @@ -0,0 +1,602 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Parity eval: legacy KV hiding vs. token exchange. + +Measures four metrics per token position, teacher-forced, across a list of +prompts: + +1. KL(p_old || p_new) — full-distribution divergence +2. Top-1 agreement — headline sanity metric +3. Nucleus (top-p=0.9) Jaccard — do sampling sets agree? +4. Mass under old nucleus — does new model put probability on tokens old + model considered plausible? + +Two modes: + +**Synthetic mode (default, CPU-friendly):** builds two HF models with +identical base weights, one in legacy KV-hiding mode and one in token- +exchange mode. Measures only the effect of control-token handling on +logits — *not* trained-adapter behavior. Useful as a plumbing sanity +check and a regression guard. + +**Real-model mode (GPU, opt-in):** set +``GRANITE_SWITCH_PARITY_MODELS='{"old":"/path","new":"/path"}'`` (JSON with +two paths) and pytest will load actual composed checkpoints. This is the +pre-merge gate described in docs/KV_CACHE_OVERHEAD_REMOVAL.md §4. + +Run directly:: + + python -m tests.integration.test_token_exchange_parity + +Run as test:: + + pytest tests/integration/test_token_exchange_parity.py -v -s --tb=short +""" + +from __future__ import annotations + +import argparse +import json +import os +import statistics +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple + +import pytest +import torch +import torch.nn.functional as F + +from granite_switch.config import GraniteSwitchConfig +from granite_switch.hf import GraniteSwitchForCausalLM + + +# ──────────────────────────────────────────────────────────────────── +# Metric primitives +# ──────────────────────────────────────────────────────────────────── + + +def _kl_from_logits(logits_p: torch.Tensor, logits_q: torch.Tensor) -> float: + """KL(p || q) in nats, computed from logits to avoid softmax underflow. + + Equivalent to ``sum_i p_i * (log p_i - log q_i)`` but evaluated via + log_softmax so that very small tail probabilities don't underflow to + zero before the log. logits_{p,q}: 1-D [vocab]. + """ + log_p = F.log_softmax(logits_p, dim=-1) + log_q = F.log_softmax(logits_q, dim=-1) + p = log_p.exp() + return float((p * (log_p - log_q)).sum()) + + +def _nucleus_indices(p: torch.Tensor, top_p: float) -> torch.Tensor: + """Smallest descending-sorted prefix whose cumulative sum >= top_p. + + The nucleus always contains at least one token (the argmax). k is the + index (1-based count) of the first element where cumsum >= top_p. + """ + sorted_p, sorted_idx = torch.sort(p, descending=True) + cumsum = torch.cumsum(sorted_p, dim=0) + # Smallest index where cumsum >= top_p. If cumsum never reaches top_p + # (floating-point edge), keep everything. + ge = cumsum >= top_p + if ge.any(): + k = int(torch.argmax(ge.int()).item()) + 1 # +1: argmax is 0-indexed, we want count + else: + k = sorted_idx.numel() + k = max(1, min(k, sorted_idx.numel())) + return sorted_idx[:k] + + +def _jaccard(a: torch.Tensor, b: torch.Tensor) -> float: + a_set = set(a.tolist()) + b_set = set(b.tolist()) + if not a_set and not b_set: + return 1.0 + return len(a_set & b_set) / len(a_set | b_set) + + +# ──────────────────────────────────────────────────────────────────── +# Core eval +# ──────────────────────────────────────────────────────────────────── + + +@dataclass +class PositionResult: + kl: float + top1_agree: bool + nucleus_jaccard: float + mass_under_old_nucleus: float + old_nucleus_size: int + new_nucleus_size: int + # Partition flag: True if this position is at or after a control token + # in the causal past (adapter-activated); False if it's in the base path. + adapter_active: bool = False + + +@dataclass +class AggregateResult: + n_positions: int + kl_mean: float + kl_median: float + kl_p95: float + kl_max: float + top1_agree_rate: float + nucleus_jaccard_mean: float + nucleus_jaccard_exact_match_rate: float + mass_under_old_nucleus_mean: float + mass_under_old_nucleus_p05: float + # Fraction of positions where the new model places less than 0.80 of its + # probability mass on the old model's nucleus — the actionable "how bad" + # signal for real-model runs. + frac_mass_under_80: float + # Nucleus size distribution: useful for judging whether top-1 agreement + # is a meaningful metric (confident models: median 1-5; noise: ~vocab/2). + old_nucleus_size_median: float + new_nucleus_size_median: float + old_nucleus_size_p05: float + old_nucleus_size_p95: float + + def render(self, heading: str = "") -> str: + header = f"── {heading} ──" if heading else "" + trusted = "(trusted)" if self.old_nucleus_size_median < 50 else "(noisy: wide nucleus)" + lines = [ + header, + f"n_positions = {self.n_positions}", + "", + "KL(p_old || p_new) per position:", + f" mean = {self.kl_mean:.6f}", + f" median = {self.kl_median:.6f}", + f" p95 = {self.kl_p95:.6f}", + f" max = {self.kl_max:.6f}", + "", + f"Top-1 agreement rate = {self.top1_agree_rate:.4f} {trusted}", + "", + "Nucleus (top-p=0.9):", + f" size — old p05/med/p95 = {self.old_nucleus_size_p05:g} / {self.old_nucleus_size_median:g} / {self.old_nucleus_size_p95:g}", + f" size — new median = {self.new_nucleus_size_median:g}", + f" Jaccard mean = {self.nucleus_jaccard_mean:.4f}", + f" exact-match rate = {self.nucleus_jaccard_exact_match_rate:.4f}", + "", + "Mass under old nucleus (new model):", + f" mean = {self.mass_under_old_nucleus_mean:.4f}", + f" p05 (worst 5% of positions) = {self.mass_under_old_nucleus_p05:.4f}", + f" frac positions < 0.80 = {self.frac_mass_under_80:.4f}", + ] + return "\n".join(line for line in lines if line is not None) + + def as_dict(self) -> Dict[str, float]: + return { + "n_positions": self.n_positions, + "kl_mean": self.kl_mean, + "kl_median": self.kl_median, + "kl_p95": self.kl_p95, + "kl_max": self.kl_max, + "top1_agree_rate": self.top1_agree_rate, + "nucleus_jaccard_mean": self.nucleus_jaccard_mean, + "nucleus_jaccard_exact_match_rate": self.nucleus_jaccard_exact_match_rate, + "mass_under_old_nucleus_mean": self.mass_under_old_nucleus_mean, + "mass_under_old_nucleus_p05": self.mass_under_old_nucleus_p05, + "frac_mass_under_80": self.frac_mass_under_80, + "old_nucleus_size_median": self.old_nucleus_size_median, + "new_nucleus_size_median": self.new_nucleus_size_median, + "old_nucleus_size_p05": self.old_nucleus_size_p05, + "old_nucleus_size_p95": self.old_nucleus_size_p95, + } + + +def _adapter_active_mask(input_ids: torch.Tensor, adapter_token_ids: List[int]) -> torch.Tensor: + """[seq_len] bool: True at position s if any control token appears at + positions <= s. Token at position s itself counts — the swap happens + before that position's hidden state enters the decoder.""" + ctrl_set = set(adapter_token_ids) + is_ctrl = torch.tensor( + [int(t.item()) in ctrl_set for t in input_ids], dtype=torch.bool + ) + # Cumulative OR along the sequence. + return torch.cummax(is_ctrl.int(), dim=0).values.bool() + + +def _per_position_metrics( + logits_old: torch.Tensor, + logits_new: torch.Tensor, + top_p: float, + adapter_active: Optional[torch.Tensor] = None, +) -> List[PositionResult]: + """logits_{old,new}: [seq_len, vocab_size]. Returns one result per position.""" + assert logits_old.shape == logits_new.shape + results: List[PositionResult] = [] + # Promote to float32 for metric stability. + logits_old = logits_old.to(torch.float32) + logits_new = logits_new.to(torch.float32) + p_old_all = F.softmax(logits_old, dim=-1) + p_new_all = F.softmax(logits_new, dim=-1) + for s in range(logits_old.shape[0]): + p_old = p_old_all[s] + p_new = p_new_all[s] + nuc_old = _nucleus_indices(p_old, top_p) + nuc_new = _nucleus_indices(p_new, top_p) + results.append( + PositionResult( + kl=_kl_from_logits(logits_old[s], logits_new[s]), + top1_agree=bool(p_old.argmax() == p_new.argmax()), + nucleus_jaccard=_jaccard(nuc_old, nuc_new), + mass_under_old_nucleus=float(p_new[nuc_old].sum()), + old_nucleus_size=int(nuc_old.numel()), + new_nucleus_size=int(nuc_new.numel()), + adapter_active=bool(adapter_active[s]) if adapter_active is not None else False, + ) + ) + return results + + +def _aggregate(results: List[PositionResult]) -> AggregateResult: + if not results: + raise ValueError("No positions measured") + kls = sorted(r.kl for r in results) + jaccards = [r.nucleus_jaccard for r in results] + mass = sorted(r.mass_under_old_nucleus for r in results) + old_sizes = sorted(r.old_nucleus_size for r in results) + n = len(results) + p05_idx = max(0, int(n * 0.05) - 1) + p95_idx = min(n - 1, int(n * 0.95)) + return AggregateResult( + n_positions=n, + kl_mean=statistics.mean(kls), + kl_median=statistics.median(kls), + kl_p95=kls[p95_idx], + kl_max=kls[-1], + top1_agree_rate=sum(r.top1_agree for r in results) / n, + nucleus_jaccard_mean=statistics.mean(jaccards), + nucleus_jaccard_exact_match_rate=sum(1 for j in jaccards if j == 1.0) / n, + mass_under_old_nucleus_mean=statistics.mean(mass), + mass_under_old_nucleus_p05=mass[p05_idx], + frac_mass_under_80=sum(1 for m in mass if m < 0.80) / n, + old_nucleus_size_median=statistics.median(r.old_nucleus_size for r in results), + new_nucleus_size_median=statistics.median(r.new_nucleus_size for r in results), + old_nucleus_size_p05=old_sizes[p05_idx], + old_nucleus_size_p95=old_sizes[p95_idx], + ) + + +# ──────────────────────────────────────────────────────────────────── +# Synthetic model builder (CPU-friendly, weight-sharing pair) +# ──────────────────────────────────────────────────────────────────── + + +_SYNTHETIC_BASE_KWARGS = dict( + vocab_size=512, + hidden_size=64, + num_attention_heads=4, + num_key_value_heads=2, + num_hidden_layers=4, + intermediate_size=128, + shared_intermediate_size=128, + max_position_embeddings=128, + mamba_n_heads=1, + mamba_expand=1, + torch_dtype=torch.float32, +) + + +def _build_synthetic_pair( + num_adapters: int = 2, + seed: int = 0, +) -> Tuple[GraniteSwitchForCausalLM, GraniteSwitchForCausalLM]: + """Build two models with identical base weights: old=hiding, new=exchange. + + Any logit difference between them is therefore purely from control-token + handling, not from weight initialization. + """ + adapter_token_ids = [100, 101][:num_adapters] + substitute_token_ids = [5, 7][:num_adapters] + + torch.manual_seed(seed) + old_config = GraniteSwitchConfig( + **_SYNTHETIC_BASE_KWARGS, + num_adapters=num_adapters, + adapter_ranks=[4] * num_adapters, + max_lora_rank=4, + adapter_token_ids=adapter_token_ids, + adapter_names=[f"a{i}" for i in range(num_adapters)], + control_dims=32, + ) + old_model = GraniteSwitchForCausalLM(old_config).eval() + + new_config = GraniteSwitchConfig( + **_SYNTHETIC_BASE_KWARGS, + num_adapters=num_adapters, + adapter_ranks=[4] * num_adapters, + max_lora_rank=4, + adapter_token_ids=adapter_token_ids, + adapter_substitute_token_ids=substitute_token_ids, + adapter_names=[f"a{i}" for i in range(num_adapters)], + control_dims=0, + ) + new_model = GraniteSwitchForCausalLM(new_config).eval() + + # Share weights where the two configs have matching parameter shapes. + # Non-shared: tensors whose shape depends on control_dims (e.g. switch + # head_dim in the legacy path differs from the new native-head_dim path). + old_sd = old_model.state_dict() + new_sd = new_model.state_dict() + transferred = 0 + skipped: List[str] = [] + for name, new_tensor in new_sd.items(): + if name in old_sd and old_sd[name].shape == new_tensor.shape: + new_tensor.copy_(old_sd[name]) + transferred += 1 + else: + skipped.append(name) + assert transferred > 0, "no weights transferred; synthetic pair would be meaningless" + return old_model, new_model + + +def _synthetic_prompts( + num_adapters: int, + adapter_token_ids: List[int], + vocab_size: int, +) -> List[torch.Tensor]: + """A small, deterministic set of prompt sequences. + + Mix of: + - Prompts with no control token (base-path sanity). + - Prompts with a control token at different positions (adapter-activated). + """ + torch.manual_seed(42) + prompts: List[torch.Tensor] = [] + seq_len = 24 + # Fill tokens are drawn from the vocab excluding control-token ids. + safe_vocab = [t for t in range(1, vocab_size) if t not in adapter_token_ids] + + def _rand_seq() -> List[int]: + return [safe_vocab[int(torch.randint(0, len(safe_vocab), (1,)))] for _ in range(seq_len)] + + # Base-path prompts (no control tokens). + for _ in range(4): + prompts.append(torch.tensor([_rand_seq()], dtype=torch.long)) + # Adapter-activated prompts (one control token at varying positions). + for pos in (0, 2, 5, 10): + for ctrl_id in adapter_token_ids: + seq = _rand_seq() + seq[pos] = ctrl_id + prompts.append(torch.tensor([seq], dtype=torch.long)) + return prompts + + +def _demo_prompts(tokenizer, adapter_names: List[str]) -> List[torch.Tensor]: + """Realistic parity prompts: render every demo from tutorials/scripts + through the composed model's chat template, then tokenize. + + Each returned tensor is shape [1, seq_len]. Shape varies per prompt — + the parity eval loops one at a time, so no padding is needed. + """ + from tutorials.scripts.run_adapter_generation_direct import build_demo_prompts + + prompts: List[torch.Tensor] = [] + pairs = build_demo_prompts(tokenizer, available_adapters=set(adapter_names)) + for _demo_key, prompt_text in pairs: + ids = tokenizer(prompt_text, return_tensors="pt").input_ids + prompts.append(ids) + return prompts + + +# ──────────────────────────────────────────────────────────────────── +# Runner +# ──────────────────────────────────────────────────────────────────── + + +@dataclass +class ParityReport: + overall: AggregateResult + pre_control: Optional[AggregateResult] # positions before any control token + adapter_active: Optional[AggregateResult] # positions at / after control token + + def render(self) -> str: + parts = [self.overall.render("overall")] + if self.pre_control is not None: + parts.append("") + parts.append(self.pre_control.render("pre-control (base path)")) + if self.adapter_active is not None: + parts.append("") + parts.append(self.adapter_active.render("adapter-active")) + return "\n".join(parts) + + def as_dict(self) -> Dict: + d = {"overall": self.overall.as_dict()} + if self.pre_control is not None: + d["pre_control"] = self.pre_control.as_dict() + if self.adapter_active is not None: + d["adapter_active"] = self.adapter_active.as_dict() + return d + + +def run_parity_eval( + old_model: GraniteSwitchForCausalLM, + new_model: GraniteSwitchForCausalLM, + prompts: List[torch.Tensor], + adapter_token_ids: List[int], + top_p: float = 0.9, +) -> ParityReport: + all_results: List[PositionResult] = [] + for prompt in prompts: + with torch.no_grad(): + out_old = old_model(input_ids=prompt) + out_new = new_model(input_ids=prompt) + logits_old = out_old.logits[0] # [seq_len, vocab] + logits_new = out_new.logits[0] + mask = _adapter_active_mask(prompt[0], adapter_token_ids) + all_results.extend( + _per_position_metrics(logits_old, logits_new, top_p, adapter_active=mask) + ) + + overall = _aggregate(all_results) + pre = [r for r in all_results if not r.adapter_active] + active = [r for r in all_results if r.adapter_active] + return ParityReport( + overall=overall, + pre_control=_aggregate(pre) if pre else None, + adapter_active=_aggregate(active) if active else None, + ) + + +# ──────────────────────────────────────────────────────────────────── +# pytest entry points +# ──────────────────────────────────────────────────────────────────── + + +@pytest.mark.integration +def test_synthetic_parity_cpu(): + """Plumbing sanity check: legacy vs. token-exchange on a synthetic pair. + + With random weights, the two paths produce different logits (the swap IS + the difference), but the *structure* of the comparison should hold: base + positions (no control token) should agree perfectly; adapter-activated + positions will differ and we report how much. + """ + old_model, new_model = _build_synthetic_pair() + prompts = _synthetic_prompts( + num_adapters=2, + adapter_token_ids=[100, 101], + vocab_size=_SYNTHETIC_BASE_KWARGS["vocab_size"], + ) + report = run_parity_eval( + old_model, new_model, prompts, adapter_token_ids=[100, 101] + ) + print("\n" + report.render()) + assert report.overall.n_positions > 0 + assert report.overall.kl_max >= 0.0 + assert 0.0 <= report.overall.top1_agree_rate <= 1.0 + # Pre-control positions MUST agree bit-for-bit (both paths process them + # identically — no substitution, no hiding). Any disagreement here is a + # bug in the swap gating, not a mode trade-off. + if report.pre_control is not None: + assert report.pre_control.kl_max < 1e-6, ( + f"Pre-control KL max {report.pre_control.kl_max} should be ~0" + ) + assert report.pre_control.top1_agree_rate == 1.0 + + +@pytest.mark.slow +@pytest.mark.requires_model +def test_real_model_parity(): + """Gate for real composed checkpoints. Opt-in via env var. + + Set ``GRANITE_SWITCH_PARITY_MODELS`` to a JSON object with two paths: + '{"old": "/path/to/control_dims=32_build", "new": "/path/to/token_exchange_build"}' + + Both must be composed from the same base + adapter pair, differing only + in --legacy-hiding. Acceptance thresholds are documented per-metric; the + test fails if any is exceeded. + """ + spec = os.environ.get("GRANITE_SWITCH_PARITY_MODELS") + if spec is None: + pytest.skip("GRANITE_SWITCH_PARITY_MODELS env var not set") + paths = json.loads(spec) + old_path, new_path = paths["old"], paths["new"] + + old_model = GraniteSwitchForCausalLM.from_pretrained(old_path).eval() + new_model = GraniteSwitchForCausalLM.from_pretrained(new_path).eval() + + # Prompt set priority: + # 1. GRANITE_SWITCH_PARITY_PROMPTS env var (JSON array of int lists). + # 2. Rendered demo prompts from tutorials/scripts/run_adapter_generation_direct + # via the composed tokenizer — realistic adapter inputs. + # 3. Synthetic fallback (only useful when demo prompts fail for some reason). + prompts_spec = os.environ.get("GRANITE_SWITCH_PARITY_PROMPTS") + if prompts_spec: + prompt_lists = json.loads(prompts_spec) + prompts = [torch.tensor([p], dtype=torch.long) for p in prompt_lists] + else: + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(old_path) + adapter_names = list(old_model.config.adapter_names or []) + prompts = _demo_prompts(tokenizer, adapter_names) + if not prompts: + prompts = _synthetic_prompts( + num_adapters=old_model.config.num_adapters, + adapter_token_ids=list(old_model.config.adapter_token_ids), + vocab_size=old_model.config.vocab_size, + ) + + report = run_parity_eval( + old_model, + new_model, + prompts, + adapter_token_ids=list(old_model.config.adapter_token_ids), + ) + print("\n" + report.render()) + + # Acceptance thresholds on the adapter-active partition (the comparison + # we actually care about). See docs/KV_CACHE_OVERHEAD_REMOVAL.md §4. + # Initial guesses — calibrate against a control_dims=32 vs control_dims=1 + # baseline run first and tighten to 2-3x the observed noise floor. + active = report.adapter_active if report.adapter_active else report.overall + assert active.top1_agree_rate >= 0.95, ( + f"top-1 agreement {active.top1_agree_rate:.4f} below 0.95 threshold" + ) + assert active.kl_mean <= 0.02, ( + f"mean KL {active.kl_mean:.5f} above 0.02 threshold" + ) + assert active.mass_under_old_nucleus_mean >= 0.88, ( + f"mean mass under old nucleus {active.mass_under_old_nucleus_mean:.4f} " + f"below 0.88 threshold" + ) + + +# ──────────────────────────────────────────────────────────────────── +# CLI entry point +# ──────────────────────────────────────────────────────────────────── + + +def _cli(): + p = argparse.ArgumentParser(description="Token-exchange parity eval.") + p.add_argument("--old", type=str, default=None, help="Path to legacy-hiding model") + p.add_argument("--new", type=str, default=None, help="Path to token-exchange model") + p.add_argument("--top-p", type=float, default=0.9) + p.add_argument("--json-out", type=str, default=None, help="Optional JSON report path") + args = p.parse_args() + + if args.old and args.new: + print(f"Loading old model from {args.old}...") + old_model = GraniteSwitchForCausalLM.from_pretrained(args.old).eval() + print(f"Loading new model from {args.new}...") + new_model = GraniteSwitchForCausalLM.from_pretrained(args.new).eval() + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(args.old) + adapter_names = list(old_model.config.adapter_names or []) + print(f"Building demo prompts for adapters: {adapter_names}") + prompts = _demo_prompts(tokenizer, adapter_names) + if not prompts: + print("No demo prompts matched; falling back to synthetic.") + prompts = _synthetic_prompts( + num_adapters=old_model.config.num_adapters, + adapter_token_ids=list(old_model.config.adapter_token_ids), + vocab_size=old_model.config.vocab_size, + ) + print(f"Collected {len(prompts)} prompts.") + else: + print("Running synthetic parity (no --old/--new paths given)...") + old_model, new_model = _build_synthetic_pair() + prompts = _synthetic_prompts( + num_adapters=2, + adapter_token_ids=[100, 101], + vocab_size=_SYNTHETIC_BASE_KWARGS["vocab_size"], + ) + + if args.old and args.new: + adapter_token_ids = list(old_model.config.adapter_token_ids) + else: + adapter_token_ids = [100, 101] + report = run_parity_eval( + old_model, new_model, prompts, adapter_token_ids=adapter_token_ids, top_p=args.top_p, + ) + print() + print(report.render()) + if args.json_out: + import json as _json + with open(args.json_out, "w") as f: + _json.dump(report.as_dict(), f, indent=2) + print(f"\nWrote JSON report to {args.json_out}") + + +if __name__ == "__main__": + _cli() diff --git a/tutorials/scripts/reference/run_adapter_generation_direct.py b/tutorials/scripts/reference/run_adapter_generation_direct.py index 9998e5c..2d33b6a 100644 --- a/tutorials/scripts/reference/run_adapter_generation_direct.py +++ b/tutorials/scripts/reference/run_adapter_generation_direct.py @@ -62,7 +62,17 @@ def load_model(model_dir: str): def _generate(model, tokenizer, text: str, max_new_tokens: int) -> str: - """Generate text and return only the new tokens.""" + """Generate text and return only the new tokens. + + When ``_PROMPT_CAPTURE`` is active (set by build_demo_prompts), skip + generation entirely, capture the prompt on the thread-local list, and + return an empty string so each demo's subsequent logic (score parsing, + etc.) can no-op harmlessly. + """ + if _PROMPT_CAPTURE is not None: + _PROMPT_CAPTURE.append(text) + return "" + device = model.device inputs = tokenizer(text, return_tensors="pt").to(device) @@ -75,11 +85,77 @@ def _generate(model, tokenizer, text: str, max_new_tokens: int) -> str: return tokenizer.decode(generated_ids, skip_special_tokens=True).strip() +# Module-level capture switch. Populated by build_demo_prompts; None means +# the normal generate path runs. +_PROMPT_CAPTURE: Optional[list] = None + + +def build_demo_prompts( + tokenizer, available_adapters: Optional[set[str]] = None, +) -> list[tuple[str, str]]: + """Render every registered demo's prompt as a string, without generation. + + Returns a list of ``(demo_key, prompt_text)`` pairs for all demos whose + base adapter is present in ``available_adapters`` (or every registered + demo when the filter is None). The prompts are exactly what the demo + script would feed to ``model.generate`` — chat-template-rendered and + adapter-token-injected by the composed tokenizer. + + Used by the token-exchange parity eval + (tests/integration/test_token_exchange_parity.py) to compare legacy + hiding vs. token-exchange on realistic adapter inputs. + """ + global _PROMPT_CAPTURE + results: list[tuple[str, str]] = [] + _PROMPT_CAPTURE = [] + try: + for base_adapter, demo_fn in _DEMOS: + if available_adapters is not None and base_adapter not in available_adapters: + continue + demo_key = demo_fn.__name__.removeprefix("demo_") + _PROMPT_CAPTURE.clear() + try: + demo_fn(model=None, tokenizer=tokenizer, max_new_tokens=1) + except Exception as e: + # Some demos parse the (empty) output and may raise. Capture + # the prompt we already collected and move on; partial prompts + # are still useful for parity comparison. + if not _PROMPT_CAPTURE: + print(f"[build_demo_prompts] {demo_key}: {e}") + continue + for prompt_text in _PROMPT_CAPTURE: + results.append((demo_key, prompt_text)) + finally: + _PROMPT_CAPTURE = None + return results + + # --------------------------------------------------------------------------- # Activation helper — uses the composed model's chat template # --------------------------------------------------------------------------- +def _build_prompt( + tokenizer, + adapter_name: str, + messages: list[dict], + documents: Optional[list[dict]] = None, +) -> str: + """Render an adapter prompt using the composed model's chat template. + + Separated from _invoke so callers (e.g. the parity eval) can obtain the + exact prompt text without running generation. + """ + tmpl_kwargs: dict = { + "tokenize": False, + "add_generation_prompt": True, + "adapter_name": adapter_name, + } + if documents is not None: + tmpl_kwargs["documents"] = documents + return tokenizer.apply_chat_template(messages, **tmpl_kwargs) + + def _invoke( model, tokenizer, @@ -96,26 +172,8 @@ def _invoke( position for that adapter's technology (LoRA prefix vs aLoRA splice). See ``composer/tokenizer_setup.py`` for the template machinery. - - Args: - adapter_name: Name of the adapter to activate; must be one of - the composed model's ``adapter_names``. - messages: List of ``{"role", "content"}`` dicts. - documents: Optional list of ``{"doc_id", "text"}`` dicts, as - documented in the granite-switch README. - max_new_tokens: Generation budget. - - Returns: - The generated adapter output (new tokens only, decoded). """ - tmpl_kwargs: dict = { - "tokenize": False, - "add_generation_prompt": True, - "adapter_name": adapter_name, - } - if documents is not None: - tmpl_kwargs["documents"] = documents - prompt = tokenizer.apply_chat_template(messages, **tmpl_kwargs) + prompt = _build_prompt(tokenizer, adapter_name, messages, documents=documents) return _generate(model, tokenizer, prompt, max_new_tokens) From 5333787869ab082e7a09e8cccb10ec02cc144e42 Mon Sep 17 00:00:00 2001 From: AlonMalach Date: Wed, 13 May 2026 12:52:54 +0300 Subject: [PATCH 03/18] Use <|start_of_role|> instead of BOS for LoRA/builtin substitute (#8) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Granite tokenizers alias bos_token_id to <|end_of_text|> (EOS), so the previous BOS-based substitute for LoRA/builtin adapters would have injected an end-of-text signal mid-prompt — a stop-generation marker in a place the model was not trained to see it. The chat template places the LoRA control token at sequence start, immediately followed by <|start_of_role|>user<|end_of_role|>... — so <|start_of_role|> is the deterministic "token that naturally follows" for every LoRA adapter, and its embedding is well-trained in the base model (part of the base vocab on Granite 4.0 and 4.1). Parallels the ALoRA path (substitute = first invocation token). Both paths now pick "the token that comes right after the control token in the rendered chat prompt" — single principle, two sources. Validated: - tokenizer.convert_tokens_to_ids('<|start_of_role|>') == 100264 on ibm-granite/granite-4.1-3b and granite-4.0-micro (part of base vocab, not composer-added). - bos_token_id == eos_token_id == 100257 ('<|end_of_text|>') on all three Granite tokenizers tested — confirming the prior default was semantically wrong. --- .../composer/compose_granite_switch.py | 27 +++++++++++++------ 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/src/granite_switch/composer/compose_granite_switch.py b/src/granite_switch/composer/compose_granite_switch.py index b6048ba..424a0d9 100755 --- a/src/granite_switch/composer/compose_granite_switch.py +++ b/src/granite_switch/composer/compose_granite_switch.py @@ -781,21 +781,32 @@ def build(): adapter_third_party = list(external_names) else: # Default: token-exchange. control_dims=0; every adapter needs a substitute id. - # ALoRA adapters substitute with the first token of their invocation sequence; - # LoRA/builtin adapters substitute with BOS (only required when at least one - # non-ALoRA adapter is present — ALoRA-only builds don't need BOS). + # + # Substitute choice (must mirror the token that appears right after the + # control token in the rendered chat prompt, so the swap keeps the + # residual stream in-distribution): + # - ALoRA: first token of the adapter's alora_invocation_tokens. + # - LoRA/builtin: <|start_of_role|>. The chat template places the + # LoRA control token at sequence start, immediately followed by + # <|start_of_role|>user<|end_of_role|>... — so <|start_of_role|> + # is the deterministic "what comes right after" for every LoRA + # adapter. Granite's bos_token_id is an alias for <|end_of_text|> + # (EOS), so we cannot use it here: injecting EOS mid-prompt is a + # stop-generation signal. + _LORA_SUBSTITUTE_TOKEN = "<|start_of_role|>" adapter_substitute_token_ids = [] for adapter_path, _name, technology, _source in all_discovered: if technology == "alora": sub_id = get_alora_first_invocation_token_id(adapter_path) else: - if tokenizer.bos_token_id is None: + sub_id = tokenizer.convert_tokens_to_ids(_LORA_SUBSTITUTE_TOKEN) + if sub_id is None or sub_id == tokenizer.unk_token_id: raise ValueError( - "Tokenizer has no bos_token_id; required for LoRA/builtin " - "token exchange. Pass --legacy-hiding to use the KV-hiding " - "path instead." + f"Tokenizer does not contain the LoRA substitute token " + f"{_LORA_SUBSTITUTE_TOKEN!r}; required for LoRA/builtin " + "token exchange. Pass --legacy-hiding to use the " + "KV-hiding path instead." ) - sub_id = tokenizer.bos_token_id adapter_substitute_token_ids.append(sub_id) # Token-exchange supersedes KV hiding — no hiding config needed. hiding_groups = None From 5dc99dc1f0d9d5b97ecaf2f48bcf7f0fc5a9a67b Mon Sep 17 00:00:00 2001 From: AlonMalach Date: Wed, 13 May 2026 18:33:30 +0300 Subject: [PATCH 04/18] Template: drop the first role marker after each adapter control token (#8) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The runtime embedding swap replaces each adapter control token's embedding with a substitute token's embedding — for LoRA adapters this is <|start_of_role|>, for assistant-boundary ALoRA adapters it's also <|start_of_role|> (the first token of their invocation sequence). But the chat template then emits a *real* <|start_of_role|> at the next position: the user or assistant role marker that naturally follows the control-token prefix. Result before this change: two consecutive positions carrying <|start_of_role|>'s embedding. The model has never seen that pattern during pretraining — a duplicate-embedding OOD right at the start of the decoder's residual stream. Fix: add a skip-once Jinja flag (ns.skip_next_start_of_role). Arm it when lora_prefix_insertion emits the LoRA control token, or when alora_insertion fires the fallback path for assistant-boundary ALoRAs. Wrap every <|start_of_role|> emission in the base Granite template with a skip-once block that consumes the flag. The flag is single-shot — only the very first <|start_of_role|> after the control token is suppressed; all later role markers emit normally. Not addressed in this PR: ALoRAs whose invocation text is in-message content text (, , ). The first token of these invocations is the single character '<', and the rest of the invocation text cannot be cleanly sliced at the template level without changing what 'requirements>' (or 'guardian>', etc.) tokenizes to. Those adapters retain the duplicate-embedding pattern until a runtime-level drop lands in a follow-up. Backward compatibility: old checkpoints (composed before this change) load unchanged — the template edit only runs at compose time and affects only newly-composed models. Their rendered output for LoRA and assistant-boundary ALoRA is now one token shorter than before (the suppressed <|start_of_role|>). Update the three test_chat_template tests whose assertions encoded the old contract. --- .../composer/tokenizer_setup.py | 71 ++++++++++++++- tests/composer/test_chat_template.py | 89 ++++++++++++++++--- 2 files changed, 145 insertions(+), 15 deletions(-) diff --git a/src/granite_switch/composer/tokenizer_setup.py b/src/granite_switch/composer/tokenizer_setup.py index a437af4..28e306d 100644 --- a/src/granite_switch/composer/tokenizer_setup.py +++ b/src/granite_switch/composer/tokenizer_setup.py @@ -192,9 +192,16 @@ def configure_chat_template( """ + # LoRA prefix: emit the control token at the sequence start AND arm + # skip_next_start_of_role so the template's very next <|start_of_role|> + # emission is suppressed. This avoids a duplicate-embedding OOD at runtime: + # the runtime swap replaces the control token's embedding with + # <|start_of_role|>'s embedding, and without this drop the sequence + # would carry two identical embeddings back-to-back. lora_prefix_insertion = """{#- For lora adapters: insert activation token at the very beginning -#} {%- if adapter_token and adapter_type == 'lora' %} {{- adapter_token }} +{%- set ns.skip_next_start_of_role = true %} {%- endif %} """ @@ -239,12 +246,18 @@ def configure_chat_template( # Fallback for adapters whose invocation sequence is the assistant role tokens: # Pass 1 never sets alora_target_idx >= 0 for those, so we emit here instead. + # Also arm skip_next_start_of_role so the generation-prompt <|start_of_role|> + # that would immediately follow is suppressed — mirrors the LoRA rationale: + # the runtime swap replaces the control token's embedding with the first + # invocation token's embedding (<|start_of_role|>), so without this drop the + # sequence would carry two identical embeddings back-to-back. alora_insertion = """{#- ALoRA fallback: insert activation token right before generation prompt. Only fires when Pass 1 found no user message with the invocation text (alora_target_idx == -1), meaning the adapter activates at the assistant role token boundary rather than inside a user message. -#} {%- if ns.adapter_token and ns.adapter_type == 'alora' and ns.alora_target_idx == -1 %} {{- ns.adapter_token }} +{%- set ns.skip_next_start_of_role = true %} {%- endif %} """ @@ -286,7 +299,8 @@ def configure_chat_template( "\n adapter_token=adapter_token," "\n adapter_type=adapter_type," "\n adapter_invocation_text=adapter_invocation_text," - "\n alora_target_idx=-1" + "\n alora_target_idx=-1," + "\n skip_next_start_of_role=false" "\n )" ) modified_chat_template = ( @@ -337,6 +351,61 @@ def configure_chat_template( else: modified_chat_template += "\n" + alora_insertion + # Skip-once wrapper for every <|start_of_role|> emission in the template. + # ns.skip_next_start_of_role is set to true immediately after a LoRA or + # assistant-boundary ALoRA control token is emitted; the very next role + # marker consumes the flag and is suppressed. Prevents a duplicate + # embedding at position 1 (see lora_prefix_insertion / alora_insertion + # comments). + # + # Every <|start_of_role|> in the base template appears inside a string + # literal, either merged with the following role text ('<|start_of_role|>user<|end_of_role|>') + # or standalone ('<|start_of_role|>' + message.role + ...). We split at + # the '<|start_of_role|>' boundary and route only that fragment through + # the skip-once Jinja block. + skip_once_block = ( + "{%- if ns.skip_next_start_of_role %}" + "{%- set ns.skip_next_start_of_role = false %}" + "{%- else %}" + "{{- '<|start_of_role|>' }}" + "{%- endif %}" + ) + # Case A: '<|start_of_role|>' as a standalone literal, possibly at the + # start of a concatenation ({{- '<|start_of_role|>' + expr + ... }}). + # Replace the literal emission with the skip block; the rest of the + # expression stays. Handles sites 77 and 79 directly. + modified_chat_template = re.sub( + r"\{\{-\s*'<\|start_of_role\|>'\s*\+\s*", + skip_once_block + "\n {{- ", + modified_chat_template, + ) + # Case B: '<|start_of_role|>ROLE<|end_of_role|>' merged literal (with or + # without trailing concatenation). Split the literal so only the + # '<|start_of_role|>' prefix goes through the skip block and the rest + # ('ROLE<|end_of_role|>' + anything) emits normally. + # Pattern: {{- 'literal_starting_with_start_of_role' (+ expr | ) }} + def _split_merged(match: "re.Match") -> str: + remainder = match.group(1) # text after <|start_of_role|> up to end of literal + tail = match.group(2) # trailing + expr or empty + return ( + skip_once_block + + "\n {{- '" + + remainder + + "'" + + tail + + " }}" + ) + + # Merged literal like '<|start_of_role|>system<|end_of_role|>' followed by + # optional " + expr + ...". The first group captures everything inside the + # literal after <|start_of_role|>; the second captures any trailing + # concatenation up to the closing }}. + modified_chat_template = re.sub( + r"\{\{-\s*'<\|start_of_role\|>([^']*)'((?:\s*\+\s*[^}]+?)?)\s*\}\}", + _split_merged, + modified_chat_template, + ) + tokenizer.chat_template = modified_chat_template print( f"Chat template configured with {len(adapter_mapping)} adapter mappings:" diff --git a/tests/composer/test_chat_template.py b/tests/composer/test_chat_template.py index 52f600e..bcbe69d 100644 --- a/tests/composer/test_chat_template.py +++ b/tests/composer/test_chat_template.py @@ -49,7 +49,15 @@ def _render(tokenizer, **kwargs): class TestConfigureChatTemplate: def test_lora_prefix_path(self): - """LoRA: activation token emitted at the very start of the sequence.""" + """LoRA: activation token emitted at the very start of the sequence. + + The skip-once flag set by lora_prefix_insertion suppresses the very + next <|start_of_role|>, so the rendered output is + '<|ctx_rel|>user<|end_of_role|>...', not + '<|ctx_rel|><|start_of_role|>user<|end_of_role|>...'. This keeps the + runtime embedding-swap from producing two identical consecutive + embeddings (see tokenizer_setup.py lora_prefix_insertion comment). + """ tokenizer = _make_tokenizer() configure_chat_template(tokenizer, [("/path/a", "ctx_rel", "lora")]) @@ -59,7 +67,12 @@ def test_lora_prefix_path(self): add_generation_prompt=True, adapter_name="ctx_rel", ) - assert result.startswith("<|ctx_rel|>") + assert result.startswith("<|ctx_rel|>user<|end_of_role|>"), ( + f"expected <|ctx_rel|> followed by 'user<|end_of_role|>' " + f"(skip-once suppressed <|start_of_role|>), got {result[:80]!r}" + ) + # Exactly one <|start_of_role|> should survive: the assistant turn. + assert result.count("<|start_of_role|>") == 1 def test_alora_pass1_pass2_path(self): """ALoRA Pass 1+2: token inserted in last user message before invocation text. @@ -109,12 +122,22 @@ def test_alora_fallback_path(self): adapter_name="answerability", ) assert "<|answerability|>" in result - # Token appears immediately before the generation prompt + # Token appears immediately before what would have been the generation + # prompt's <|start_of_role|>. The skip-once flag set by alora_insertion + # suppresses that <|start_of_role|>, so the rendered output has + # "<|answerability|>assistant<|end_of_role|>" — no role marker between + # the control token and the role name. Prevents a duplicate-embedding + # OOD at position 1 after the runtime swap (see tokenizer_setup.py + # alora_insertion comment). token = "<|answerability|>" - gen_prompt = "<|start_of_role|>assistant<|end_of_role|>" token_pos = result.index(token) - gen_pos = result.index(gen_prompt, token_pos) - assert result[token_pos + len(token):gen_pos].strip() == "" + after = result[token_pos + len(token):] + assert after.startswith("assistant<|end_of_role|>"), ( + f"expected 'assistant<|end_of_role|>' immediately after " + f"{token!r}, got {after[:60]!r}" + ) + # Only one <|start_of_role|> should survive: the one before the user turn. + assert result.count("<|start_of_role|>") == 1 def test_alora_pass1_pass2_iterable_content(self): """ALoRA Pass 1+2: token inserted correctly when message content is a list of parts. @@ -153,6 +176,33 @@ def test_alora_pass1_pass2_iterable_content(self): last_gen_pos = result.rindex(gen_prompt) assert result[last_gen_pos - len("<|req_check|>"):last_gen_pos] != "<|req_check|>" + def test_skip_once_is_single_shot(self): + """Skip-once flag consumes itself: only the first <|start_of_role|> + after a LoRA control token is suppressed; later role markers emit.""" + tokenizer = _make_tokenizer() + configure_chat_template(tokenizer, [("/path/a", "my_lora", "lora")]) + + # Two user turns so the template emits <|start_of_role|> three times: + # once per user turn + once for the generation prompt. Only the very + # first one should be suppressed. + result = _render( + tokenizer, + messages=[ + {"role": "user", "content": "first"}, + {"role": "assistant", "content": "reply"}, + {"role": "user", "content": "second"}, + ], + add_generation_prompt=True, + adapter_name="my_lora", + ) + assert result.startswith("<|my_lora|>user<|end_of_role|>"), ( + f"first <|start_of_role|> should be suppressed; got {result[:80]!r}" + ) + # Four role markers would be emitted normally (first user, assistant, + # second user, assistant-generation-prompt). Skip-once removes the + # first → exactly three survive. + assert result.count("<|start_of_role|>") == 3 + def test_no_adapter_no_tokens(self): """Without adapter_name the rendered output is identical to the original template.""" messages = [{"role": "user", "content": "Hello"}] @@ -210,13 +260,19 @@ def test_alora_fallback_from_adapter_config(self): add_generation_prompt=True, adapter_name="answerability", ) - # Fallback: token immediately before generation prompt + # Fallback: token immediately before generation prompt, with the + # generation-prompt <|start_of_role|> suppressed by the skip-once flag + # armed in alora_insertion. Output is "<|answerability|>assistant<|end_of_role|>". token = "<|answerability|>" - gen_prompt = "<|start_of_role|>assistant<|end_of_role|>" assert token in result token_pos = result.index(token) - gen_pos = result.index(gen_prompt, token_pos) - assert result[token_pos + len(token):gen_pos].strip() == "" + after = result[token_pos + len(token):] + assert after.startswith("assistant<|end_of_role|>"), ( + f"expected 'assistant<|end_of_role|>' immediately after " + f"{token!r}, got {after[:60]!r}" + ) + # Only the user-turn <|start_of_role|> should survive. + assert result.count("<|start_of_role|>") == 1 def test_alora_invocation_at_start_of_user_message(self): """ALoRA: invocation text is the first thing in the user message.""" @@ -300,6 +356,11 @@ def test_lora_prefix_from_adapter_config(self): adapter_name="summarization", ) assert result.startswith("<|summarization|>") + # Skip-once suppresses the user-turn <|start_of_role|>: output is + # "<|summarization|>user<|end_of_role|>...", not + # "<|summarization|><|start_of_role|>user...". Keeps the adapter + # substitute token from duplicating at runtime. + assert result.startswith("<|summarization|>user<|end_of_role|>") def test_mixed_adapters_from_adapter_config(self): """All three adapter types composed together, each activated independently.""" @@ -322,16 +383,16 @@ def test_mixed_adapters_from_adapter_config(self): ) assert "<|context_relevance|>" in result - # Activate answerability → fallback + # Activate answerability → fallback (skip-once suppresses the + # generation-prompt <|start_of_role|>). result = _render( tokenizer, messages=messages, add_generation_prompt=True, adapter_name="answerability", ) token = "<|answerability|>" - gen_prompt = "<|start_of_role|>assistant<|end_of_role|>" token_pos = result.index(token) - gen_pos = result.index(gen_prompt, token_pos) - assert result[token_pos + len(token):gen_pos].strip() == "" + after = result[token_pos + len(token):] + assert after.startswith("assistant<|end_of_role|>") # Activate summarization → prefix result = _render( From 9cdf630d799aa5d6f1d677beb3e82bb6e1f68d17 Mon Sep 17 00:00:00 2001 From: AlonMalach Date: Wed, 13 May 2026 19:05:27 +0300 Subject: [PATCH 05/18] Template: drop first char of in-message ALoRA invocation (#8) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes the remaining duplicate-embedding OOD at the swap site. Complements the skip-once <|start_of_role|> edit from the previous commit by extending the same principle to ALoRA adapters whose invocation text lives inside a user message (, , , , etc.). Change: in alora_pass2, after inserting the control token before the invocation text, also drop the first CHARACTER of the invocation text. Example: "Please <|req_check|>" becomes "Please <|req_check|>requirements>". At runtime the embedding-swap replaces the control token's embedding with the first invocation token's embedding — the embedding of '<'. The decoder then sees [<|req_check|>→e_<, requirements, >] — exactly what "" tokenizes to in isolation, with no duplicate. Why this is safe on the Granite tokenizer: verified empirically via a new property test (test_first_char_drop_equals_first_token_drop). For every ALoRA invocation in the standard library, tokenizing the full invocation and dropping the first token ID yields the same sequence as tokenizing the string with its first character removed. BPE's greedy- merge would break this property if the second-byte merges depended on the leading '<'; it doesn't, because '<' tokenizes as its own single- character token in every case. The accompanying test test_first_token_is_single_character asserts the complementary invariant: the first token of each invocation decodes to exactly one character. If a future invocation text starts with a multi-character first token, that test catches it — the Jinja edit (invocation_text[1:] drops one character) would otherwise silently produce a wrong-length drop. Combined with the previous commit (skip-once <|start_of_role|>), the duplicate-embedding pattern is now eliminated across all adapter types in the Granite adapter library: LoRA, assistant-boundary ALoRA, and in-message ALoRA. --- .../composer/tokenizer_setup.py | 21 +++- tests/composer/test_chat_template.py | 109 +++++++++++++++--- 2 files changed, 110 insertions(+), 20 deletions(-) diff --git a/src/granite_switch/composer/tokenizer_setup.py b/src/granite_switch/composer/tokenizer_setup.py index 28e306d..5f0a118 100644 --- a/src/granite_switch/composer/tokenizer_setup.py +++ b/src/granite_switch/composer/tokenizer_setup.py @@ -235,11 +235,28 @@ def configure_chat_template( # Pass 2: runs inside the main message loop after content.val is assembled. # rsplit(..., 1) splits on the last occurrence so the token lands in the # right place when the invocation text appears more than once in the message. - alora_pass2 = """ {#- ALoRA Pass 2: inject activation token before invocation text in the target message -#} + # + # Token drop (mirrors the <|start_of_role|> skip-once flag used for LoRA / + # assistant-boundary ALoRA): we also omit the FIRST CHARACTER of the + # invocation text. The runtime embedding swap replaces the control-token + # embedding with the first-invocation-token's embedding; writing the full + # invocation text after the control token would then produce two copies + # of that first-invocation-token back to back — an OOD pattern at the + # swap site. + # + # For every ALoRA invocation text in the standard Granite adapter library + # (, , , , etc.) the first + # character is a single '<' that the tokenizer emits as its own token, + # and the tail of the string retokenizes identically to the tail of the + # full string. So dropping the first character on the string side is + # equivalent to dropping exactly the first token on the tokenized side — + # no re-merging, no change to what follows. + alora_pass2 = """ {#- ALoRA Pass 2: inject activation token AND drop the first char of + the invocation text so the runtime-swapped embedding doesn't duplicate. -#} {%- if loop.index0 == ns.alora_target_idx %} {%- set _parts = content.val.rsplit(ns.adapter_invocation_text, 1) %} {%- if _parts | length > 1 %} - {%- set content.val = _parts[0] + ns.adapter_token + ns.adapter_invocation_text + _parts[1] %} + {%- set content.val = _parts[0] + ns.adapter_token + ns.adapter_invocation_text[1:] + _parts[1] %} {%- endif %} {%- endif %} """ diff --git a/tests/composer/test_chat_template.py b/tests/composer/test_chat_template.py index bcbe69d..e363afd 100644 --- a/tests/composer/test_chat_template.py +++ b/tests/composer/test_chat_template.py @@ -75,11 +75,15 @@ def test_lora_prefix_path(self): assert result.count("<|start_of_role|>") == 1 def test_alora_pass1_pass2_path(self): - """ALoRA Pass 1+2: token inserted in last user message before invocation text. + """ALoRA Pass 1+2: token inserted in last user message, first char of + invocation text dropped. Pass 1 finds the user message containing '' and sets - ns.alora_target_idx. Pass 2 splits content.val on '' - and rejoins with the control token before the last occurrence. + ns.alora_target_idx. Pass 2 splits content.val on '' + and rejoins with the control token followed by the invocation text + MINUS its first character ('<' is dropped). The runtime swap + replaces the control token's embedding with '<'s embedding, so the + sequence tokenizes the same as '' with no duplicate. The fallback block does NOT fire (alora_target_idx >= 0). """ with patch(_PATCH_TARGET, return_value=""): @@ -94,9 +98,13 @@ def test_alora_pass1_pass2_path(self): add_generation_prompt=True, adapter_name="req_check", ) - # Token immediately precedes the invocation text inside the user turn + # Token immediately precedes the invocation text (minus first char) + # inside the user turn: "<|req_check|>requirements>" (no '<'). user_turn_header = "<|start_of_role|>user<|end_of_role|>" - assert user_turn_header + "<|req_check|>" in result + assert user_turn_header + "<|req_check|>requirements>" in result + # And the literal "<|req_check|>" should NOT appear — + # the leading '<' must have been dropped. + assert "<|req_check|>" not in result # Fallback did not fire: token is not immediately before generation prompt gen_prompt = "<|start_of_role|>assistant<|end_of_role|>" last_gen_pos = result.rindex(gen_prompt) @@ -168,8 +176,10 @@ def test_alora_pass1_pass2_iterable_content(self): add_generation_prompt=True, adapter_name="req_check", ) - # Token must appear immediately before invocation text inside the user turn - assert "<|req_check|>" in result + # Token appears before the invocation text, and the invocation + # text's first character ('<') has been dropped. + assert "<|req_check|>requirements>" in result + assert "<|req_check|>" not in result assert result.index("<|req_check|>") > result.index("<|start_of_role|>user<|end_of_role|>") # Fallback must NOT also fire gen_prompt = "<|start_of_role|>assistant<|end_of_role|>" @@ -219,6 +229,55 @@ def test_no_adapter_no_tokens(self): assert modified == original +class TestInvocationFirstCharDropProperty: + """Standalone property test on a real Granite tokenizer: dropping the first + character of an ALoRA invocation text yields the same tail-token sequence + as tokenizing the full invocation text and dropping its first token. This + is the BPE-level invariant the Pass-2 edit relies on — if a future + tokenizer change breaks it, the template-level drop would silently corrupt + the tail of the invocation. + """ + + _INVOCATIONS = [ + "", + "", + "", + "", + ] + + def _get_tokenizer(self): + from transformers import AutoTokenizer + try: + return AutoTokenizer.from_pretrained("ibm-granite/granite-4.1-3b") + except Exception as e: + import pytest + pytest.skip(f"could not fetch Granite tokenizer: {e}") + + def test_first_char_drop_equals_first_token_drop(self): + tok = self._get_tokenizer() + for invocation in self._INVOCATIONS: + full_ids = tok(invocation, add_special_tokens=False).input_ids + dropped_ids = tok(invocation[1:], add_special_tokens=False).input_ids + assert full_ids[1:] == dropped_ids, ( + f"invocation {invocation!r}: dropping first char of the " + f"string produced tokens {dropped_ids} but the tail of the " + f"full tokenization is {full_ids[1:]}" + ) + + def test_first_token_is_single_character(self): + """Sanity: the first token of each invocation must be exactly one + character (the leading '<'). Otherwise dropping invocation_text[1:] + in Jinja would drop the wrong number of characters.""" + tok = self._get_tokenizer() + for invocation in self._INVOCATIONS: + first_id = tok(invocation, add_special_tokens=False).input_ids[0] + first_str = tok.decode([first_id]) + assert first_str == invocation[0], ( + f"invocation {invocation!r}: first token decodes to " + f"{first_str!r}, expected {invocation[0]!r}" + ) + + class _FixtureTokenizer: """Tokenizer with a decode map for fixture adapter token IDs.""" @@ -275,7 +334,12 @@ def test_alora_fallback_from_adapter_config(self): assert result.count("<|start_of_role|>") == 1 def test_alora_invocation_at_start_of_user_message(self): - """ALoRA: invocation text is the first thing in the user message.""" + """ALoRA: invocation text is the first thing in the user message. + + Pass 2 drops the first character of the invocation text after + inserting the control token, so "" becomes + "<|context_relevance|>context>" in the rendered output. + """ tokenizer = self._make_tokenizer({(27,): ""}) configure_chat_template(tokenizer, [ (self._CONTEXT_REL, "context_relevance", "alora"), @@ -287,16 +351,21 @@ def test_alora_invocation_at_start_of_user_message(self): add_generation_prompt=True, adapter_name="context_relevance", ) - # Token injected right after the user role header, before + # Token injected right after the user role header; the '<' of + # the invocation text is dropped. user_header = "<|start_of_role|>user<|end_of_role|>" - assert user_header + "<|context_relevance|>" in result + assert user_header + "<|context_relevance|>context>" in result + assert "<|context_relevance|>" not in result # Fallback must NOT fire gen_prompt = "<|start_of_role|>assistant<|end_of_role|>" last_gen_pos = result.rindex(gen_prompt) assert result[last_gen_pos - len("<|context_relevance|>"):last_gen_pos] != "<|context_relevance|>" def test_alora_invocation_mid_user_message(self): - """ALoRA: invocation text appears in the middle of the user message.""" + """ALoRA: invocation text appears in the middle of the user message. + + Same first-character drop as the start-of-message case. + """ tokenizer = self._make_tokenizer({(27,): ""}) configure_chat_template(tokenizer, [ (self._CONTEXT_REL, "context_relevance", "alora"), @@ -308,8 +377,9 @@ def test_alora_invocation_mid_user_message(self): add_generation_prompt=True, adapter_name="context_relevance", ) - # Token injected mid-message, before - assert "Please review: <|context_relevance|>" in result + # Token injected mid-message, invocation text's '<' dropped. + assert "Please review: <|context_relevance|>context>" in result + assert "<|context_relevance|>" not in result user_header = "<|start_of_role|>user<|end_of_role|>" assert result.index("<|context_relevance|>") > result.index(user_header) # Fallback must NOT fire @@ -321,7 +391,8 @@ def test_alora_multiple_occurrences_targets_last(self): """ALoRA: invocation text appears twice — token injected before the last occurrence. rsplit(..., 1) splits on the last occurrence, so the control token must - land before the second , not the first. + land before the second , not the first. First occurrence + remains intact with its '<'; only the second has its '<' dropped. """ tokenizer = self._make_tokenizer({(27,): ""}) configure_chat_template(tokenizer, [ @@ -337,8 +408,9 @@ def test_alora_multiple_occurrences_targets_last(self): add_generation_prompt=True, adapter_name="context_relevance", ) - # The first must NOT have the control token before it - assert "first batch Also check <|context_relevance|>second batch" in result + # First untouched; second one has the control token + # inserted with its '<' dropped. + assert "first batch Also check <|context_relevance|>context>second batch" in result # Only one control token in the entire output assert result.count("<|context_relevance|>") == 1 @@ -376,12 +448,13 @@ def test_mixed_adapters_from_adapter_config(self): messages = [{"role": "user", "content": "docs"}] - # Activate context_relevance → Pass 1+2 + # Activate context_relevance → Pass 1+2 (drops first char of invocation). result = _render( tokenizer, messages=messages, add_generation_prompt=True, adapter_name="context_relevance", ) - assert "<|context_relevance|>" in result + assert "<|context_relevance|>context>" in result + assert "<|context_relevance|>" not in result # Activate answerability → fallback (skip-once suppresses the # generation-prompt <|start_of_role|>). From 1b51f59197c8e06cb29eb3fd72eacb06a7276be6 Mon Sep 17 00:00:00 2001 From: AlonMalach Date: Wed, 13 May 2026 20:08:19 +0300 Subject: [PATCH 06/18] Derive LoRA substitute from the tokenizer's chat template (#8) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously the composer hardcoded _LORA_SUBSTITUTE_TOKEN = "<|start_of_role|>". That's the right answer for Granite 4.x but it ties the default-path composer to a Granite-specific token name. Any base model with a different chat template (different role marker, different turn-open convention) would silently get the wrong substitute — a token the base model knows, but not the one sitting at position 1 of its rendered prompt. Replace the hardcode with a compose-time probe: render a minimal no-adapter user turn through tokenizer.apply_chat_template, tokenize, and read input_ids[0]. That's by construction whatever the template emits at the start of a normal turn, which is exactly what sits at position 1 after a LoRA-prepended control token. The substitute and the template's own behavior are now derived from the same source of truth. Verified: the probe returns 100264 (<|start_of_role|>) on granite-4.1-3b, granite-4.0-micro, and granite-switch-4.1-3b-preview — identical to the previous hardcoded value. Behavior on Granite is unchanged; the door is open for non-Granite base models. Error paths give actionable messages: - Tokenizer has no chat_template → suggest --legacy-hiding - Template render fails → report the Jinja error, suggest --legacy-hiding - First token is → report that the template emits something outside the vocab - Probe returns an empty id list → same Tests: - tests/composer/test_lora_substitute_probe.py (7 cases): * Real tokenizer round-trip on granite-4.1-3b and 4.0-micro * Synthetic tokenizer with a non-Granite template returns the custom template's first-token id * All four error paths raise ValueError with matching messages --- .../composer/compose_granite_switch.py | 77 +++++++--- tests/composer/test_lora_substitute_probe.py | 134 ++++++++++++++++++ 2 files changed, 195 insertions(+), 16 deletions(-) create mode 100644 tests/composer/test_lora_substitute_probe.py diff --git a/src/granite_switch/composer/compose_granite_switch.py b/src/granite_switch/composer/compose_granite_switch.py index 424a0d9..41b7c02 100755 --- a/src/granite_switch/composer/compose_granite_switch.py +++ b/src/granite_switch/composer/compose_granite_switch.py @@ -77,6 +77,60 @@ def _load_tokenizer(model_name_or_path): return AutoTokenizer.from_pretrained(model_name_or_path) +def _probe_lora_substitute_token_id(tokenizer) -> int: + """Ask the tokenizer which token naturally appears at the start of a + rendered no-adapter chat. + + The LoRA prefix insertion prepends the adapter control token at the very + beginning of the rendered output, so whatever the template emits first + for a normal user turn is exactly what sits at position 1 after the + control token — and therefore the right substitute whose embedding + should land at the swap site. + + By deriving this from the tokenizer's own chat template at compose + time, we avoid hard-coding a Granite-4.1-specific token string + (<|start_of_role|>). Other base models with different chat templates + get the correct substitute for their template by construction. + + Raises ``ValueError`` if the template cannot be rendered or the first + tokenized id cannot be determined. Callers should suggest + ``--legacy-hiding`` as the fallback. + """ + if tokenizer.chat_template is None: + raise ValueError( + "Tokenizer has no chat_template; cannot probe the LoRA " + "substitute token. Pass --legacy-hiding to use the KV-hiding " + "path instead." + ) + try: + probe_text = tokenizer.apply_chat_template( + [{"role": "user", "content": "probe"}], + tokenize=False, + add_generation_prompt=False, + ) + except Exception as e: + raise ValueError( + "Failed to render a probe chat via tokenizer.apply_chat_template " + f"while detecting the LoRA substitute token: {e!r}. " + "Pass --legacy-hiding to use the KV-hiding path instead." + ) from e + ids = tokenizer(probe_text, add_special_tokens=False).input_ids + if not ids: + raise ValueError( + "Probe chat tokenized to an empty id list; cannot determine the " + "LoRA substitute token. Pass --legacy-hiding to use the " + "KV-hiding path instead." + ) + sub_id = ids[0] + if sub_id == tokenizer.unk_token_id: + raise ValueError( + "First token of the rendered probe chat is ; the template " + "appears to emit content outside the tokenizer's vocabulary. " + "Pass --legacy-hiding to use the KV-hiding path instead." + ) + return sub_id + + def _get_directory_size(directory): """Return ``(total_size in GBs, file_count)`` for *directory*.""" if Path(directory).exists(): @@ -786,27 +840,18 @@ def build(): # control token in the rendered chat prompt, so the swap keeps the # residual stream in-distribution): # - ALoRA: first token of the adapter's alora_invocation_tokens. - # - LoRA/builtin: <|start_of_role|>. The chat template places the - # LoRA control token at sequence start, immediately followed by - # <|start_of_role|>user<|end_of_role|>... — so <|start_of_role|> - # is the deterministic "what comes right after" for every LoRA - # adapter. Granite's bos_token_id is an alias for <|end_of_text|> - # (EOS), so we cannot use it here: injecting EOS mid-prompt is a - # stop-generation signal. - _LORA_SUBSTITUTE_TOKEN = "<|start_of_role|>" + # - LoRA/builtin: whatever the tokenizer's chat template emits at + # the very start of a no-adapter user turn. For Granite 4.x + # that's <|start_of_role|>; the probe derives this from the + # template at compose time so other base models work by + # construction. + lora_sub_id = _probe_lora_substitute_token_id(tokenizer) adapter_substitute_token_ids = [] for adapter_path, _name, technology, _source in all_discovered: if technology == "alora": sub_id = get_alora_first_invocation_token_id(adapter_path) else: - sub_id = tokenizer.convert_tokens_to_ids(_LORA_SUBSTITUTE_TOKEN) - if sub_id is None or sub_id == tokenizer.unk_token_id: - raise ValueError( - f"Tokenizer does not contain the LoRA substitute token " - f"{_LORA_SUBSTITUTE_TOKEN!r}; required for LoRA/builtin " - "token exchange. Pass --legacy-hiding to use the " - "KV-hiding path instead." - ) + sub_id = lora_sub_id adapter_substitute_token_ids.append(sub_id) # Token-exchange supersedes KV hiding — no hiding config needed. hiding_groups = None diff --git a/tests/composer/test_lora_substitute_probe.py b/tests/composer/test_lora_substitute_probe.py new file mode 100644 index 0000000..f5c90f1 --- /dev/null +++ b/tests/composer/test_lora_substitute_probe.py @@ -0,0 +1,134 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for _probe_lora_substitute_token_id. + +The probe derives the LoRA substitute token from the tokenizer's chat +template rather than hardcoding a Granite-4.x-specific token string. These +tests verify: + +1. On real Granite tokenizers, the probe returns <|start_of_role|> (id + 100264) — the token the LoRA prefix insertion places immediately after + the control token in the rendered prompt. +2. On a synthetic tokenizer with a different template, the probe returns + whatever that template emits first for a user turn. +3. The probe raises a clear error when the template is missing, fails to + render, or emits an unknown token. +""" + +from types import SimpleNamespace + +import pytest + +from granite_switch.composer.compose_granite_switch import ( + _probe_lora_substitute_token_id, +) + + +class TestOnRealGraniteTokenizer: + """Exercise the probe on actual Granite tokenizers. Network-dependent; + skips cleanly if the model can't be fetched.""" + + def _tok(self, name): + from transformers import AutoTokenizer + try: + return AutoTokenizer.from_pretrained(name) + except Exception as e: + pytest.skip(f"could not fetch tokenizer {name!r}: {e}") + + def test_granite_4_1_3b(self): + tok = self._tok("ibm-granite/granite-4.1-3b") + sub_id = _probe_lora_substitute_token_id(tok) + assert sub_id == 100264 + assert tok.convert_ids_to_tokens([sub_id])[0] == "<|start_of_role|>" + + def test_granite_4_0_micro(self): + tok = self._tok("ibm-granite/granite-4.0-micro") + sub_id = _probe_lora_substitute_token_id(tok) + assert sub_id == 100264 + assert tok.convert_ids_to_tokens([sub_id])[0] == "<|start_of_role|>" + + +class TestOnSyntheticTokenizer: + """Verify the probe is generic — it returns whatever the template emits, + not a Granite-specific hardcoded token.""" + + def test_custom_template_gives_custom_token(self): + """A template whose first emission is a different marker produces + the id of that different marker.""" + + class _FakeTokenizer: + chat_template = "" + unk_token_id = 0 + + def apply_chat_template( + self, messages, tokenize, add_generation_prompt + ): + assert tokenize is False + assert add_generation_prompt is False + return "hello" + + def __call__(self, text, **kwargs): + # Pretend tokenizes as [42], "hello" as [7, 8, 9, 10, 11]. + assert kwargs.get("add_special_tokens") is False + assert text == "hello" + return SimpleNamespace(input_ids=[42, 7, 8, 9, 10, 11]) + + assert _probe_lora_substitute_token_id(_FakeTokenizer()) == 42 + + +class TestErrorPaths: + + def _minimal_tokenizer_without_template(self): + class _T: + chat_template = None + unk_token_id = 0 + def apply_chat_template(self, *a, **kw): + raise AssertionError("should not be called") + def __call__(self, text, **kw): + raise AssertionError("should not be called") + return _T() + + def _tokenizer_whose_template_fails(self): + class _T: + chat_template = "" + unk_token_id = 0 + def apply_chat_template(self, *a, **kw): + raise RuntimeError("template exploded") + def __call__(self, text, **kw): + raise AssertionError("unreachable") + return _T() + + def _tokenizer_emitting_unk(self): + class _T: + chat_template = "" + unk_token_id = 777 + def apply_chat_template(self, messages, tokenize, add_generation_prompt): + return "mystery" + def __call__(self, text, **kw): + return SimpleNamespace(input_ids=[777]) + return _T() + + def _tokenizer_emitting_empty(self): + class _T: + chat_template = "" + unk_token_id = 0 + def apply_chat_template(self, messages, tokenize, add_generation_prompt): + return "" + def __call__(self, text, **kw): + return SimpleNamespace(input_ids=[]) + return _T() + + def test_missing_chat_template_raises(self): + with pytest.raises(ValueError, match="no chat_template"): + _probe_lora_substitute_token_id(self._minimal_tokenizer_without_template()) + + def test_template_render_failure_raises(self): + with pytest.raises(ValueError, match="Failed to render a probe chat"): + _probe_lora_substitute_token_id(self._tokenizer_whose_template_fails()) + + def test_unk_first_token_raises(self): + with pytest.raises(ValueError, match=""): + _probe_lora_substitute_token_id(self._tokenizer_emitting_unk()) + + def test_empty_tokenization_raises(self): + with pytest.raises(ValueError, match="empty id list"): + _probe_lora_substitute_token_id(self._tokenizer_emitting_empty()) From df9512c1c0549c89c528509fcdd8a2169c9e8fd4 Mon Sep 17 00:00:00 2001 From: AlonMalach Date: Thu, 14 May 2026 11:37:56 +0300 Subject: [PATCH 07/18] Move token-exchange rewrite into the switch (#8) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Refactor: the runtime substitution LUT and the embedding-swap step move out of each backend's decoder and into SingleSwitch (HF + vLLM). The switch now performs both halves of token-exchange: 1. Adapter selection — read input_ids, detect control tokens via input_ids == adapter_token_ids, emit per-token adapter_indices (unchanged). 2. Token rewrite — replace each control token's id in input_ids with its substitute id (from a switch-owned LUT). New. The switch's forward signature changes from -> adapter_indices to -> (adapter_indices, modified_input_ids) The decoder consumes both: adapter_indices feeds the LoRA layers as before, modified_input_ids feeds embed_tokens / get_input_embeddings exactly once. There is no longer a decoder-side LUT, no scatter, no clone-guard, no use_token_exchange branch in the embedding path. Why this is cleaner: - Single source of truth for the substitution. The switch already knows which positions are control tokens; rewriting input_ids at those positions is a natural extension of "decide which adapter is active." The decoder is genuinely token-exchange-agnostic — it just embeds whatever input_ids it receives. - HF and vLLM converge to the same control flow. Both backends now call switch(...), unpack two outputs, embed once. Previously each backend had a near-identical but layout-specific (B,S,H vs T,H) embedding-swap block + clone-guard that needed to be maintained separately. - Smaller diff for any future change to the substitution logic. Whether to ship a different substitute strategy (e.g. learned embedding, per-adapter rules) becomes a one-place change in the switch instead of a two-place change across both decoders. HF model forward also reorders slightly: switch runs before embed_tokens, so we embed exactly once on modified_input_ids. create_causal_mask now receives a stub embedding tensor of the right shape and dtype (it only uses the tensor for batch/query/dtype inference per the upstream docstring), since the real embedding hasn't been computed yet. Tests: - tests/hf/test_single_switch.py: _run helper unpacks the new tuple return; TestBatchProcessing similarly. - tests/hf/test_token_exchange.py: LUT presence assertion now reads model.model.switch.control_to_substitute_lut instead of model.model.control_to_substitute_lut. No behavior change verified by 756 passing tests (= same count as before the refactor; +0 -0 after fixture updates). --- .../hf/modeling_granite_switch.py | 90 ++++++++----------- src/granite_switch/hf/switch/single.py | 65 ++++++++++++-- .../vllm/granite_switch_model.py | 48 +++------- src/granite_switch/vllm/switch/single.py | 56 ++++++++++-- tests/hf/test_single_switch.py | 16 ++-- tests/hf/test_token_exchange.py | 5 +- 6 files changed, 169 insertions(+), 111 deletions(-) diff --git a/src/granite_switch/hf/modeling_granite_switch.py b/src/granite_switch/hf/modeling_granite_switch.py index 518bdc0..2f064c0 100644 --- a/src/granite_switch/hf/modeling_granite_switch.py +++ b/src/granite_switch/hf/modeling_granite_switch.py @@ -188,26 +188,10 @@ def __init__(self, config: GraniteSwitchConfig): torch.zeros(config.num_adapters, dtype=torch.long), ) - # --- Token-exchange buffers --- - # control_to_substitute_lut: [vocab_size], -1 for non-control token ids - # and the substitute token id at each control slot. Lets the embedding - # swap run as a single gather + masked scatter without allocating an - # [B, S, num_adapters] intermediate. - if config.use_token_exchange: - sub_ids = config.adapter_substitute_token_ids - self.register_buffer( - "adapter_substitute_token_ids", - torch.tensor(sub_ids, dtype=torch.long), - ) - max_ctrl_id = max(config.adapter_token_ids) - lut_size = max(config.vocab_size, max_ctrl_id + 1) - lut = torch.full((lut_size,), -1, dtype=torch.long) - for ctrl_id, sub_id in zip(config.adapter_token_ids, sub_ids): - lut[ctrl_id] = sub_id - self.register_buffer("control_to_substitute_lut", lut) - else: - self.adapter_substitute_token_ids = None - self.control_to_substitute_lut = None + # Token-exchange LUT lives on the switch module (see hf/switch/ + # single.py); the switch rewrites input_ids in-place during its + # forward pass, so this model class no longer needs a decoder- + # side substitute table. # --- Hiding group buffers --- # token_to_group_mask: [vocab_size, num_groups] lookup table. @@ -245,8 +229,6 @@ def __init__(self, config: GraniteSwitchConfig): else: self.switch = None self.adapter_token_ids = None - self.adapter_substitute_token_ids = None - self.control_to_substitute_lut = None self.token_to_group_mask = None self.adapter_hiding_matrix = None @@ -310,55 +292,53 @@ def forward( ) use_cache = False - inputs_embeds_owned = inputs_embeds is None - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - # Token exchange: replace each control token's embedding with the - # substitute token's embedding before the embedding multiplier so - # both paths receive the same scaling. The LUT maps control token - # ids to substitute ids; non-control positions produce -1. - if self.config.use_token_exchange and input_ids is not None: - sub_id_per_pos = self.control_to_substitute_lut[input_ids] - is_control = sub_id_per_pos >= 0 - if is_control.any(): - flat_sub_ids = sub_id_per_pos[is_control] - sub_embeds = self.embed_tokens(flat_sub_ids) - # Only clone when the caller owns the tensor; if we just - # allocated it via embed_tokens(input_ids), mutating is safe. - if not inputs_embeds_owned: - inputs_embeds = inputs_embeds.clone() - inputs_embeds[is_control] = sub_embeds - - inputs_embeds = inputs_embeds * self.embedding_multiplier - # Initialize cache if use_cache and past_key_values is None: past_key_values = DynamicCache(config=self.config) + # Determine sequence shape and device. With input_ids we get them + # directly; with pre-supplied inputs_embeds we read from the tensor. + if input_ids is not None: + batch_size, seq_length = input_ids.shape + device = input_ids.device + else: + batch_size, seq_length = inputs_embeds.shape[:2] + device = inputs_embeds.device + if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + past_seen_tokens, past_seen_tokens + seq_length, device=device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) - # Causal mask (4D for attention layers) + # Causal mask (4D for attention layers). create_causal_mask only + # uses the embedding tensor for batch/query/dtype inference; we + # haven't embedded yet (the switch call below may rewrite input_ids + # first), so pass a stub of the right shape/dtype. + embed_dtype = self.embed_tokens.weight.dtype + mask_shape_proxy = inputs_embeds if inputs_embeds is not None else torch.empty( + batch_size, seq_length, 1, device=device, dtype=embed_dtype + ) causal_mask = create_causal_mask( config=self.config, - input_embeds=inputs_embeds, + input_embeds=mask_shape_proxy, attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, position_ids=position_ids, ) - # Compute adapter_indices using switch (BEFORE RoPE for position correction) + # Compute adapter_indices using switch (BEFORE RoPE for position correction). + # The switch also returns modified_input_ids: input_ids with each + # control token rewritten to its substitute id, so the decoder can + # embed once without any token-exchange awareness. hidden_count = None + modified_input_ids = input_ids if self.switch is not None: - adapter_indices = self.switch( + adapter_indices, modified_input_ids = self.switch( input_ids=input_ids, adapter_token_ids=self.adapter_token_ids, attention_mask=causal_mask, @@ -384,15 +364,19 @@ def forward( if hidden_count is None and not self.config.use_token_exchange: hidden_count = (adapter_indices > 0).long() else: - batch_size, seq_length = inputs_embeds.shape[:2] adapter_indices = torch.zeros( - (batch_size, seq_length), - dtype=torch.long, - device=inputs_embeds.device + (batch_size, seq_length), dtype=torch.long, device=device, ) token_group_membership = None query_group_suppression = None + # Embed once, on the (possibly-rewritten) input_ids. The decoder is + # token-exchange-agnostic — it just embeds whatever the switch + # passed through. + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(modified_input_ids) + inputs_embeds = inputs_embeds * self.embedding_multiplier + # Expose adapter_indices for tests and debugging. self._last_adapter_indices = adapter_indices diff --git a/src/granite_switch/hf/switch/single.py b/src/granite_switch/hf/switch/single.py index 7a26a29..e35ac91 100644 --- a/src/granite_switch/hf/switch/single.py +++ b/src/granite_switch/hf/switch/single.py @@ -14,7 +14,7 @@ import torch import torch.nn as nn -from typing import Optional +from typing import Optional, Tuple from transformers.cache_utils import Cache from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.models.granite.modeling_granite import eager_attention_forward @@ -78,6 +78,28 @@ def __init__( # Switch is layer 0, decoder layers are 1 to num_hidden_layers self.layer_idx = layer_idx + # control_to_substitute_lut: [vocab_size_or_higher], -1 at non-control + # ids and the substitute id at each control slot. The switch performs + # the runtime token-exchange: it rewrites input_ids in-place so that + # control-token positions carry the substitute id by the time the + # decoder embeds them. The decoder is then oblivious — it just calls + # embed_tokens(input_ids) and gets the right result by construction. + if ( + config is not None + and getattr(config, "adapter_token_ids", None) is not None + and getattr(config, "adapter_substitute_token_ids", None) is not None + ): + ctrl_ids = config.adapter_token_ids + sub_ids = config.adapter_substitute_token_ids + max_ctrl_id = max(ctrl_ids) + lut_size = max(getattr(config, "vocab_size", 0), max_ctrl_id + 1) + lut = torch.full((lut_size,), -1, dtype=torch.long) + for ctrl_id, sub_id in zip(ctrl_ids, sub_ids): + lut[ctrl_id] = sub_id + self.register_buffer("control_to_substitute_lut", lut) + else: + self.control_to_substitute_lut = None + @property def num_cache_layers(self) -> int: """Number of cache slots used.""" @@ -90,13 +112,19 @@ def forward( attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Compute adapter indices using single-head attention mechanism. + Compute adapter indices and rewrite control tokens via the LUT. + + The switch performs both halves of token-exchange: + 1. Adapter selection: read input_ids, detect control tokens via + input_ids == adapter_token_ids, emit per-token adapter_indices. + 2. Token rewrite: replace each control token's id in input_ids + with its substitute id (from control_to_substitute_lut). - The switch uses the same head_dim as decoder layers to share the model's Cache object, - ensuring standard HuggingFace behavior where past_key_values is exposed and managed - by the caller. + Returning the rewritten input_ids means the decoder is oblivious to + the swap — it simply embeds whatever it's given. There's no + decoder-side LUT, no per-forward scatter, no clone-guard. Args: input_ids: Input token IDs [batch, seq_len] @@ -110,7 +138,10 @@ def forward( cache_position: Position indices for caching [seq_len] Returns: - adapter_indices: [batch, seq_len] where 0 = base, 1+ = adapters + (adapter_indices, modified_input_ids): + adapter_indices: [batch, seq_len] where 0 = base, 1+ = adapters. + modified_input_ids: [batch, seq_len] with each control-token + id replaced by its substitute id. """ bsz, q_len = input_ids.shape device = input_ids.device @@ -195,4 +226,22 @@ def forward( f"adapter_indices shape {adapter_indices.shape} must match input_ids shape {input_ids.shape}" ) - return adapter_indices + # Token-exchange rewrite: replace each control token's id with its + # substitute id via the LUT. Done here (rather than in the decoder) + # so the decoder sees a clean, unified input_ids and never has to + # know about substitutes. Skipped only when the LUT was not built + # (no substitute ids configured — e.g. a non-token-exchange test + # fixture). + if self.control_to_substitute_lut is not None: + sub_id_per_pos = self.control_to_substitute_lut[input_ids] + is_control = sub_id_per_pos >= 0 + if is_control.any(): + modified_input_ids = torch.where( + is_control, sub_id_per_pos, input_ids + ) + else: + modified_input_ids = input_ids + else: + modified_input_ids = input_ids + + return adapter_indices, modified_input_ids diff --git a/src/granite_switch/vllm/granite_switch_model.py b/src/granite_switch/vllm/granite_switch_model.py index 7a8f7b6..9646883 100644 --- a/src/granite_switch/vllm/granite_switch_model.py +++ b/src/granite_switch/vllm/granite_switch_model.py @@ -155,24 +155,10 @@ def __init__( torch.zeros(num_adapters, dtype=torch.long), ) - # --- Token-exchange LUT --- - # See the HF model for the shared rationale. -1 indicates "not a - # control token"; other positions map control id → substitute id. - if config.use_token_exchange: - sub_ids = config.adapter_substitute_token_ids - self.register_buffer( - "adapter_substitute_token_ids", - torch.tensor(sub_ids, dtype=torch.long), - ) - max_ctrl_id = max(config.adapter_token_ids) - lut_size = max(config.vocab_size, max_ctrl_id + 1) - lut = torch.full((lut_size,), -1, dtype=torch.long) - for ctrl_id, sub_id in zip(config.adapter_token_ids, sub_ids): - lut[ctrl_id] = sub_id - self.register_buffer("control_to_substitute_lut", lut) - else: - self.adapter_substitute_token_ids = None - self.control_to_substitute_lut = None + # Token-exchange LUT lives on the switch module + # (see vllm/switch/single.py); the switch rewrites input_ids + # in-place during its forward pass, so this model class no + # longer needs a decoder-side substitute table. # Initialize compile-friendly LoRA metadata handler # This replaces vLLM's LoRAKernelMeta with a torch.compile-compatible version @@ -212,8 +198,6 @@ def __init__( else: self.switch = None self.adapter_token_ids = None - self.adapter_substitute_token_ids = None - self.control_to_substitute_lut = None self.lora_meta = None self.token_to_group_mask = None self.adapter_hiding_matrix = None @@ -353,7 +337,7 @@ def forward( hidden_count = None if get_pp_group().is_first_rank: if self.switch is not None: - adapter_indices = self.switch( + adapter_indices, modified_input_ids = self.switch( input_ids=input_ids, adapter_token_ids=self.adapter_token_ids, ) @@ -365,6 +349,7 @@ def forward( dtype=torch.long, device=input_ids.device ) + modified_input_ids = input_ids # Step 2: Compute group-based hiding masks. if self.token_to_group_mask is not None: @@ -443,23 +428,12 @@ def forward( if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds - hidden_states_owned = False else: - hidden_states = self.get_input_embeddings(input_ids) - hidden_states_owned = True - - # Token exchange: mirror of the HF path. vLLM tensors are flat - # [num_tokens, hidden]; the gather + masked scatter runs pre- - # multiplier so both raw and substitute embeddings are scaled once. - if self.config.use_token_exchange and input_ids is not None: - sub_id_per_pos = self.control_to_substitute_lut[input_ids] - is_control = sub_id_per_pos >= 0 - if is_control.any(): - flat_sub_ids = sub_id_per_pos[is_control] - sub_embeds = self.get_input_embeddings(flat_sub_ids) - if not hidden_states_owned: - hidden_states = hidden_states.clone() - hidden_states[is_control] = sub_embeds + # Embed the (possibly-rewritten) input_ids the switch returned. + # The switch already performed the token-exchange rewrite, so + # this single lookup produces the correct embeddings for both + # control positions (substitute id) and content positions. + hidden_states = self.get_input_embeddings(modified_input_ids) hidden_states *= self.config.embedding_multiplier residual = None diff --git a/src/granite_switch/vllm/switch/single.py b/src/granite_switch/vllm/switch/single.py index 1171466..da85801 100644 --- a/src/granite_switch/vllm/switch/single.py +++ b/src/granite_switch/vllm/switch/single.py @@ -18,7 +18,7 @@ import torch import torch.nn as nn -from typing import Optional +from typing import Optional, Tuple from vllm.model_executor.layers.attention.attention import Attention from vllm.config import VllmConfig @@ -92,6 +92,29 @@ def __init__( prefix="switch.layers.0", ) + # control_to_substitute_lut: [vocab_size_or_higher], -1 at non-control + # ids and the substitute id at each control slot. The switch performs + # the runtime token-exchange: it rewrites input_ids in-place so that + # control-token positions carry the substitute id by the time the + # decoder embeds them. The decoder is then oblivious — it just calls + # get_input_embeddings(input_ids) and gets the right result by + # construction. + if ( + config is not None + and getattr(config, "adapter_token_ids", None) is not None + and getattr(config, "adapter_substitute_token_ids", None) is not None + ): + ctrl_ids = config.adapter_token_ids + sub_ids = config.adapter_substitute_token_ids + max_ctrl_id = max(ctrl_ids) + lut_size = max(getattr(config, "vocab_size", 0), max_ctrl_id + 1) + lut = torch.full((lut_size,), -1, dtype=torch.long) + for ctrl_id, sub_id in zip(ctrl_ids, sub_ids): + lut[ctrl_id] = sub_id + self.register_buffer("control_to_substitute_lut", lut) + else: + self.control_to_substitute_lut = None + @property def num_cache_layers(self) -> int: """Number of KV cache slots used by this switch (1 Attention layer).""" @@ -101,9 +124,13 @@ def forward( self, input_ids: torch.Tensor, adapter_token_ids: torch.Tensor, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Compute adapter indices using replicated one-hot attention. + Compute adapter indices and rewrite control tokens via the LUT. + + See the HF SingleSwitch docstring for the full rationale. In short: + the switch performs both adapter selection and token-exchange + rewrite, so the decoder is agnostic to substitution. Args: input_ids: Input token IDs [total_tokens] - flattened by vLLM scheduler @@ -114,7 +141,10 @@ def forward( to transition back to base mid-sequence. Returns: - adapter_indices: [total_tokens] where 0 = base, 1+ = adapters + (adapter_indices, modified_input_ids): + adapter_indices: [total_tokens] where 0 = base, 1+ = adapters. + modified_input_ids: [total_tokens] with each control-token id + replaced by its substitute id. """ total_tokens = input_ids.shape[0] device = input_ids.device @@ -162,8 +192,22 @@ def forward( # Round to get integer adapter indices adapter_indices = torch.round(attn_output).long() - + # Clamp to valid range [0, num_adapters] adapter_indices = torch.clamp(adapter_indices, 0, self.num_adapters) - return adapter_indices + # Token-exchange rewrite: see the HF switch for the rationale. + # Skipped only when no LUT was built (no substitute ids configured). + if self.control_to_substitute_lut is not None: + sub_id_per_pos = self.control_to_substitute_lut[input_ids] + is_control = sub_id_per_pos >= 0 + if is_control.any(): + modified_input_ids = torch.where( + is_control, sub_id_per_pos, input_ids + ) + else: + modified_input_ids = input_ids + else: + modified_input_ids = input_ids + + return adapter_indices, modified_input_ids diff --git a/tests/hf/test_single_switch.py b/tests/hf/test_single_switch.py index 234d9c8..5d186b4 100644 --- a/tests/hf/test_single_switch.py +++ b/tests/hf/test_single_switch.py @@ -135,8 +135,12 @@ def _run(self, seq, num_adapters=NUM_ADAPTERS, control_token_gain=15.0): switch = _make_switch(self._backend, num_adapters, control_token_gain) token_ids = torch.tensor(ADAPTER_TOKEN_IDS_LIST[:num_adapters]) input_ids = torch.tensor([seq]) - result = switch.forward(input_ids=input_ids, adapter_token_ids=token_ids) - return result[0].tolist() + # Switch returns (adapter_indices, modified_input_ids); these tests + # only check adapter selection so we drop the rewritten ids here. + adapter_indices, _modified = switch.forward( + input_ids=input_ids, adapter_token_ids=token_ids, + ) + return adapter_indices[0].tolist() # ── Shared test classes (from mixin) ──────────────────────────────── @@ -177,6 +181,8 @@ def test_batch_independence(self, backend): [TEXT_TOKEN, ADAPTER_TOKEN_IDS_LIST[0], TEXT_TOKEN, TEXT_TOKEN, TEXT_TOKEN], [TEXT_TOKEN, ADAPTER_TOKEN_IDS_LIST[3], TEXT_TOKEN, TEXT_TOKEN, TEXT_TOKEN], ]) - result = switch.forward(input_ids=input_ids, adapter_token_ids=token_ids) - assert (result[0, 2:] == 1).all() - assert (result[1, 2:] == 4).all() + adapter_indices, _modified = switch.forward( + input_ids=input_ids, adapter_token_ids=token_ids, + ) + assert (adapter_indices[0, 2:] == 1).all() + assert (adapter_indices[1, 2:] == 4).all() diff --git a/tests/hf/test_token_exchange.py b/tests/hf/test_token_exchange.py index e877bb7..3f9e306 100644 --- a/tests/hf/test_token_exchange.py +++ b/tests/hf/test_token_exchange.py @@ -60,8 +60,9 @@ def test_swap_picks_substitute_embedding(self): config, torch.tensor([[10, 20, 100, 40]], dtype=torch.long), # adapter 0 control at pos 2 ) - # The LUT maps control id 100 → substitute 5. - lut = model.model.control_to_substitute_lut + # The LUT lives on the switch (it performs the rewrite during its + # forward); maps control id 100 → substitute 5. + lut = model.model.switch.control_to_substitute_lut assert lut[100].item() == 5 assert lut[101].item() == 7 # Positions without control tokens map to -1. From 5dcbf257a10c479b5859f06dcbe9eca1b58a66c0 Mon Sep 17 00:00:00 2001 From: AlonMalach Date: Thu, 14 May 2026 13:44:48 +0300 Subject: [PATCH 08/18] Remove the legacy KV-hiding code path (#8) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Token-exchange has been the default for several commits. This change deletes the dead-but-still-callable KV-hiding code path entirely: Config: - Drop control_dims, hiding_groups, hiding_policy, adapter_third_party parameters and the corresponding state. - Drop expanded_head_dim, num_hiding_groups, hiding_group_names, use_token_exchange properties (token-exchange is now always on when num_adapters > 0). - Drop get_hiding_group_token_ids, get_third_party_adapter_mask, get_adapter_hiding_policy_matrix methods. - adapter_substitute_token_ids becomes required when num_adapters > 0. - Net: -150 LoC (config.py 345 → 195). Models: - HF and vLLM both drop token_to_group_mask / adapter_hiding_matrix buffers, hidden_count / adjusted_position_ids logic, and the token_group_membership / query_group_suppression plumbing through decoder layers. - The HF decoder layer's forward signature drops two kwargs. Attention layers (hf/core/lora.py, vllm/core/decoder.py): - Drop expand_control_dims / control_dims / expanded_head_dim fields. - Delete _expand_with_control_dimensions method entirely (~85 LoC each). - Delete the expansion / trim-back branches in forward. - vllm/core/decoder.py: attn_head_dim is unconditionally head_dim. Switches: - Drop config.expanded_head_dim references; head_dim is config.projection_head_dim everywhere. vllm/__init__.py: - ModelArchConfigConvertor.get_head_size() returns config.projection_head_dim (no expansion logic). Composer: - compose_granite_switch.py: drop --control-dims and --legacy-hiding CLI flags. Delete the legacy-hiding branch in build(); always token-exchange. - compose_utils.py: drop hiding_groups / hiding_policy / adapter_third_party kwargs. - model_card.py: drop control_dims / legacy_hiding / use_token_exchange reporting fields. Tests deleted entirely: - tests/unit/test_hiding_constant.py - tests/hf/test_kv_hiding_gap_equivalence.py - tests/vllm/test_kv_hiding_gap_equivalence.py - tests/vllm/_kv_hiding_gap_tests.py - tests/hf/test_position_zero_nan.py - tests/vllm/_position_zero_nan_tests.py - tests/integration/test_token_exchange_parity.py (compared old vs new modes; with no old mode, nothing to compare). - tests/composer/test_built_in_adapters.py (entire file tested removed Mode A / Mode B distinction). Tests rewritten: - tests/conftest.py, tests/unit/test_config{,_edge_cases}.py, tests/unit/test_token_exchange.py, tests/hf/test_model_forward.py, tests/hf/test_token_exchange.py, tests/hf/test_qk_norm.py, tests/shared/granite4_equivalence.py, tests/shared/generation_models.py: fixtures and assertions updated for the simpler config surface. Net diff: ~3000 LoC deleted, ~200 LoC added (test rewrites). 643 tests pass on CPU after the refactor (was 756; the difference is parameterized hiding-equivalence tests + the parity harness, all deleted). Breaking change for any externally-composed checkpoint that was using control_dims > 0: those checkpoints are unloadable under this version. The token-exchange path has been the documented default since #8 and the only path that received the chat-template drops, so any in-flight build should already be on it. --- .../composer/compose_granite_switch.py | 104 +-- src/granite_switch/composer/compose_utils.py | 7 - .../composer/reporting/model_card.py | 3 - src/granite_switch/config.py | 205 +----- src/granite_switch/hf/core/lora.py | 108 ---- .../hf/modeling_granite_switch.py | 78 +-- src/granite_switch/hf/switch/single.py | 13 +- src/granite_switch/vllm/__init__.py | 20 +- src/granite_switch/vllm/core/decoder.py | 95 +-- .../vllm/granite_switch_model.py | 114 +--- src/granite_switch/vllm/switch/single.py | 4 +- tests/composer/test_built_in_adapters.py | 272 -------- tests/conftest.py | 9 +- tests/hf/test_kv_hiding_gap_equivalence.py | 215 ------- tests/hf/test_model_forward.py | 159 +---- tests/hf/test_position_zero_nan.py | 188 ------ tests/hf/test_qk_norm.py | 4 - tests/hf/test_token_exchange.py | 43 +- .../integration/test_token_exchange_parity.py | 602 ------------------ tests/shared/generation_models.py | 16 +- tests/shared/granite4_equivalence.py | 14 +- tests/unit/test_config.py | 171 ++--- tests/unit/test_config_edge_cases.py | 212 +----- tests/unit/test_hiding_constant.py | 118 ---- tests/unit/test_token_exchange.py | 74 +-- tests/vllm/_kv_hiding_gap_tests.py | 71 --- tests/vllm/_position_zero_nan_tests.py | 479 -------------- tests/vllm/test_kv_hiding_gap_equivalence.py | 39 -- 28 files changed, 200 insertions(+), 3237 deletions(-) delete mode 100644 tests/composer/test_built_in_adapters.py delete mode 100644 tests/hf/test_kv_hiding_gap_equivalence.py delete mode 100644 tests/hf/test_position_zero_nan.py delete mode 100644 tests/integration/test_token_exchange_parity.py delete mode 100644 tests/unit/test_hiding_constant.py delete mode 100644 tests/vllm/_kv_hiding_gap_tests.py delete mode 100644 tests/vllm/_position_zero_nan_tests.py delete mode 100644 tests/vllm/test_kv_hiding_gap_equivalence.py diff --git a/src/granite_switch/composer/compose_granite_switch.py b/src/granite_switch/composer/compose_granite_switch.py index 41b7c02..0f90b2a 100755 --- a/src/granite_switch/composer/compose_granite_switch.py +++ b/src/granite_switch/composer/compose_granite_switch.py @@ -88,19 +88,17 @@ def _probe_lora_substitute_token_id(tokenizer) -> int: should land at the swap site. By deriving this from the tokenizer's own chat template at compose - time, we avoid hard-coding a Granite-4.1-specific token string + time, we avoid hard-coding a Granite-4.x-specific token string (<|start_of_role|>). Other base models with different chat templates get the correct substitute for their template by construction. - Raises ``ValueError`` if the template cannot be rendered or the first - tokenized id cannot be determined. Callers should suggest - ``--legacy-hiding`` as the fallback. + Raises ``ValueError`` if the template is missing, fails to render, or + emits an unknown token. """ if tokenizer.chat_template is None: raise ValueError( "Tokenizer has no chat_template; cannot probe the LoRA " - "substitute token. Pass --legacy-hiding to use the KV-hiding " - "path instead." + "substitute token." ) try: probe_text = tokenizer.apply_chat_template( @@ -111,22 +109,19 @@ def _probe_lora_substitute_token_id(tokenizer) -> int: except Exception as e: raise ValueError( "Failed to render a probe chat via tokenizer.apply_chat_template " - f"while detecting the LoRA substitute token: {e!r}. " - "Pass --legacy-hiding to use the KV-hiding path instead." + f"while detecting the LoRA substitute token: {e!r}." ) from e ids = tokenizer(probe_text, add_special_tokens=False).input_ids if not ids: raise ValueError( "Probe chat tokenized to an empty id list; cannot determine the " - "LoRA substitute token. Pass --legacy-hiding to use the " - "KV-hiding path instead." + "LoRA substitute token." ) sub_id = ids[0] if sub_id == tokenizer.unk_token_id: raise ValueError( "First token of the rendered probe chat is ; the template " - "appears to emit content outside the tokenizer's vocabulary. " - "Pass --legacy-hiding to use the KV-hiding path instead." + "appears to emit content outside the tokenizer's vocabulary." ) return sub_id @@ -504,22 +499,6 @@ def _compose_argparser(): default=None, help="Dimension of Q/K/V vectors in switch attention", ) - parser.add_argument( - "--control-dims", - type=int, - default=None, - help="Extra dims for K/V to mask control tokens in decoder layers. " - "Default: 0 (token-exchange mode). Set to >=1 and pass --legacy-hiding " - "only if a specific adapter regresses under token exchange.", - ) - parser.add_argument( - "--legacy-hiding", - action="store_true", - default=False, - help="Use the legacy KV-hiding path (control_dims=32, no embedding " - "substitution). Escape hatch for adapters that regress under the " - "default token-exchange mode.", - ) parser.add_argument( "--built-in-adapters", type=str, @@ -743,9 +722,8 @@ def build(): has_external = len(external_discovered) > 0 has_built_in = len(built_in_discovered) > 0 - # Mode detection: - # Mode A (native): built-in only → no hiding, control_dims=0 - # Mode B (third-party): externals present → full hiding + # Mode detection (informational only — token-exchange handles both + # native and third-party adapter builds uniformly). if has_built_in and not has_external: build_mode = "native" elif has_external: @@ -757,7 +735,6 @@ def build(): # Extract fields from 4-tuples (path, name, tech, source) adapter_paths = [t[0] for t in all_discovered if t[0] is not None] adapter_names = [t[1] for t in all_discovered] - external_names = [t[1] for t in external_discovered] built_in_names = [name for name in (args.built_in_adapters or [])] print(f"\nBuild mode: {build_mode}") @@ -812,51 +789,23 @@ def build(): optional_kwargs = {} if args.switch_head_dim is not None: optional_kwargs["switch_head_dim"] = args.switch_head_dim - if args.control_dims is not None: - optional_kwargs["control_dims"] = args.control_dims - - # Control-token handling mode: token-exchange by default, legacy hiding on opt-in. - if args.legacy_hiding: - # Legacy: keep today's KV-hiding scheme. control_dims must be > 0. - if optional_kwargs.get("control_dims", 0) == 0: - optional_kwargs["control_dims"] = 32 - adapter_substitute_token_ids = None - # Hiding groups only apply in third-party mode; native mode still - # has no hiding and no substitution, but --legacy-hiding forces - # control_dims > 0 so the config validator accepts it. - if build_mode == "native": - hiding_groups = None - hiding_policy = None - adapter_third_party = None + + # Token-exchange substitute choice (must mirror the token that appears + # right after the control token in the rendered chat prompt, so the + # swap keeps the residual stream in-distribution): + # - ALoRA: first token of the adapter's alora_invocation_tokens. + # - LoRA/builtin: whatever the tokenizer's chat template emits at + # the very start of a no-adapter user turn. For Granite 4.x that's + # <|start_of_role|>; the probe derives this from the template at + # compose time so other base models work by construction. + lora_sub_id = _probe_lora_substitute_token_id(tokenizer) + adapter_substitute_token_ids = [] + for adapter_path, _name, technology, _source in all_discovered: + if technology == "alora": + sub_id = get_alora_first_invocation_token_id(adapter_path) else: - hiding_groups = {"all_controls": list(adapter_names)} - hiding_policy = {name: ["all_controls"] for name in adapter_names} - hiding_policy["base"] = ["all_controls"] - adapter_third_party = list(external_names) - else: - # Default: token-exchange. control_dims=0; every adapter needs a substitute id. - # - # Substitute choice (must mirror the token that appears right after the - # control token in the rendered chat prompt, so the swap keeps the - # residual stream in-distribution): - # - ALoRA: first token of the adapter's alora_invocation_tokens. - # - LoRA/builtin: whatever the tokenizer's chat template emits at - # the very start of a no-adapter user turn. For Granite 4.x - # that's <|start_of_role|>; the probe derives this from the - # template at compose time so other base models work by - # construction. - lora_sub_id = _probe_lora_substitute_token_id(tokenizer) - adapter_substitute_token_ids = [] - for adapter_path, _name, technology, _source in all_discovered: - if technology == "alora": - sub_id = get_alora_first_invocation_token_id(adapter_path) - else: - sub_id = lora_sub_id - adapter_substitute_token_ids.append(sub_id) - # Token-exchange supersedes KV hiding — no hiding config needed. - hiding_groups = None - hiding_policy = None - adapter_third_party = None + sub_id = lora_sub_id + adapter_substitute_token_ids.append(sub_id) model = GraniteSwitchComposer.from_base_and_adapters( base_model_name_or_path=base_model_local_path, @@ -864,9 +813,6 @@ def build(): adapter_token_ids=adapter_token_ids, adapter_substitute_token_ids=adapter_substitute_token_ids, adapter_names=adapter_names, - hiding_groups=hiding_groups, - hiding_policy=hiding_policy, - adapter_third_party=adapter_third_party, built_in_adapter_names=built_in_names, built_in_lora_rank=args.lora_rank, built_in_lora_alpha=args.lora_alpha if args.lora_alpha is not None else float(args.lora_rank), diff --git a/src/granite_switch/composer/compose_utils.py b/src/granite_switch/composer/compose_utils.py index 2690a2f..66704b2 100644 --- a/src/granite_switch/composer/compose_utils.py +++ b/src/granite_switch/composer/compose_utils.py @@ -117,10 +117,6 @@ def from_base_and_adapters( source_analysis = {} # --- Step 4: Build switch config from arch descriptor --- - hiding_groups = kwargs.pop("hiding_groups", None) - hiding_policy = kwargs.pop("hiding_policy", None) - adapter_third_party = kwargs.pop("adapter_third_party", None) - # Copy config fields driven by architecture descriptor config_kwargs: Dict = {} @@ -158,9 +154,6 @@ def from_base_and_adapters( "adapter_token_ids": adapter_token_ids, "adapter_substitute_token_ids": adapter_substitute_token_ids, "adapter_names": adapter_names, - "hiding_groups": hiding_groups, - "hiding_policy": hiding_policy, - "adapter_third_party": adapter_third_party, "max_lora_rank": lora_rank, "adapter_ranks": adapter_ranks, "lora_target_modules": lora_target_modules, diff --git a/src/granite_switch/composer/reporting/model_card.py b/src/granite_switch/composer/reporting/model_card.py index 4bb2cd1..8c3505f 100644 --- a/src/granite_switch/composer/reporting/model_card.py +++ b/src/granite_switch/composer/reporting/model_card.py @@ -391,9 +391,6 @@ def _short_source(source): "lora_rank": getattr(args, "lora_rank", None) if built_in else None, "lora_alpha": getattr(args, "lora_alpha", None) if built_in else None, "switch_head_dim": getattr(args, "switch_head_dim", None), - "control_dims": getattr(args, "control_dims", None), - "legacy_hiding": getattr(args, "legacy_hiding", False), - "use_token_exchange": getattr(model.config, "use_token_exchange", False), "adapter_substitute_token_ids": getattr( model.config, "adapter_substitute_token_ids", None ), diff --git a/src/granite_switch/config.py b/src/granite_switch/config.py index d1380a6..7824002 100644 --- a/src/granite_switch/config.py +++ b/src/granite_switch/config.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 """Configuration for Granite model with adapter switching.""" -from typing import List, Optional, Dict +from typing import List, Optional from transformers import GraniteMoeHybridConfig @@ -9,11 +9,12 @@ class GraniteSwitchConfig(GraniteMoeHybridConfig): """Configuration class for GraniteSwitch model. - Extends the Granite base config with parameters for adapter switching using - the SingleSwitch mechanism. - - Inherits from GraniteMoeHybridConfig (the transformers base class for - Granite 4 models) and adds adapter routing parameters. + Extends the Granite base config with parameters for adapter switching + using the SingleSwitch mechanism. Control tokens are handled exclusively + via token exchange: the switch reads ``input_ids``, decides the active + adapter, and rewrites each control token to its substitute id (from + ``adapter_substitute_token_ids``) before the decoder embeds the + sequence. The decoder is unaware of the substitution. Args: num_adapters (int): Number of LoRA adapters available. Default: 0 (no adapters). @@ -23,30 +24,15 @@ class GraniteSwitchConfig(GraniteMoeHybridConfig): adapter_token_ids[i] activates adapter i+1 (1-indexed output). Output 0 = base (implicit default, no token needed to return to base). NOTE: SingleSwitch cannot transition back to base mid-sequence. - adapter_substitute_token_ids (List[int]): Token IDs whose embeddings replace - the control-token embeddings before the decoder runs (token-exchange mode). - Length: num_adapters. When provided together with control_dims=0, the model - uses token exchange instead of KV hiding. + adapter_substitute_token_ids (List[int]): Token IDs whose embeddings + replace the control-token embeddings before the decoder runs. + Length: num_adapters. Required when num_adapters > 0. SingleSwitch parameters: control_token_gain (float): Attention gain for control/non-control separation. Default: 15.0. switch_head_dim (int): Dimension of Q/K/V vectors in switch attention. Default: 32. - control_dims (int): Extra dimensions for K/V to mask control tokens. Must be >= 0. - Default: 0 (token-exchange mode — adapter_substitute_token_ids must be provided). - Set >= 1 to enable the legacy KV-hiding path. adapter_names (List[str]): Ordered adapter names for name-to-index mapping. - Used by hiding_groups and hiding_policy to resolve names to indices. - hiding_groups (Dict[str, List[str]]): Hiding group definitions. - Maps group_name → list of adapter names whose control tokens belong to this group. - Each group uses one control dimension. Requires control_dims >= len(hiding_groups). - hiding_policy (Dict[str, List[str]]): Per-adapter hiding policy. - Maps adapter_name → list of group names that adapter hides. Use "base" for the - base adapter (adapter_index 0). - adapter_third_party (List[str]): Adapter names that are third-party (externally trained). - Third-party adapters were not trained with control tokens in their vocabulary, - which affects KV hiding policy. - max_lora_rank (int): Maximum rank across all LoRA adapters (for allocation). Default: 8. adapter_ranks (List[int]): Per-adapter ranks. Must have length equal to num_adapters. lora_target_modules (List[str]): List of module GROUP names to apply LoRA to. @@ -65,13 +51,8 @@ def __init__( # SingleSwitch parameters control_token_gain: float = 15.0, switch_head_dim: int = 32, - control_dims: int = 0, - # Hiding groups and policy - adapter_names: Optional[List[str]] = None, - hiding_groups: Optional[Dict[str, List[str]]] = None, - hiding_policy: Optional[Dict[str, List[str]]] = None, - adapter_third_party: Optional[List[str]] = None, # Adapter parameters + adapter_names: Optional[List[str]] = None, max_lora_rank: int = 8, adapter_ranks: List[int] = None, lora_target_modules: Optional[List[str]] = None, @@ -124,12 +105,19 @@ def __init__( ) self.adapter_token_ids = adapter_token_ids - # Validate adapter_substitute_token_ids if provided - if num_adapters > 0 and adapter_substitute_token_ids is not None: + # Validate adapter_substitute_token_ids — required when num_adapters > 0. + if num_adapters > 0: + if adapter_substitute_token_ids is None: + raise ValueError( + "adapter_substitute_token_ids is required when num_adapters > 0. " + "Every adapter needs a substitute token id whose embedding replaces " + "the control-token embedding before the decoder runs." + ) if len(adapter_substitute_token_ids) != num_adapters: raise ValueError( - f"adapter_substitute_token_ids length ({len(adapter_substitute_token_ids)}) " - f"must equal num_adapters ({num_adapters})." + f"adapter_substitute_token_ids length " + f"({len(adapter_substitute_token_ids)}) must equal num_adapters " + f"({num_adapters})." ) if any(sid < 0 for sid in adapter_substitute_token_ids): raise ValueError( @@ -139,56 +127,22 @@ def __init__( if adapter_token_ids is None: raise ValueError( "adapter_token_ids is required when adapter_substitute_token_ids " - "is provided (token-exchange mode maps control ids to substitute ids)." + "is provided (token-exchange maps control ids to substitute ids)." ) self.adapter_substitute_token_ids = adapter_substitute_token_ids # SingleSwitch parameters self.control_token_gain = control_token_gain self.switch_head_dim = switch_head_dim - if control_dims < 0: - raise ValueError( - f"control_dims must be >= 0 (got {control_dims}). " - "Use control_dims=0 for native mode (no KV hiding). " - "Use control_dims >= 1 for third-party mode (KV cache masking)." - ) - self.control_dims = control_dims self.fused_add_norm = fused_add_norm - # Control tokens need one of two handling paths when adapters are present: - # legacy KV hiding (control_dims > 0) or token exchange (substitute ids present). - # The combination of neither would leak raw control-token embeddings into attention. - if ( - num_adapters > 0 - and control_dims == 0 - and adapter_substitute_token_ids is None - ): - raise ValueError( - "When num_adapters > 0, either control_dims > 0 (legacy KV hiding) " - "or adapter_substitute_token_ids (token exchange) must be provided. " - "Neither is set, which would leave control tokens unhandled." - ) - - # Hiding groups and policy + # Adapter names self.adapter_names = adapter_names - self.hiding_groups = hiding_groups - self.hiding_policy = hiding_policy - self.adapter_third_party = adapter_third_party - # Validate control_dims >= num_hiding_groups - if hiding_groups is not None and len(hiding_groups) > control_dims: - raise ValueError( - f"control_dims ({control_dims}) must be >= number of hiding groups " - f"({len(hiding_groups)}). Each hiding group uses one control dimension." - ) - - # KV cache head dimension vs. projection dimension. + # Projection head dimension. # The QKV projection outputs vectors of size projection_head_dim - # (= hidden_size / num_attention_heads). The KV cache stores larger - # vectors (projection_head_dim + control_dims) for exact attention - # masking of control tokens. The expanded size is communicated to - # vLLM via a custom ModelArchConfigConvertor (registered in vllm/__init__.py) - # so that hybrid page-size calculations use the correct value. + # (= hidden_size / num_attention_heads). The KV cache stores native- + # head_dim tensors — no expansion under token exchange. # We do NOT set head_dim here because HF's RoPE also reads it. # Use explicit head_dim from kwargs when available (some models have # head_dim != hidden_size // num_attention_heads). @@ -238,108 +192,3 @@ def __init__( ]) self.lora_target_modules = lora_target_modules - - @property - def expanded_head_dim(self) -> int: - """KV cache head dimension: projection_head_dim + control_dims when adapters are active.""" - if self.num_adapters > 0 and self.control_dims > 0: - return self.projection_head_dim + self.control_dims - return self.projection_head_dim - - @property - def use_token_exchange(self) -> bool: - """True when control tokens are replaced with substitute embeddings (vs. KV hiding).""" - return ( - self.num_adapters > 0 - and self.control_dims == 0 - and self.adapter_substitute_token_ids is not None - ) - - @property - def num_hiding_groups(self) -> int: - """Number of hiding groups (each uses one control dimension).""" - if self.hiding_groups is None: - return 0 - return len(self.hiding_groups) - - @property - def hiding_group_names(self) -> List[str]: - """Ordered list of hiding group names (determines control dim indices).""" - if self.hiding_groups is None: - return [] - return list(self.hiding_groups.keys()) - - def get_hiding_group_token_ids(self) -> Dict[int, List[int]]: - """Map group index → list of token IDs in that group. - - Resolves adapter names to their activating token IDs using - adapter_names and adapter_token_ids. - - Returns empty dict if no hiding groups configured. - """ - if self.hiding_groups is None or self.adapter_names is None: - return {} - if self.adapter_token_ids is None: - return {} - - # Build name → token ID mapping (no offset for SingleSwitch) - name_to_token_id = {} - for i, name in enumerate(self.adapter_names): - name_to_token_id[name] = self.adapter_token_ids[i] - - result = {} - for group_idx, group_name in enumerate(self.hiding_group_names): - adapter_names_in_group = self.hiding_groups[group_name] - token_ids = [] - for name in adapter_names_in_group: - if name in name_to_token_id: - token_ids.append(name_to_token_id[name]) - result[group_idx] = token_ids - return result - - def get_third_party_adapter_mask(self) -> List[bool]: - """Return per-adapter-slot boolean: True if the adapter is third-party. - - Index 0 = base (never third-party). Index 1+ = real adapters. - Length = num_adapters + 1 (one slot per adapter index including base). - - Third-party adapters were not trained with control tokens in their - vocabulary, which affects KV hiding policy. - - Returns all-False list if adapter_third_party is not configured. - """ - num_slots = self.num_adapters + 1 # base + adapters - if not self.adapter_third_party or not self.adapter_names: - return [False] * num_slots - - tp_set = set(self.adapter_third_party) - # Index 0 = base (never third-party) - mask = [False] - for name in self.adapter_names: - mask.append(name in tp_set) - return mask - - def get_adapter_hiding_policy_matrix(self) -> List[List[bool]]: - """Build adapter hiding policy matrix: [num_adapter_slots][num_groups]. - - Index 0 = base adapter. Index 1+ = real adapters (matching adapter_names order). - Each entry is True if that adapter hides that group. - - Returns empty list if no hiding policy configured. - """ - if self.hiding_policy is None or self.adapter_names is None: - return [] - - num_groups = self.num_hiding_groups - group_names = self.hiding_group_names - - # Build ordered adapter list: [base, adapter_0, adapter_1, ...] - all_adapter_names = ["base"] + list(self.adapter_names) - num_slots = len(all_adapter_names) - - matrix = [] - for adapter_name in all_adapter_names: - groups_to_hide = self.hiding_policy.get(adapter_name, []) - row = [gn in groups_to_hide for gn in group_names] - matrix.append(row) - return matrix diff --git a/src/granite_switch/hf/core/lora.py b/src/granite_switch/hf/core/lora.py index 97e97ce..648abc3 100644 --- a/src/granite_switch/hf/core/lora.py +++ b/src/granite_switch/hf/core/lora.py @@ -371,13 +371,6 @@ def __init__(self, config: GraniteSwitchConfig, layer_idx: int): self.q_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps) - # Control dimension expansion for KV cache masking. - # Expand only when adapters present AND control_dims > 0. - # control_dims=0 means native mode: no KV hiding, no expansion. - self.expand_control_dims = config.num_adapters > 0 and config.control_dims > 0 - self.control_dims = config.control_dims - self.expanded_head_dim = self.head_dim + self.control_dims - # Fused QKV projection - conditionally add LoRA based on config q_size = self.num_heads * self.head_dim kv_size = self.num_key_value_heads * self.head_dim @@ -425,94 +418,10 @@ def __init__(self, config: GraniteSwitchConfig, layer_idx: int): ) self.has_o_lora = False - def _expand_with_control_dimensions( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - token_group_membership: Optional[torch.Tensor], - query_group_suppression: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Expand Q, K, V with control dimensions for group-based KV cache hiding. - - Always called when num_adapters > 0 (static shape decision). - Each hiding group g uses one control dimension: - - K-side: finfo(dtype).min for tokens that are members of group g - - Q-side: 1.0 for queries whose adapter suppresses group g, - except for tokens that are themselves in group g - - When both tensors are None, all control dims are zero (no masking effect). - - Args: - q: Query tensor [batch, seq_len, num_heads, head_dim] - k: Key tensor [batch, seq_len, num_kv_heads, head_dim] - v: Value tensor [batch, seq_len, num_kv_heads, head_dim] - token_group_membership: [batch, seq_len, num_groups] — True if token is in group g - query_group_suppression: [batch, seq_len, num_groups] — True if token's adapter suppresses group g - - Returns: - Expanded Q, K, V tensors with control_dims added to head_dim - """ - batch_size, seq_len = q.shape[:2] - device = q.device - dtype = q.dtype - - # Allocate control dimensions (initialized to zero) - q_control = torch.zeros( - batch_size, seq_len, self.num_heads, self.control_dims, - device=device, dtype=dtype - ) - k_control = torch.zeros( - batch_size, seq_len, self.num_key_value_heads, self.control_dims, - device=device, dtype=dtype - ) - v_control = torch.zeros( - batch_size, seq_len, self.num_key_value_heads, self.control_dims, - device=device, dtype=dtype - ) - - # K-side: brand each group-member token's key with finfo.min in its group's - # control dim so that suppressing queries score it as −∞. - # token_group_membership: [batch, seq, num_groups] - # → expand to [batch, seq, num_kv_heads, num_groups] - if token_group_membership is not None: - num_groups = token_group_membership.shape[-1] - hiding_constant = torch.finfo(dtype).min - k_control[:, :, :, :num_groups] = ( - token_group_membership.unsqueeze(2) - .expand(-1, -1, self.num_key_value_heads, -1) - .to(dtype) * hiding_constant - ) - - # Q-side: set control dim g to 1.0 for queries whose adapter suppresses group g. - # query_group_suppression: [batch, seq, num_groups] - # → expand to [batch, seq, num_heads, num_groups] - # Tokens that are themselves in group g are excluded: when the control token - # sits at position 0 it has no other causal key to attend to, so suppressing - # its own key yields softmax([−∞]) = NaN. - if query_group_suppression is not None: - num_groups = query_group_suppression.shape[-1] - q_hide = query_group_suppression.to(dtype) - if token_group_membership is not None: - q_hide = q_hide * (1 - token_group_membership.to(dtype)) - q_control[:, :, :, :num_groups] = ( - q_hide.unsqueeze(2) - .expand(-1, -1, self.num_heads, -1) - ) - - # Concatenate original dims + control dims - q = torch.cat([q, q_control], dim=-1) # [batch, seq_len, num_heads, head_dim + control_dims] - k = torch.cat([k, k_control], dim=-1) # [batch, seq_len, num_kv_heads, head_dim + control_dims] - v = torch.cat([v, v_control], dim=-1) # [batch, seq_len, num_kv_heads, head_dim + control_dims] - - return q, k, v - def forward( self, hidden_states: torch.Tensor, adapter_indices: torch.Tensor, - token_group_membership: Optional[torch.Tensor], - query_group_suppression: Optional[torch.Tensor], position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[Cache] = None, @@ -525,8 +434,6 @@ def forward( Args: hidden_states: Input tensor [batch, seq_len, hidden_size] adapter_indices: Per-token adapter selection [batch, seq_len] - token_group_membership: [batch, seq_len, num_groups] — True if token is in group g, or None - query_group_suppression: [batch, seq_len, num_groups] — True if token's adapter suppresses group g, or None position_embeddings: Precomputed (cos, sin) for RoPE attention_mask: Attention mask past_key_values: Cache object for KV caching @@ -573,15 +480,6 @@ def forward( query_states = query_states_t.transpose(1, 2) key_states = key_states_t.transpose(1, 2) - # Control dimension expansion: always when adapters are present. - # Group masks control which tokens/groups get K=finfo.min masking - # (can be None if no hiding groups, but expansion still happens). - if self.expand_control_dims: - query_states, key_states, value_states = self._expand_with_control_dimensions( - query_states, key_states, value_states, - token_group_membership, query_group_suppression, - ) - # Belief that both cache and attention expect [batch, heads, seq, dim] key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) @@ -612,12 +510,6 @@ def forward( sliding_window=getattr(self.config, "sliding_window", None), ) - # Trim control dimensions from output - if self.expand_control_dims: - # attn_output shape: [batch, num_heads, seq_len, expanded_head_dim] - # Trim to original head_dim - attn_output = attn_output[..., :self.head_dim] - # Reshape and project output - conditionally use LoRA attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) if self.has_o_lora: diff --git a/src/granite_switch/hf/modeling_granite_switch.py b/src/granite_switch/hf/modeling_granite_switch.py index 2f064c0..dd23110 100644 --- a/src/granite_switch/hf/modeling_granite_switch.py +++ b/src/granite_switch/hf/modeling_granite_switch.py @@ -82,8 +82,6 @@ def forward( cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, adapter_indices: Optional[torch.Tensor] = None, - token_group_membership: Optional[torch.Tensor] = None, - query_group_suppression: Optional[torch.Tensor] = None, **kwargs, ) -> tuple: residual = hidden_states @@ -93,8 +91,6 @@ def forward( hidden_states, self_attn_weights, present_key_values = self.self_attn( hidden_states=hidden_states, adapter_indices=adapter_indices, - token_group_membership=token_group_membership, - query_group_suppression=query_group_suppression, position_embeddings=position_embeddings, attention_mask=attention_mask, past_key_values=past_key_values, @@ -193,44 +189,9 @@ def __init__(self, config: GraniteSwitchConfig): # forward pass, so this model class no longer needs a decoder- # side substitute table. - # --- Hiding group buffers --- - # token_to_group_mask: [vocab_size, num_groups] lookup table. - # For each token ID, True at group g if that token belongs to group g. - # Enables O(1) per-token group membership via: mask = table[input_ids] - num_groups = config.num_hiding_groups - if num_groups > 0: - group_token_ids = config.get_hiding_group_token_ids() - # Size must cover all token IDs including added control tokens - # which may have IDs >= config.vocab_size. - all_known_ids = [tid for tids in group_token_ids.values() for tid in tids] - if config.adapter_token_ids: - all_known_ids.extend(config.adapter_token_ids) - max_tid = max(all_known_ids) if all_known_ids else -1 - table_size = max(config.vocab_size, max_tid + 1) - token_to_group_mask = torch.zeros( - table_size, num_groups, dtype=torch.bool - ) - for g, tids in group_token_ids.items(): - for tid in tids: - token_to_group_mask[tid, g] = True - self.register_buffer("token_to_group_mask", token_to_group_mask) - - # adapter_hiding_matrix: [num_adapter_slots, num_groups] boolean. - # Index 0 = base, 1+ = adapters. True if adapter hides group g. - policy_matrix = config.get_adapter_hiding_policy_matrix() - self.register_buffer( - "adapter_hiding_matrix", - torch.tensor(policy_matrix, dtype=torch.bool), - ) - else: - self.token_to_group_mask = None - self.adapter_hiding_matrix = None - else: self.switch = None self.adapter_token_ids = None - self.token_to_group_mask = None - self.adapter_hiding_matrix = None # Decoder layers if config.num_adapters > 0: @@ -335,7 +296,6 @@ def forward( # The switch also returns modified_input_ids: input_ids with each # control token rewritten to its substitute id, so the decoder can # embed once without any token-exchange awareness. - hidden_count = None modified_input_ids = input_ids if self.switch is not None: adapter_indices, modified_input_ids = self.switch( @@ -345,30 +305,10 @@ def forward( past_key_values=past_key_values, cache_position=cache_position, ) - - # Compute group-based hiding masks from lookup tables. - if self.token_to_group_mask is not None: - # token_group_membership: True at [b, i, g] if token i is a member of group g - token_group_membership = self.token_to_group_mask[input_ids] - # query_group_suppression: True at [b, i, g] if token i's adapter suppresses group g - query_group_suppression = self.adapter_hiding_matrix[adapter_indices] - else: - token_group_membership = None - query_group_suppression = None - - # Compute hidden_count for position correction (SingleSwitch). - # SingleSwitch fires once: hidden_count is 0 before the control - # token and 1 at/after it, which is exactly (adapter_indices > 0). - # In token-exchange mode control tokens become real positions, so - # the correction is a no-op — skip it rather than subtract zeros. - if hidden_count is None and not self.config.use_token_exchange: - hidden_count = (adapter_indices > 0).long() else: adapter_indices = torch.zeros( (batch_size, seq_length), dtype=torch.long, device=device, ) - token_group_membership = None - query_group_suppression = None # Embed once, on the (possibly-rewritten) input_ids. The decoder is # token-exchange-agnostic — it just embeds whatever the switch @@ -380,20 +320,12 @@ def forward( # Expose adapter_indices for tests and debugging. self._last_adapter_indices = adapter_indices - # Position correction: adjust position_ids to close gaps from hidden tokens. - # Clamp to >= 0: pre-init tokens have no hidden tokens in their causal - # past, but the counting mechanism returns capacity-1 when all attention - # keys are masked, which would produce negative positions and OOB RoPE - # cache indices. - if hidden_count is not None: - adjusted_position_ids = torch.clamp(position_ids - hidden_count, min=0) - else: - adjusted_position_ids = position_ids - - # Position embeddings (only if RoPE is configured) + # Position embeddings (only if RoPE is configured). Control tokens + # in token-exchange mode count as real positions, so position_ids + # is used directly — no hidden_count subtraction. position_embeddings = None if self.rotary_emb is not None: - position_embeddings = self.rotary_emb(inputs_embeds, position_ids=adjusted_position_ids) + position_embeddings = self.rotary_emb(inputs_embeds, position_ids=position_ids) # Decoder layers hidden_states = inputs_embeds @@ -414,8 +346,6 @@ def forward( cache_position=cache_position, position_embeddings=position_embeddings, adapter_indices=adapter_indices, - token_group_membership=token_group_membership, - query_group_suppression=query_group_suppression, **kwargs, ) diff --git a/src/granite_switch/hf/switch/single.py b/src/granite_switch/hf/switch/single.py index e35ac91..ea2508b 100644 --- a/src/granite_switch/hf/switch/single.py +++ b/src/granite_switch/hf/switch/single.py @@ -57,11 +57,14 @@ def __init__( self.control_token_gain = control_token_gain self.config = config - # Use expanded_head_dim to align with decoder layers across both backends. - if config is not None and hasattr(config, 'expanded_head_dim') and getattr(config, 'num_adapters', 0) > 0: - self.head_dim = config.expanded_head_dim - elif config is not None: - self.head_dim = config.hidden_size // config.num_attention_heads + # Align with the decoder's native head_dim. (Under token exchange the + # KV cache no longer carries any expansion, so this is just the + # base-model projection_head_dim.) + if config is not None: + self.head_dim = getattr( + config, "projection_head_dim", + config.hidden_size // config.num_attention_heads, + ) else: self.head_dim = switch_head_dim diff --git a/src/granite_switch/vllm/__init__.py b/src/granite_switch/vllm/__init__.py index ba52afb..eb6401b 100644 --- a/src/granite_switch/vllm/__init__.py +++ b/src/granite_switch/vllm/__init__.py @@ -52,10 +52,12 @@ def register(): except Exception: pass - # Register custom ModelArchConfigConvertor so vLLM sees the correct - # KV cache head size. When adapters use control_dims, the decoder - # attention stores expanded vectors (projection_head_dim + control_dims) - # in the KV cache. + # Register custom ModelArchConfigConvertor so vLLM sees: + # 1. The correct decoder layer count (excluding the switch's KV-cache + # placeholder slot). + # 2. The native KV cache head size (projection_head_dim). Token + # exchange does not expand the head dim, so this is just the base + # model's head_dim. try: from vllm.transformers_utils.model_arch_config_convertor import ( MODEL_ARCH_CONFIG_CONVERTORS, @@ -76,15 +78,7 @@ def get_num_hidden_layers(self) -> int: def get_head_size(self) -> int: cfg = self.hf_text_config - if hasattr(cfg, 'expanded_head_dim'): - return cfg.expanded_head_dim - # Fallback for configs without the property - base = super().get_head_size() - num_adapters = getattr(cfg, "num_adapters", 0) - control_dims = getattr(cfg, "control_dims", 32) - if num_adapters > 0 and control_dims > 0: - return base + control_dims - return base + return getattr(cfg, "projection_head_dim", super().get_head_size()) MODEL_ARCH_CONFIG_CONVERTORS["granite_switch"] = ( _GraniteSwitchArchConfigConvertor diff --git a/src/granite_switch/vllm/core/decoder.py b/src/granite_switch/vllm/core/decoder.py index 66090d0..565d1cc 100644 --- a/src/granite_switch/vllm/core/decoder.py +++ b/src/granite_switch/vllm/core/decoder.py @@ -78,12 +78,6 @@ def __init__( self.kv_size = self.num_kv_heads * self.head_dim self.scaling = config.attention_multiplier - # Control dimension expansion: expand only when adapters present AND - # control_dims > 0. control_dims=0 means native mode (no KV hiding). - self.expand_control_dims = num_adapters > 0 and config.control_dims > 0 - self.control_dims = config.control_dims - self.expanded_head_dim = self.head_dim + self.control_dims - # QKV projection - conditionally add LoRA based on config base_qkv_proj = QKVParallelLinear( self.hidden_size, @@ -143,11 +137,10 @@ def __init__( else: self.rotary_emb = None - # Attention layer — use expanded head dim only when expansion is active - self.attn_head_dim = self.expanded_head_dim if self.expand_control_dims else self.head_dim + # Attention layer — head_dim is the native projection_head_dim. self.attn = Attention( self.num_heads, - self.attn_head_dim, + self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, @@ -155,81 +148,12 @@ def __init__( prefix=f"{prefix}.attn", ) - def _expand_with_control_dimensions( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - token_group_membership: Optional[torch.Tensor], - query_group_suppression: Optional[torch.Tensor], - ) -> tuple: - """Expand Q, K, V with control dimensions for group-based KV cache hiding. - - Always called when num_adapters > 0 (static shape decision). - Each hiding group g uses one control dimension: - - K-side: finfo(dtype).min for tokens that are members of group g - - Q-side: 1.0 for queries whose adapter suppresses group g, - except for tokens that are themselves in group g - - When both tensors are None, all control dims are zero (no masking effect). - """ - num_tokens = q.size(0) - device = q.device - dtype = q.dtype - - q = q.view(num_tokens, self.num_heads, self.head_dim) - k = k.view(num_tokens, self.num_kv_heads, self.head_dim) - v = v.view(num_tokens, self.num_kv_heads, self.head_dim) - - q_control = torch.zeros(num_tokens, self.num_heads, self.control_dims, device=device, dtype=dtype) - k_control = torch.zeros(num_tokens, self.num_kv_heads, self.control_dims, device=device, dtype=dtype) - v_control = torch.zeros(num_tokens, self.num_kv_heads, self.control_dims, device=device, dtype=dtype) - - # K-side: brand each group-member token's key with finfo.min in its group's - # control dim so that suppressing queries score it as −∞. - # token_group_membership: [num_tokens, num_groups] — True if token is in group g - if token_group_membership is not None: - num_groups = token_group_membership.shape[-1] - hiding_constant = torch.finfo(dtype).min - k_control[:, :, :num_groups] = ( - token_group_membership.unsqueeze(1) - .expand(-1, self.num_kv_heads, -1) - .to(dtype) * hiding_constant - ) - - # Q-side: set control dim g to 1.0 for queries whose adapter suppresses group g. - # query_group_suppression: [num_tokens, num_groups] — True if this token's - # adapter suppresses group g. - # Tokens that are themselves in group g are excluded: when the control token - # sits at position 0 it has no other causal key to attend to, so suppressing - # its own key yields softmax([−∞]) = NaN. - if query_group_suppression is not None: - num_groups = query_group_suppression.shape[-1] - q_hide = query_group_suppression.to(dtype) - if token_group_membership is not None: - q_hide = q_hide * (1 - token_group_membership.to(dtype)) - q_control[:, :, :num_groups] = ( - q_hide.unsqueeze(1) - .expand(-1, self.num_heads, -1) - ) - - q = torch.cat([q, q_control], dim=-1) - k = torch.cat([k, k_control], dim=-1) - v = torch.cat([v, v_control], dim=-1) - - q = q.view(num_tokens, self.num_heads * self.expanded_head_dim) - k = k.view(num_tokens, self.num_kv_heads * self.expanded_head_dim) - v = v.view(num_tokens, self.num_kv_heads * self.expanded_head_dim) - - return q, k, v - def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: - # SwitchedLoRALinear reads LoRA metadata from shared LoRAContext; - # hiding group masks for control dims also come from the context. + # SwitchedLoRALinear reads LoRA metadata from the shared LoRAContext. qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -242,20 +166,7 @@ def forward( if self.rotary_emb is not None: q, k = self.rotary_emb(positions, q, k) - if self.expand_control_dims: - token_group_membership = self._lora_ctx.token_group_membership if self._lora_ctx is not None else None - query_group_suppression = self._lora_ctx.query_group_suppression if self._lora_ctx is not None else None - q, k, v = self._expand_with_control_dimensions( - q, k, v, token_group_membership, query_group_suppression, - ) - attn_output = self.attn(q, k, v) - - if self.expand_control_dims: - attn_output = attn_output.view(-1, self.num_heads, self.expanded_head_dim)[ - ..., :self.head_dim - ].reshape(-1, self.num_heads * self.head_dim) - output, _ = self.o_proj(attn_output) return output diff --git a/src/granite_switch/vllm/granite_switch_model.py b/src/granite_switch/vllm/granite_switch_model.py index 9646883..3f90278 100644 --- a/src/granite_switch/vllm/granite_switch_model.py +++ b/src/granite_switch/vllm/granite_switch_model.py @@ -81,17 +81,10 @@ class GraniteSwitchModel(nn.Module): 3. Base transformer layers with LoRA 4. LM head - The switch detects special tokens and selects the appropriate adapter. - Adapter indices are passed as arguments to LoRA layers. - - To mitigate the contribution of control tokens in the base model and adapter computations: - - Each layer's k and v values are augmented with a control dimension set to - k=-inf for control tokens and k=0 otherwise (v=0 throughout), prior to attention - calculation. After softmax attention is computed, the value is reduced to its - original dimension. - - The logits for control tokens are set to -inf in compute_logits() to prevent - the sampler from generating control tokens - - Position correction via hidden_count closes RoPE gaps from KV-hidden control tokens + The switch detects special tokens, selects the appropriate adapter, and + rewrites each control token's id to its substitute id (token exchange). + The decoder embeds the rewritten ids and is otherwise oblivious to the + substitution. Adapter indices are passed as arguments to LoRA layers. """ def __init__( @@ -169,38 +162,10 @@ def __init__( dtype=torch.bfloat16, ) - # --- Hiding group buffers --- - num_groups = config.num_hiding_groups - if num_groups > 0: - group_token_ids = config.get_hiding_group_token_ids() - all_known_ids = [tid for tids in group_token_ids.values() for tid in tids] - if config.adapter_token_ids: - all_known_ids.extend(config.adapter_token_ids) - max_tid = max(all_known_ids) if all_known_ids else -1 - table_size = max(config.vocab_size, max_tid + 1) - token_to_group_mask = torch.zeros( - table_size, num_groups, dtype=torch.bool - ) - for g, tids in group_token_ids.items(): - for tid in tids: - token_to_group_mask[tid, g] = True - self.register_buffer("token_to_group_mask", token_to_group_mask) - - policy_matrix = config.get_adapter_hiding_policy_matrix() - self.register_buffer( - "adapter_hiding_matrix", - torch.tensor(policy_matrix, dtype=torch.bool), - ) - else: - self.token_to_group_mask = None - self.adapter_hiding_matrix = None - else: self.switch = None self.adapter_token_ids = None self.lora_meta = None - self.token_to_group_mask = None - self.adapter_hiding_matrix = None # 3. Base transformer layers with custom LoRA # @@ -290,19 +255,6 @@ def make_empty_intermediate_tensors( ), } - num_groups = self.config.num_hiding_groups - if num_groups > 0: - tensors["token_group_membership"] = torch.zeros( - (batch_size, num_groups), - dtype=torch.bool, - device=device, - ) - tensors["query_group_suppression"] = torch.zeros( - (batch_size, num_groups), - dtype=torch.bool, - device=device, - ) - return IntermediateTensors(tensors) def forward( @@ -332,9 +284,8 @@ def forward( # COMPILED: Switch + Metadata preparation # ═══════════════════════════════════════════════════════════════ - # Step 1: Switch - determine adapter for each token via switch - # Switch only runs on first rank - hidden_count = None + # Step 1: Switch — determine adapter for each token and rewrite + # control tokens via token-exchange. Only runs on first rank. if get_pp_group().is_first_rank: if self.switch is not None: adapter_indices, modified_input_ids = self.switch( @@ -342,56 +293,24 @@ def forward( adapter_token_ids=self.adapter_token_ids, ) else: - # No switch - all tokens use base model (adapter_id = 0) + # No switch — all tokens use base model (adapter_id = 0). num_tokens = input_ids.shape[0] adapter_indices = torch.zeros( - num_tokens, - dtype=torch.long, - device=input_ids.device + num_tokens, dtype=torch.long, device=input_ids.device, ) modified_input_ids = input_ids - # Step 2: Compute group-based hiding masks. - if self.token_to_group_mask is not None: - # token_group_membership: True at [i, g] if token i is a member of group g - token_group_membership = self.token_to_group_mask[input_ids] # [num_tokens, num_groups] - # query_group_suppression: True at [i, g] if token i's adapter suppresses group g - query_group_suppression = self.adapter_hiding_matrix[adapter_indices] # [num_tokens, num_groups] - else: - token_group_membership = None - query_group_suppression = None - - # Compute hidden_count for position correction (SingleSwitch). - # In token-exchange mode control tokens are real positions, so - # skip the correction entirely rather than subtract zeros. - if hidden_count is None and not self.config.use_token_exchange: - hidden_count = (adapter_indices > 0).long() - - # Position correction: adjust positions to close gaps from hidden tokens. - # Clamp to >= 0: pre-init tokens have no hidden tokens in their causal - # past, but the counting mechanism returns capacity-1 when all attention - # keys are masked, which would produce negative positions and OOB RoPE - # cache indices. - if hidden_count is not None: - positions = torch.clamp(positions - hidden_count, min=0) - - # Step 3: Prepare LoRA metadata ONCE for all decoder layers. + # Step 2: Prepare LoRA metadata ONCE for all decoder layers. # Stored on the shared LoRAContext — every SwitchedLoRALinear reads from it. if self.lora_meta is not None and self.lora_ctx is not None: # Convert to Punica convention: 0=base -> -1=base punica_indices = adapter_indices - 1 self.lora_meta.prepare_and_store(punica_indices, self.lora_ctx) - self.lora_ctx.token_group_membership = token_group_membership - self.lora_ctx.query_group_suppression = query_group_suppression - # Store metadata in intermediate_tensors for pipeline parallelism + # Store metadata in intermediate_tensors for pipeline parallelism. if intermediate_tensors is None: intermediate_tensors = IntermediateTensors({}) intermediate_tensors["adapter_indices"] = adapter_indices - if token_group_membership is not None: - intermediate_tensors["token_group_membership"] = token_group_membership - if query_group_suppression is not None: - intermediate_tensors["query_group_suppression"] = query_group_suppression else: # Subsequent ranks: recompute fixed-size LoRA metadata from # token-leading adapter_indices received through PP. @@ -400,19 +319,6 @@ def forward( if self.lora_ctx is not None: punica_indices = adapter_indices - 1 self.lora_meta.prepare_and_store(punica_indices, self.lora_ctx) - self.lora_ctx.token_group_membership = ( - _get_intermediate_tensor( - intermediate_tensors, "token_group_membership", - ) - ) - self.lora_ctx.query_group_suppression = ( - _get_intermediate_tensor( - intermediate_tensors, "query_group_suppression", - ) - ) - if not self.config.use_token_exchange: - hidden_count = (adapter_indices > 0).long() - positions = torch.clamp(positions - hidden_count, min=0) else: # Fallback: no metadata available (should not happen in normal operation) num_tokens = input_ids.shape[0] if input_ids is not None else 0 diff --git a/src/granite_switch/vllm/switch/single.py b/src/granite_switch/vllm/switch/single.py index da85801..a250270 100644 --- a/src/granite_switch/vllm/switch/single.py +++ b/src/granite_switch/vllm/switch/single.py @@ -2,7 +2,7 @@ """SingleSwitch using replicated one-hot attention for adapter selection. This switch uses the backbone's full head geometry (num_attention_heads, -num_key_value_heads, expanded_head_dim, attention_multiplier) so that all +num_key_value_heads, projection_head_dim, attention_multiplier) so that all attention layers share one FlashAttentionMetadataBuilder configuration. The same one-hot dim-0 pattern is replicated identically across every head: @@ -65,7 +65,7 @@ def __init__( self.num_kv_heads = total_kv // tp_size else: self.num_kv_heads = max(1, total_kv // tp_size) - self.head_dim = config.expanded_head_dim + self.head_dim = config.projection_head_dim self.scaling = config.attention_multiplier self.effective_gain = control_token_gain / self.scaling else: diff --git a/tests/composer/test_built_in_adapters.py b/tests/composer/test_built_in_adapters.py deleted file mode 100644 index 813dc25..0000000 --- a/tests/composer/test_built_in_adapters.py +++ /dev/null @@ -1,272 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""Tests for built-in adapter support (Mode A native / Mode B mixed). - -Tests config-level behavior: -- Mode A: control_dims=0, no hiding -- Mode B: mixed built-in + external → control_dims>0, third_party = external only -- SSM rejection for mixed mode -- Model construction with control_dims=0 -""" - -import pytest -import torch - -from granite_switch.config import GraniteSwitchConfig -from granite_switch.hf import GraniteSwitchForCausalLM - - -# ── Fixtures ────────────────────────────────────────────────────────── - - -@pytest.fixture -def mode_a_config(): - """Mode A (native): built-in adapters only, control_dims=0, token-exchange.""" - return GraniteSwitchConfig( - vocab_size=300, - hidden_size=64, - intermediate_size=128, - num_hidden_layers=3, # 1 switch + 2 decoder - num_attention_heads=4, - num_key_value_heads=4, - num_adapters=2, - adapter_token_ids=[250, 251], - # Built-in adapters: substitute = BOS (arbitrary id 1 for tests). - adapter_substitute_token_ids=[1, 1], - adapter_names=["router", "planner"], - max_lora_rank=4, - adapter_ranks=[4, 4], - switch_head_dim=16, - control_dims=0, - # No hiding groups - hiding_groups=None, - hiding_policy=None, - adapter_third_party=None, - ) - - -@pytest.fixture -def mode_b_config(): - """Mode B (mixed): 1 external + 1 built-in adapter, control_dims=8.""" - return GraniteSwitchConfig( - vocab_size=300, - hidden_size=64, - intermediate_size=128, - num_hidden_layers=3, # 1 switch + 2 decoder - num_attention_heads=4, - num_key_value_heads=4, - num_adapters=2, - adapter_token_ids=[250, 251], - adapter_names=["external_rag", "router"], - hiding_groups={"all_controls": ["external_rag", "router"]}, - hiding_policy={ - "base": ["all_controls"], - "external_rag": ["all_controls"], - "router": ["all_controls"], - }, - adapter_third_party=["external_rag"], # Only external is third-party - max_lora_rank=4, - adapter_ranks=[4, 4], - switch_head_dim=16, - control_dims=8, - ) - - -# ── Mode A Config Tests ────────────────────────────────────────────── - - -class TestModeAConfig: - """Config-level checks for Mode A (native, control_dims=0).""" - - def test_control_dims_zero_allowed(self, mode_a_config): - """control_dims=0 should be accepted by config validation.""" - assert mode_a_config.control_dims == 0 - - def test_no_hiding_groups(self, mode_a_config): - """Mode A has no hiding groups.""" - assert mode_a_config.num_hiding_groups == 0 - assert mode_a_config.hiding_group_names == [] - assert mode_a_config.get_hiding_group_token_ids() == {} - - def test_no_third_party(self, mode_a_config): - """Mode A has no third-party adapters.""" - assert mode_a_config.adapter_third_party is None - mask = mode_a_config.get_third_party_adapter_mask() - assert all(v is False for v in mask) - - def test_adapters_present(self, mode_a_config): - """Mode A still has adapters with LoRA.""" - assert mode_a_config.num_adapters == 2 - assert mode_a_config.adapter_ranks == [4, 4] - - -# ── Mode A Model Tests ─────────────────────────────────────────────── - - -class TestModeAModel: - """Model construction and forward pass with control_dims=0.""" - - def test_model_creates_successfully(self, mode_a_config): - """GraniteSwitchForCausalLM should construct with control_dims=0.""" - model = GraniteSwitchForCausalLM(mode_a_config) - assert model is not None - assert model.config.control_dims == 0 - - def test_attention_no_expansion(self, mode_a_config): - """Decoder attention layers should NOT expand control dims.""" - model = GraniteSwitchForCausalLM(mode_a_config) - for layer in model.model.layers: - attn = layer.self_attn - assert not attn.expand_control_dims, ( - "expand_control_dims should be False when control_dims=0" - ) - assert attn.expanded_head_dim == attn.head_dim, ( - "expanded_head_dim should equal head_dim when control_dims=0" - ) - - def test_forward_pass(self, mode_a_config): - """Forward pass should work with control_dims=0.""" - model = GraniteSwitchForCausalLM(mode_a_config).eval() - model.model.adapter_token_ids.data = torch.tensor( - mode_a_config.adapter_token_ids, dtype=torch.long - ) - - input_ids = torch.tensor([[10, 250, 20, 30, 40]]) - with torch.no_grad(): - output = model(input_ids=input_ids) - assert output.logits.shape == (1, 5, mode_a_config.vocab_size) - - def test_no_hiding_buffers(self, mode_a_config): - """Model should have no hiding-related buffers when control_dims=0.""" - model = GraniteSwitchForCausalLM(mode_a_config) - assert model.model.token_to_group_mask is None - assert model.model.adapter_hiding_matrix is None - - def test_lora_shapes_correct(self, mode_a_config): - """LoRA weight shapes should reflect num_adapters.""" - model = GraniteSwitchForCausalLM(mode_a_config) - layer = model.model.layers[0] # First decoder layer - attn = layer.self_attn - # QKV has LoRA with 2 adapters, rank 4 - if hasattr(attn.qkv_proj, "lora_A_slices"): - for lora_a in attn.qkv_proj.lora_A_slices: - assert lora_a.shape[0] == 2, "num_adapters should be 2" - assert lora_a.shape[2] == 4, "max_lora_rank should be 4" - - def test_adapter_routing_works(self, mode_a_config): - """Adapter routing should still work with control_dims=0.""" - model = GraniteSwitchForCausalLM(mode_a_config).eval() - model.model.adapter_token_ids.data = torch.tensor( - mode_a_config.adapter_token_ids, dtype=torch.long - ) - - # Set non-zero lora_B to make adapter effect visible - with torch.no_grad(): - for layer in model.model.layers: - if hasattr(layer.self_attn.o_proj, "lora_B"): - layer.self_attn.o_proj.lora_B.data = ( - torch.randn_like(layer.self_attn.o_proj.lora_B) * 0.1 - ) - - # All base tokens - base_ids = torch.tensor([[10, 20, 30, 40, 50]]) - # With adapter control token - adapter_ids = torch.tensor([[250, 20, 30, 40, 50]]) - - with torch.no_grad(): - out_base = model(input_ids=base_ids) - out_adapter = model(input_ids=adapter_ids) - - # Logits should differ when adapter is active - # (tokens after control token see different LoRA) - diff = (out_base.logits[0, -1] - out_adapter.logits[0, -1]).abs().max() - assert diff > 1e-6, "Adapter should produce different logits than base" - - def test_control_token_logits_finite(self, mode_a_config): - """Control token logits should be finite.""" - model = GraniteSwitchForCausalLM(mode_a_config).eval() - model.model.adapter_token_ids.data = torch.tensor( - mode_a_config.adapter_token_ids, dtype=torch.long - ) - - input_ids = torch.tensor([[250, 20, 30]]) - with torch.no_grad(): - output = model(input_ids=input_ids) - - control_token_logits = output.logits[:, :, mode_a_config.adapter_token_ids] - assert torch.isfinite(control_token_logits).all(), ( - "Control token logits should be finite" - ) - - -# ── Mode B Config Tests ────────────────────────────────────────────── - - -class TestModeBConfig: - """Config-level checks for Mode B (mixed, control_dims>0).""" - - def test_control_dims_positive(self, mode_b_config): - """Mode B should have control_dims > 0.""" - assert mode_b_config.control_dims == 8 - - def test_only_external_is_third_party(self, mode_b_config): - """Only external adapter should be third-party.""" - assert mode_b_config.adapter_third_party == ["external_rag"] - mask = mode_b_config.get_third_party_adapter_mask() - # [base=False, external_rag=True, router=False] - assert mask == [False, True, False] - - def test_hiding_groups_present(self, mode_b_config): - """Mode B should have hiding groups.""" - assert mode_b_config.num_hiding_groups == 1 - - def test_third_party_mask(self, mode_b_config): - """Third-party mask marks only external adapter.""" - mask = mode_b_config.get_third_party_adapter_mask() - assert mask == [False, True, False] - - -# ── Negative Tests ──────────────────────────────────────────────────── - - -class TestNegative: - """Validation errors that should be raised.""" - - def test_control_dims_negative_rejected(self): - """control_dims < 0 should still be rejected.""" - with pytest.raises(ValueError, match="control_dims must be >= 0"): - GraniteSwitchConfig( - vocab_size=256, - hidden_size=64, - intermediate_size=128, - num_hidden_layers=2, - num_attention_heads=4, - num_key_value_heads=4, - num_adapters=0, - control_dims=-1, - ) - - def test_hiding_groups_require_control_dims(self): - """Hiding groups require control_dims >= num_hiding_groups. - - A build with 1 hiding group and control_dims=1 works; with 2 groups and - control_dims=1 it must fail. Substitute ids are supplied only to get - past the newer "no-hiding-and-no-exchange" validator; the assertion is - specifically about the hiding-vs-control_dims arithmetic. - """ - with pytest.raises(ValueError, match="control_dims.*must be >= number of hiding groups"): - GraniteSwitchConfig( - vocab_size=300, - hidden_size=64, - intermediate_size=128, - num_hidden_layers=3, - num_attention_heads=4, - num_key_value_heads=4, - num_adapters=2, - adapter_token_ids=[250, 251], - adapter_substitute_token_ids=[1, 2], - adapter_names=["a", "b"], - hiding_groups={"g1": ["a"], "g2": ["b"]}, # 2 groups > control_dims=1 - max_lora_rank=4, - adapter_ranks=[4, 4], - control_dims=1, - ) diff --git a/tests/conftest.py b/tests/conftest.py index 7fca1da..e261922 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -54,18 +54,11 @@ def tiny_config(): num_key_value_heads=4, num_adapters=2, adapter_token_ids=[250, 251], + adapter_substitute_token_ids=[1, 1], adapter_names=["adapter_a", "adapter_b"], - hiding_groups={"all_controls": ["adapter_a", "adapter_b"]}, - hiding_policy={ - "base": ["all_controls"], - "adapter_a": ["all_controls"], - "adapter_b": ["all_controls"], - }, - adapter_third_party=["adapter_a", "adapter_b"], max_lora_rank=4, adapter_ranks=[4, 4], switch_head_dim=16, - control_dims=8, ) diff --git a/tests/hf/test_kv_hiding_gap_equivalence.py b/tests/hf/test_kv_hiding_gap_equivalence.py deleted file mode 100644 index 7fc857a..0000000 --- a/tests/hf/test_kv_hiding_gap_equivalence.py +++ /dev/null @@ -1,215 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""Verify that a hidden control token creates a transparent gap in attention. - -The upstream model processes a contiguous N-token sequence. The switch model -processes the same N content tokens with a hidden control token inserted, -giving N+1 total tokens. With zero LoRA weights and SingleSwitch's hidden_count -closing the RoPE gap, the logits at corresponding visible positions should -match within FP tolerance. - -The hiding mechanism itself is exact on CPU: -- exp(finfo.min) = 0.0 exactly → hidden token gets zero softmax weight -- Control dims: Q_ctrl * K_ctrl = 1.0 * 0.0 = 0.0 → no score change -- V_control = 0.0 → zero contribution to attention output - -The ~1e-7 tolerance comes from different softmax window sizes at -corresponding positions. Switch position k+1 computes softmax over k+2 -entries (including the ~0 hidden token), while upstream position k computes -softmax over k+1 entries. Although the hidden entry contributes exactly 0.0 -to the denominator, SDPA's fused softmax kernel processes different-length -reductions with different FP accumulation order. Positions before the -control token are bit-exact (same causal window in both models). - -Attention-only models only — Mamba layers do not support KV hiding (the hidden -control token would flow through conv1d and SSM state, corrupting subsequent -positions). Only dense (attention-only) configs from GRANITE4_MINI are tested. - -SingleSwitch: hidden_count = (adapter_indices > 0).long() — fires once, -so 0 before control token and 1 at/after (see issue #16). -""" - -import pytest -import torch -from transformers.models.granitemoehybrid.configuration_granitemoehybrid import ( - GraniteMoeHybridConfig, -) -from transformers.models.granitemoehybrid.modeling_granitemoehybrid import ( - GraniteMoeHybridForCausalLM, -) - -from granite_switch.config import GraniteSwitchConfig -from granite_switch.hf import GraniteSwitchForCausalLM - -from tests.shared.granite4_equivalence import ( - augment_cfg_with_adapters, - transfer_weights, - zero_lora_weights, - GRANITE4_MINI, -) -from tests.shared.gap_equivalence import ( - ATTN_ONLY_NAMES, - make_gapped_inputs, - extract_visible_batched, -) - -# Softmax window-size tolerance (see module docstring). -# Observed max: ~1e-7 across all configs/positions/seeds. -# Use 5e-7 (≈5x margin) to accommodate variation. -_ATOL = 5e-7 - - -# ── Helpers ──────────────────────────────────────────────────────── - - -def _make_gap_pair(cfg_dict): - """Create upstream + 1-adapter switch model pair with zero LoRA weights. - - SingleSwitch: adapter_token_ids=[101], 101 is adapter_0 (KV-hidden). - """ - torch.manual_seed(0) - upstream = GraniteMoeHybridForCausalLM( - GraniteMoeHybridConfig(**cfg_dict) - ).eval() - - switch_cfg_dict = augment_cfg_with_adapters(cfg_dict, num_adapters=1) - switch = GraniteSwitchForCausalLM( - GraniteSwitchConfig(**switch_cfg_dict) - ).eval() - - # Transfer base weights (non-strict: LoRA/switch params left unloaded) - unloaded = transfer_weights(upstream.state_dict(), switch.state_dict()) - - # Verify unloaded params are only LoRA and switch related - for name in unloaded: - assert any(k in name for k in ( - "lora_A", "lora_B", "switch", "adapter_token_ids", - "token_to_group_mask", "adapter_hiding_matrix", - )), f"Unexpected unloaded parameter: {name}" - - # Zero all LoRA weights defensively - zero_lora_weights(switch) - - return upstream, switch - - -def _assert_gap_equivalence(name, upstream, switch, seq_len, ctrl_pos, seed=42): - """Run forward pass and assert visible logits match within tolerance.""" - upstream_ids, switch_ids = make_gapped_inputs(seq_len, ctrl_pos, seed) - - with torch.no_grad(): - upstream_out = upstream(input_ids=upstream_ids, use_cache=False) - switch_out = switch(input_ids=switch_ids, use_cache=False) - - visible = extract_visible_batched(switch_out.logits, ctrl_pos) - - torch.testing.assert_close( - visible, upstream_out.logits, - atol=_ATOL, rtol=0.0, - msg=f"{name}: visible logits diverge (seq={seq_len}, ctrl={ctrl_pos})", - ) - - -# ── Test class: KV Hiding Gap Equivalence ───────────────────────── - - -class TestKVHidingGapEquivalence: - """Verify hidden control token creates a transparent gap. - - The upstream model processes N contiguous tokens. The switch model - processes the same N tokens with a hidden control token inserted (N+1 - total). Visible-position logits match within BLAS gemm tolerance. - """ - - @pytest.fixture(params=ATTN_ONLY_NAMES) - def model_pair(self, request): - model_name = request.param - upstream, switch = _make_gap_pair(GRANITE4_MINI[model_name]) - return model_name, upstream, switch - - def test_gap_short(self, model_pair): - """Short sequence (16 tokens), control token at position 2.""" - name, upstream, switch = model_pair - _assert_gap_equivalence(name, upstream, switch, seq_len=16, ctrl_pos=2) - - def test_gap_long(self, model_pair): - """Longer sequence (64 tokens), control token at position 8.""" - name, upstream, switch = model_pair - _assert_gap_equivalence(name, upstream, switch, seq_len=64, ctrl_pos=8) - - def test_ctrl_at_position_1(self, model_pair): - """Control token at position 1. - - With SingleSwitch, position 0 has no special role (no counting - anchor needed). ctrl_pos=0 is tested separately in - test_multiple_ctrl_positions as a NaN regression guard. - """ - name, upstream, switch = model_pair - _assert_gap_equivalence(name, upstream, switch, seq_len=16, ctrl_pos=1) - - def test_ctrl_near_end(self, model_pair): - """Control token near the end of the sequence (pos=seq_len-2).""" - name, upstream, switch = model_pair - _assert_gap_equivalence(name, upstream, switch, seq_len=16, ctrl_pos=14) - - @pytest.mark.parametrize("ctrl_pos", [0, 1, 2, 4, 8, 14]) - def test_multiple_ctrl_positions(self, model_pair, ctrl_pos): - """Sweep control token across multiple positions. - - ctrl_pos=0 is a regression guard for the NaN bug fixed in PR #87: - when the control token sits at position 0 with no other causal key, - softmax([-inf]) = NaN unless q_control is zeroed for group members. - """ - name, upstream, switch = model_pair - _assert_gap_equivalence(name, upstream, switch, seq_len=16, ctrl_pos=ctrl_pos) - - -# ── Test class: Adapter Indices Sanity ──────────────────────────── - - -class TestAdapterIndicesSanity: - """Verify adapter_indices correctness with a single hidden control token. - - Uses a single config (4.0-350m) to check that: - - Positions before the control token have adapter_indices=0 (base) - - Positions at and after the control token have adapter_indices=1 - """ - - @pytest.fixture - def model(self): - cfg_dict = GRANITE4_MINI["4.0-350m"] - _, switch = _make_gap_pair(cfg_dict) - return switch - - def _run(self, model, ctrl_pos, seed=42): - """Run forward pass and return adapter_indices.""" - _, switch_ids = make_gapped_inputs(seq_len=16, ctrl_pos=ctrl_pos, seed=seed) - with torch.no_grad(): - model(input_ids=switch_ids, use_cache=False) - return model.model._last_adapter_indices - - def test_adapter_indices_before_ctrl(self, model): - """Positions before control token should be base (0).""" - ctrl_pos = 4 - ai = self._run(model, ctrl_pos) - assert (ai[:, :ctrl_pos] == 0).all(), ( - f"Pre-control positions should be base, got {ai[:, :ctrl_pos]}" - ) - - def test_adapter_indices_at_and_after_ctrl(self, model): - """Positions at and after control token should be adapter_0 (1).""" - ctrl_pos = 4 - ai = self._run(model, ctrl_pos) - assert (ai[:, ctrl_pos:] == 1).all(), ( - f"Post-control positions should be adapter_0 (1), got {ai[:, ctrl_pos:]}" - ) - - def test_adapter_indices_sweep(self, model): - """Sweep ctrl_pos and verify adapter_indices boundary.""" - for ctrl_pos in [1, 2, 4, 8, 14]: - ai = self._run(model, ctrl_pos, seed=ctrl_pos) - assert (ai[:, :ctrl_pos] == 0).all(), ( - f"ctrl_pos={ctrl_pos}: pre-ctrl should be 0, got {ai[:, :ctrl_pos]}" - ) - assert (ai[:, ctrl_pos:] == 1).all(), ( - f"ctrl_pos={ctrl_pos}: post-ctrl should be 1, got {ai[:, ctrl_pos:]}" - ) diff --git a/tests/hf/test_model_forward.py b/tests/hf/test_model_forward.py index 489264f..ce065ee 100644 --- a/tests/hf/test_model_forward.py +++ b/tests/hf/test_model_forward.py @@ -47,7 +47,7 @@ def _set_nonzero_lora_B(model, scale=0.1): @pytest.fixture def tiny_single_config(): - """Minimal SingleSwitch config for CPU tests.""" + """Minimal SingleSwitch config for CPU tests (token exchange).""" return GraniteSwitchConfig( vocab_size=300, hidden_size=64, @@ -57,37 +57,11 @@ def tiny_single_config(): num_key_value_heads=4, num_adapters=2, adapter_token_ids=[250, 251], + adapter_substitute_token_ids=[1, 1], adapter_names=["adapter_1", "adapter_2"], - hiding_groups={"all_controls": ["adapter_1", "adapter_2"]}, - hiding_policy={"base": ["all_controls"], "adapter_1": ["all_controls"], "adapter_2": ["all_controls"]}, - adapter_third_party=["adapter_1", "adapter_2"], - max_lora_rank=4, - adapter_ranks=[4, 4], - switch_head_dim=16, - control_dims=8, - ) - - -@pytest.fixture -def tiny_basic_mixed_tp_config(): - """SingleSwitch config where only adapter_1 is third-party (adapter_2 is not).""" - return GraniteSwitchConfig( - vocab_size=300, - hidden_size=64, - intermediate_size=128, - num_hidden_layers=3, - num_attention_heads=4, - num_key_value_heads=4, - num_adapters=2, - adapter_token_ids=[250, 251], - adapter_names=["adapter_1", "adapter_2"], - hiding_groups={"all_controls": ["adapter_1", "adapter_2"]}, - hiding_policy={"base": ["all_controls"], "adapter_1": ["all_controls"], "adapter_2": ["all_controls"]}, - adapter_third_party=["adapter_1"], # only adapter_1 is third-party max_lora_rank=4, adapter_ranks=[4, 4], switch_head_dim=16, - control_dims=8, ) @@ -238,101 +212,7 @@ def test_different_adapters_produce_different_post_control_logits(self, tiny_con # ════════════════════════════════════════════════════════════════════ -# 6. Control token KV invisibility -# ════════════════════════════════════════════════════════════════════ - -class TestControlTokenKVInvisibility: - """Verify control_dims makes control tokens invisible in KV cache.""" - - def test_control_token_kv_invisible_to_future_positions(self, tiny_config): - """Perturbing a control token's embedding doesn't affect future positions.""" - torch.manual_seed(42) - model = GraniteSwitchForCausalLM(tiny_config).eval() - _set_adapter_token_ids(model, tiny_config.adapter_token_ids) - - # Control token 250 at position 2 - input_ids = torch.tensor([[10, 20, 250, 30, 40, 50, 60, 70]]) - - # Pass A: original embeddings - with torch.no_grad(): - out_a = model(input_ids=input_ids, output_hidden_states=True) - hidden_a = out_a.hidden_states # tuple of [1, 8, hidden_size] - - # Perturb the control token's embedding - with torch.no_grad(): - perturbation = torch.randn(tiny_config.hidden_size) * 10.0 - model.model.embed_tokens.weight.data[250] += perturbation - - # Pass B: perturbed embedding - with torch.no_grad(): - out_b = model(input_ids=input_ids, output_hidden_states=True) - hidden_b = out_b.hidden_states - - # Check each layer's hidden states - for layer_idx in range(len(hidden_a)): - ha = hidden_a[layer_idx][0] # [8, hidden_size] - hb = hidden_b[layer_idx][0] - - # Pre-control (positions 0, 1): identical (causal, can't see pos 2) - torch.testing.assert_close( - ha[:2], hb[:2], - msg=f"Layer {layer_idx}: pre-control hidden states should be identical" - ) - - # At control position (2): must differ (embedding changed) - assert not torch.allclose(ha[2], hb[2]), \ - f"Layer {layer_idx}: control token hidden state should differ after perturbation" - - # Post-control (positions 3+): identical (control token KV is invisible) - torch.testing.assert_close( - ha[3:], hb[3:], - msg=f"Layer {layer_idx}: post-control hidden states should be identical " - f"(control token KV masked by control_dims)" - ) - - -# ════════════════════════════════════════════════════════════════════ -# 7. Control token KV visibility -# ════════════════════════════════════════════════════════════════════ - -class TestControlTokenKVVisibility: - """Verify control tokens are KV-invisible (hidden from attention via control dimensions).""" - - def _make_model(self, config): - torch.manual_seed(42) - model = GraniteSwitchForCausalLM(config).eval() - _set_adapter_token_ids(model, config.adapter_token_ids) - return model - - def test_adapter_token_kv_invisible(self, tiny_single_config): - """Adapter token (250) is KV-invisible: perturbing doesn't affect future.""" - config = tiny_single_config - model = self._make_model(config) - - input_ids = torch.tensor([[10, 20, 250, 30, 40, 50, 60, 70]]) - - with torch.no_grad(): - out_a = model(input_ids=input_ids, output_hidden_states=True) - hidden_a = out_a.hidden_states - - with torch.no_grad(): - perturbation = torch.randn(config.hidden_size) * 10.0 - model.model.embed_tokens.weight.data[250] += perturbation - - with torch.no_grad(): - out_b = model(input_ids=input_ids, output_hidden_states=True) - hidden_b = out_b.hidden_states - - for layer_idx in range(len(hidden_a)): - ha = hidden_a[layer_idx][0] - hb = hidden_b[layer_idx][0] - torch.testing.assert_close( - ha[3:], hb[3:], - msg=f"Layer {layer_idx}: post-adapter-token hidden states should be identical" - ) - -# ════════════════════════════════════════════════════════════════════ -# 8. Activating tokens: switch behavior (explicit adapter_indices) +# 6. Activating tokens: switch behavior (explicit adapter_indices) # ════════════════════════════════════════════════════════════════════ class TestActivatingTokenSwitch: @@ -355,13 +235,13 @@ def test_activating_adapter_indices_nonzero(self, tiny_single_config): # ════════════════════════════════════════════════════════════════════ -# 9. Native mode: control_dims=0 (no KV hiding) +# 7. Token-exchange forward pass tests # ════════════════════════════════════════════════════════════════════ @pytest.fixture def tiny_native_config(): - """Minimal config for native mode (control_dims=0, token-exchange).""" + """Minimal config for token-exchange mode.""" return GraniteSwitchConfig( vocab_size=300, hidden_size=64, @@ -376,16 +256,11 @@ def tiny_native_config(): max_lora_rank=4, adapter_ranks=[4, 4], switch_head_dim=16, - control_dims=0, - # No hiding - hiding_groups=None, - hiding_policy=None, - adapter_third_party=None, ) class TestNativeModeForward: - """Forward pass tests with control_dims=0 (native mode).""" + """Forward pass tests with token-exchange enabled.""" def test_forward_produces_logits(self, tiny_native_config): """Basic forward pass succeeds and produces correct-shaped logits.""" @@ -400,24 +275,6 @@ def test_forward_produces_logits(self, tiny_native_config): assert output.logits.shape == (1, 5, config.vocab_size) assert torch.isfinite(output.logits).all() - def test_no_expansion_in_attention(self, tiny_native_config): - """Attention layers should not expand control dimensions.""" - config = tiny_native_config - model = GraniteSwitchForCausalLM(config) - - for layer in model.model.layers: - attn = layer.self_attn - assert not attn.expand_control_dims - assert attn.expanded_head_dim == attn.head_dim - - def test_no_hiding_buffers(self, tiny_native_config): - """Model should have no hiding group buffers.""" - config = tiny_native_config - model = GraniteSwitchForCausalLM(config) - - assert model.model.token_to_group_mask is None - assert model.model.adapter_hiding_matrix is None - def test_control_token_logits_finite(self, tiny_native_config): """Control token logits should be finite.""" config = tiny_native_config @@ -431,7 +288,7 @@ def test_control_token_logits_finite(self, tiny_native_config): # All control token logits should be finite for tid in config.adapter_token_ids: assert torch.isfinite(output.logits[:, :, tid]).all(), ( - f"Token {tid} logits should be finite in native mode" + f"Token {tid} logits should be finite" ) def test_adapter_effect_visible(self, tiny_native_config): @@ -452,7 +309,7 @@ def test_adapter_effect_visible(self, tiny_native_config): assert diff > 1e-6, "Adapter should produce different logits" def test_batch_forward(self, tiny_native_config): - """Batched forward pass works with control_dims=0.""" + """Batched forward pass works.""" config = tiny_native_config model = GraniteSwitchForCausalLM(config).eval() _set_adapter_token_ids(model, config.adapter_token_ids) diff --git a/tests/hf/test_position_zero_nan.py b/tests/hf/test_position_zero_nan.py deleted file mode 100644 index 5751d9a..0000000 --- a/tests/hf/test_position_zero_nan.py +++ /dev/null @@ -1,188 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""NaN regression tests — control token at sequence position 0 (HF backend). - -HF-specific unit tests for GraniteLoRAEmbeddedAttention._expand_with_control_dimensions -(batch/seq tensor layout: [batch, seq, heads, head_dim]) plus shared SDPANaNCases. - -Note: model-level finiteness tests are not included here — the NaN bug only manifests -in vLLM's FlashAttention path, not in HF's stable softmax. See tests/vllm/ for those. -""" - -import types - -import torch - -from granite_switch.hf.core.lora import GraniteLoRAEmbeddedAttention - -from tests.shared.position_zero_nan_cases import SDPANaNCases - - -# ── Helpers ──────────────────────────────────────────────────────── - - -def _stub(num_heads=4, num_kv_heads=1, control_dims=1): - """Minimal namespace satisfying _expand_with_control_dimensions's self usage.""" - return types.SimpleNamespace( - num_heads=num_heads, - num_key_value_heads=num_kv_heads, - control_dims=control_dims, - ) - - -def _expand(stub, q, k, v, membership, suppression): - return GraniteLoRAEmbeddedAttention._expand_with_control_dimensions( - stub, q, k, v, membership, suppression - ) - - - -# ════════════════════════════════════════════════════════════════════ -# 1. HF-specific unit tests: _expand_with_control_dimensions -# Tensor layout: [batch, seq_len, num_heads, head_dim] -# ════════════════════════════════════════════════════════════════════ - - -class TestExpandControlDimensions: - """Direct tests of _expand_with_control_dimensions (HF tensor layout). - - token_group_membership=True marks the control token itself. - query_group_suppression=True marks adapter-generated tokens that suppress - the group — these are NOT group members and must keep q_control=1. - """ - - _HEAD_DIM = 32 - - def _qkv(self, stub, seq_len): - q = torch.randn(1, seq_len, stub.num_heads, self._HEAD_DIM) - k = torch.randn(1, seq_len, stub.num_key_value_heads, self._HEAD_DIM) - v = torch.randn(1, seq_len, stub.num_key_value_heads, self._HEAD_DIM) - return q, k, v - - # ── fix: control token must have q_control = 0 ────────────────── - - def test_control_token_q_hide_zero_at_position_zero(self): - """Core fix: control token at pos 0 must not activate Q-side hiding. - - Before the fix q_control was 1.0 unconditionally, so - softmax([−∞]) = NaN when it had no other causal keys. - """ - stub = _stub() - membership = torch.ones(1, 1, 1, dtype=torch.bool) - suppression = torch.ones(1, 1, 1, dtype=torch.bool) - q, k, v = self._qkv(stub, seq_len=1) - - q_exp, _, _ = _expand(stub, q, k, v, membership, suppression) - - q_ctrl = q_exp[0, 0, :, self._HEAD_DIM:] - assert q_ctrl.eq(0).all(), f"Control token at pos 0: q_control must be 0, got {q_ctrl}" - - def test_control_token_q_hide_zero_at_later_position(self): - """Control token q_control is 0 regardless of its sequence position.""" - stub = _stub() - membership = torch.zeros(1, 5, 1, dtype=torch.bool) - membership[0, 3, 0] = True - suppression = torch.ones(1, 5, 1, dtype=torch.bool) - q, k, v = self._qkv(stub, seq_len=5) - - q_exp, _, _ = _expand(stub, q, k, v, membership, suppression) - - assert q_exp[0, 3, :, self._HEAD_DIM:].eq(0).all(), "Control token at pos 3: q_control must be 0" - - # ── adapter-generated tokens must still suppress the control token ── - - def test_adapter_generated_tokens_q_hide_one(self): - """Adapter-generated tokens (non-members) keep q_control=1 to hide the control token.""" - stub = _stub() - membership = torch.zeros(1, 5, 1, dtype=torch.bool) - membership[0, 0, 0] = True # control token at pos 0 - suppression = torch.ones(1, 5, 1, dtype=torch.bool) - q, k, v = self._qkv(stub, seq_len=5) - - q_exp, _, _ = _expand(stub, q, k, v, membership, suppression) - - assert q_exp[0, 0, :, self._HEAD_DIM:].eq(0).all(), "Control token: q_control must be 0" - for pos in range(1, 5): - assert q_exp[0, pos, :, self._HEAD_DIM:].eq(1).all(), ( - f"Adapter-generated token at pos {pos}: q_control must be 1" - ) - - # ── k-side unchanged by fix ────────────────────────────────────── - - def test_k_side_finfo_min_for_control_token(self): - """K-side branding is unaffected by the fix — control token gets finfo.min.""" - stub = _stub() - membership = torch.ones(1, 1, 1, dtype=torch.bool) - q, k, v = self._qkv(stub, seq_len=1) - - _, k_exp, _ = _expand(stub, q, k, v, membership, None) - - expected_min = torch.finfo(k.dtype).min - k_ctrl = k_exp[0, 0, :, self._HEAD_DIM:] - torch.testing.assert_close(k_ctrl, torch.full_like(k_ctrl, expected_min)) - - def test_k_side_zero_for_adapter_generated_tokens(self): - """Adapter-generated tokens have k_control=0.""" - stub = _stub() - q, k, v = self._qkv(stub, seq_len=3) - - _, k_exp, _ = _expand(stub, q, k, v, torch.zeros(1, 3, 1, dtype=torch.bool), None) - - assert k_exp[:, :, :, self._HEAD_DIM:].eq(0).all() - - # ── v-side and no-mask baseline ────────────────────────────────── - - def test_v_control_always_zero(self): - """V control dimensions are always zero.""" - stub = _stub() - q, k, v = self._qkv(stub, seq_len=3) - _, _, v_exp = _expand( - stub, q, k, v, - torch.ones(1, 3, 1, dtype=torch.bool), - torch.ones(1, 3, 1, dtype=torch.bool), - ) - assert v_exp[:, :, :, self._HEAD_DIM:].eq(0).all() - - def test_both_none_leaves_all_control_dims_zero(self): - """With both tensors None, all control dims remain zero.""" - stub = _stub(control_dims=2) - q, k, v = self._qkv(stub, seq_len=4) - q_exp, k_exp, v_exp = _expand(stub, q, k, v, None, None) - assert q_exp[..., self._HEAD_DIM:].eq(0).all() - assert k_exp[..., self._HEAD_DIM:].eq(0).all() - assert v_exp[..., self._HEAD_DIM:].eq(0).all() - - # ── multiple groups ────────────────────────────────────────────── - - def test_multiple_groups_independent(self): - """Control token of group 0 only zeroes q_control for group 0.""" - stub = _stub(control_dims=2) - membership = torch.zeros(1, 1, 2, dtype=torch.bool) - membership[0, 0, 0] = True - suppression = torch.ones(1, 1, 2, dtype=torch.bool) - q, k, v = self._qkv(stub, seq_len=1) - - q_exp, _, _ = _expand(stub, q, k, v, membership, suppression) - q_ctrl = q_exp[0, 0, :, self._HEAD_DIM:] - - assert q_ctrl[:, 0].eq(0).all(), "Group 0 dim must be 0 (control token is a member)" - assert q_ctrl[:, 1].eq(1).all(), "Group 1 dim must be 1 (control token is not a member)" - - def test_original_qkv_dimensions_preserved(self): - """Original Q/K/V dimensions are unchanged; only control dims appended.""" - stub = _stub(control_dims=3) - q, k, v = self._qkv(stub, seq_len=5) - q_exp, k_exp, v_exp = _expand(stub, q, k, v, None, None) - torch.testing.assert_close(q_exp[..., : self._HEAD_DIM], q) - torch.testing.assert_close(k_exp[..., : self._HEAD_DIM], k) - torch.testing.assert_close(v_exp[..., : self._HEAD_DIM], v) - - -# ════════════════════════════════════════════════════════════════════ -# 2. Shared SDPA cases -# ════════════════════════════════════════════════════════════════════ - - -class TestSDPANaN(SDPANaNCases): - pass - - diff --git a/tests/hf/test_qk_norm.py b/tests/hf/test_qk_norm.py index a603919..557d599 100644 --- a/tests/hf/test_qk_norm.py +++ b/tests/hf/test_qk_norm.py @@ -33,7 +33,6 @@ def _make_config(qk_norm: bool, num_adapters: int = 0) -> GraniteSwitchConfig: adapter_token_ids=[], adapter_names=[], adapter_ranks=[], - control_dims=0, qk_norm=qk_norm, ) config._attn_implementation = "sdpa" @@ -110,12 +109,10 @@ def test_output_differs_with_qk_norm(self): with torch.no_grad(): out_off, _, _ = attn_off( hidden, adapter_indices, - token_group_membership=None, query_group_suppression=None, position_embeddings=pos_emb, ) out_on, _, _ = attn_on( hidden, adapter_indices, - token_group_membership=None, query_group_suppression=None, position_embeddings=pos_emb, ) @@ -137,7 +134,6 @@ def test_output_shape_preserved(self): with torch.no_grad(): out, _, _ = attn( hidden, adapter_indices, - token_group_membership=None, query_group_suppression=None, position_embeddings=pos_emb, ) diff --git a/tests/hf/test_token_exchange.py b/tests/hf/test_token_exchange.py index 3f9e306..11818e1 100644 --- a/tests/hf/test_token_exchange.py +++ b/tests/hf/test_token_exchange.py @@ -1,16 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 """HF backend tests for token-exchange mode. -Three properties under test: +Two properties under test: 1. The embedding at each control-token position equals the embedding of its substitute token (scaled by embedding_multiplier), not the original control-token embedding. -2. The KV cache head_dim is the native projection_head_dim (not expanded) — - this is the load-bearing correctness property that proves control_dims=0 - actually eliminates KV-cache overhead. -3. Legacy hiding mode is untouched: same inputs, but the embedding at the - control position matches the raw control-token embedding and the KV cache - head_dim is projection_head_dim + control_dims. +2. The KV cache head_dim is the native projection_head_dim — token-exchange + does not expand the KV cache. """ import pytest @@ -20,8 +16,8 @@ from granite_switch.hf import GraniteSwitchForCausalLM -def _build(num_adapters=2, control_dims=0, substitute_ids=(1, 7)): - kwargs = dict( +def _build(num_adapters=2, substitute_ids=(1, 7)): + return GraniteSwitchConfig( vocab_size=200, hidden_size=32, num_attention_heads=4, @@ -37,12 +33,9 @@ def _build(num_adapters=2, control_dims=0, substitute_ids=(1, 7)): max_lora_rank=4, adapter_token_ids=[100, 101][:num_adapters], adapter_names=["a", "b"][:num_adapters], - control_dims=control_dims, + adapter_substitute_token_ids=list(substitute_ids[:num_adapters]), torch_dtype=torch.float32, ) - if control_dims == 0: - kwargs["adapter_substitute_token_ids"] = list(substitute_ids[:num_adapters]) - return GraniteSwitchConfig(**kwargs) @torch.no_grad() @@ -55,7 +48,7 @@ class TestTokenExchangeEmbeddingSwap: """The control position's residual-stream input is the substitute embedding.""" def test_swap_picks_substitute_embedding(self): - config = _build(control_dims=0, substitute_ids=(5, 7)) + config = _build(substitute_ids=(5, 7)) model, _ = _forward( config, torch.tensor([[10, 20, 100, 40]], dtype=torch.long), # adapter 0 control at pos 2 @@ -70,7 +63,7 @@ def test_swap_picks_substitute_embedding(self): assert lut[40].item() == -1 def test_swap_is_not_applied_on_non_control_positions(self): - config = _build(control_dims=0, substitute_ids=(5, 7)) + config = _build(substitute_ids=(5, 7)) model = GraniteSwitchForCausalLM(config).eval() # Run once through the model with a control token and once without; # verify the non-control embedding rows are identical. @@ -83,10 +76,11 @@ def test_swap_is_not_applied_on_non_control_positions(self): class TestKVCacheHeadDim: - """The load-bearing correctness property: control_dims=0 collapses KV head_dim.""" + """The load-bearing correctness property: KV cache head_dim equals + the native projection_head_dim — no expansion.""" def test_token_exchange_native_head_dim(self): - config = _build(control_dims=0, substitute_ids=(5, 7)) + config = _build(substitute_ids=(5, 7)) _, out = _forward( config, torch.tensor([[10, 20, 100, 40]], dtype=torch.long), @@ -94,24 +88,13 @@ def test_token_exchange_native_head_dim(self): # layers[0] is the switch; layers[1] is the first decoder layer. decoder_key = out.past_key_values.layers[1].keys assert decoder_key.shape[-1] == config.projection_head_dim - assert config.use_token_exchange is True - - def test_legacy_hiding_expanded_head_dim(self): - config = _build(control_dims=32) - _, out = _forward( - config, - torch.tensor([[10, 20, 100, 40]], dtype=torch.long), - ) - decoder_key = out.past_key_values.layers[1].keys - assert decoder_key.shape[-1] == config.projection_head_dim + 32 - assert config.use_token_exchange is False class TestSwitchStillDetectsAdapter: """Swap must happen AFTER the switch reads input_ids, so detection is unaffected.""" def test_adapter_indices_still_activate(self): - config = _build(control_dims=0, substitute_ids=(5, 7)) + config = _build(substitute_ids=(5, 7)) model, _ = _forward( config, torch.tensor([[10, 20, 100, 40, 50]], dtype=torch.long), @@ -131,7 +114,7 @@ class TestPositionCorrectionSkipped: def test_no_position_shift_in_te_mode(self): """RoPE positions should equal the input positions (no hidden_count subtraction).""" - config = _build(control_dims=0, substitute_ids=(5, 7)) + config = _build(substitute_ids=(5, 7)) model = GraniteSwitchForCausalLM(config).eval() input_ids = torch.tensor([[10, 100, 20, 30]], dtype=torch.long) # Forward runs without error; the guarded branch would otherwise fire diff --git a/tests/integration/test_token_exchange_parity.py b/tests/integration/test_token_exchange_parity.py deleted file mode 100644 index e9d5ab1..0000000 --- a/tests/integration/test_token_exchange_parity.py +++ /dev/null @@ -1,602 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""Parity eval: legacy KV hiding vs. token exchange. - -Measures four metrics per token position, teacher-forced, across a list of -prompts: - -1. KL(p_old || p_new) — full-distribution divergence -2. Top-1 agreement — headline sanity metric -3. Nucleus (top-p=0.9) Jaccard — do sampling sets agree? -4. Mass under old nucleus — does new model put probability on tokens old - model considered plausible? - -Two modes: - -**Synthetic mode (default, CPU-friendly):** builds two HF models with -identical base weights, one in legacy KV-hiding mode and one in token- -exchange mode. Measures only the effect of control-token handling on -logits — *not* trained-adapter behavior. Useful as a plumbing sanity -check and a regression guard. - -**Real-model mode (GPU, opt-in):** set -``GRANITE_SWITCH_PARITY_MODELS='{"old":"/path","new":"/path"}'`` (JSON with -two paths) and pytest will load actual composed checkpoints. This is the -pre-merge gate described in docs/KV_CACHE_OVERHEAD_REMOVAL.md §4. - -Run directly:: - - python -m tests.integration.test_token_exchange_parity - -Run as test:: - - pytest tests/integration/test_token_exchange_parity.py -v -s --tb=short -""" - -from __future__ import annotations - -import argparse -import json -import os -import statistics -from dataclasses import dataclass, field -from typing import Dict, List, Optional, Tuple - -import pytest -import torch -import torch.nn.functional as F - -from granite_switch.config import GraniteSwitchConfig -from granite_switch.hf import GraniteSwitchForCausalLM - - -# ──────────────────────────────────────────────────────────────────── -# Metric primitives -# ──────────────────────────────────────────────────────────────────── - - -def _kl_from_logits(logits_p: torch.Tensor, logits_q: torch.Tensor) -> float: - """KL(p || q) in nats, computed from logits to avoid softmax underflow. - - Equivalent to ``sum_i p_i * (log p_i - log q_i)`` but evaluated via - log_softmax so that very small tail probabilities don't underflow to - zero before the log. logits_{p,q}: 1-D [vocab]. - """ - log_p = F.log_softmax(logits_p, dim=-1) - log_q = F.log_softmax(logits_q, dim=-1) - p = log_p.exp() - return float((p * (log_p - log_q)).sum()) - - -def _nucleus_indices(p: torch.Tensor, top_p: float) -> torch.Tensor: - """Smallest descending-sorted prefix whose cumulative sum >= top_p. - - The nucleus always contains at least one token (the argmax). k is the - index (1-based count) of the first element where cumsum >= top_p. - """ - sorted_p, sorted_idx = torch.sort(p, descending=True) - cumsum = torch.cumsum(sorted_p, dim=0) - # Smallest index where cumsum >= top_p. If cumsum never reaches top_p - # (floating-point edge), keep everything. - ge = cumsum >= top_p - if ge.any(): - k = int(torch.argmax(ge.int()).item()) + 1 # +1: argmax is 0-indexed, we want count - else: - k = sorted_idx.numel() - k = max(1, min(k, sorted_idx.numel())) - return sorted_idx[:k] - - -def _jaccard(a: torch.Tensor, b: torch.Tensor) -> float: - a_set = set(a.tolist()) - b_set = set(b.tolist()) - if not a_set and not b_set: - return 1.0 - return len(a_set & b_set) / len(a_set | b_set) - - -# ──────────────────────────────────────────────────────────────────── -# Core eval -# ──────────────────────────────────────────────────────────────────── - - -@dataclass -class PositionResult: - kl: float - top1_agree: bool - nucleus_jaccard: float - mass_under_old_nucleus: float - old_nucleus_size: int - new_nucleus_size: int - # Partition flag: True if this position is at or after a control token - # in the causal past (adapter-activated); False if it's in the base path. - adapter_active: bool = False - - -@dataclass -class AggregateResult: - n_positions: int - kl_mean: float - kl_median: float - kl_p95: float - kl_max: float - top1_agree_rate: float - nucleus_jaccard_mean: float - nucleus_jaccard_exact_match_rate: float - mass_under_old_nucleus_mean: float - mass_under_old_nucleus_p05: float - # Fraction of positions where the new model places less than 0.80 of its - # probability mass on the old model's nucleus — the actionable "how bad" - # signal for real-model runs. - frac_mass_under_80: float - # Nucleus size distribution: useful for judging whether top-1 agreement - # is a meaningful metric (confident models: median 1-5; noise: ~vocab/2). - old_nucleus_size_median: float - new_nucleus_size_median: float - old_nucleus_size_p05: float - old_nucleus_size_p95: float - - def render(self, heading: str = "") -> str: - header = f"── {heading} ──" if heading else "" - trusted = "(trusted)" if self.old_nucleus_size_median < 50 else "(noisy: wide nucleus)" - lines = [ - header, - f"n_positions = {self.n_positions}", - "", - "KL(p_old || p_new) per position:", - f" mean = {self.kl_mean:.6f}", - f" median = {self.kl_median:.6f}", - f" p95 = {self.kl_p95:.6f}", - f" max = {self.kl_max:.6f}", - "", - f"Top-1 agreement rate = {self.top1_agree_rate:.4f} {trusted}", - "", - "Nucleus (top-p=0.9):", - f" size — old p05/med/p95 = {self.old_nucleus_size_p05:g} / {self.old_nucleus_size_median:g} / {self.old_nucleus_size_p95:g}", - f" size — new median = {self.new_nucleus_size_median:g}", - f" Jaccard mean = {self.nucleus_jaccard_mean:.4f}", - f" exact-match rate = {self.nucleus_jaccard_exact_match_rate:.4f}", - "", - "Mass under old nucleus (new model):", - f" mean = {self.mass_under_old_nucleus_mean:.4f}", - f" p05 (worst 5% of positions) = {self.mass_under_old_nucleus_p05:.4f}", - f" frac positions < 0.80 = {self.frac_mass_under_80:.4f}", - ] - return "\n".join(line for line in lines if line is not None) - - def as_dict(self) -> Dict[str, float]: - return { - "n_positions": self.n_positions, - "kl_mean": self.kl_mean, - "kl_median": self.kl_median, - "kl_p95": self.kl_p95, - "kl_max": self.kl_max, - "top1_agree_rate": self.top1_agree_rate, - "nucleus_jaccard_mean": self.nucleus_jaccard_mean, - "nucleus_jaccard_exact_match_rate": self.nucleus_jaccard_exact_match_rate, - "mass_under_old_nucleus_mean": self.mass_under_old_nucleus_mean, - "mass_under_old_nucleus_p05": self.mass_under_old_nucleus_p05, - "frac_mass_under_80": self.frac_mass_under_80, - "old_nucleus_size_median": self.old_nucleus_size_median, - "new_nucleus_size_median": self.new_nucleus_size_median, - "old_nucleus_size_p05": self.old_nucleus_size_p05, - "old_nucleus_size_p95": self.old_nucleus_size_p95, - } - - -def _adapter_active_mask(input_ids: torch.Tensor, adapter_token_ids: List[int]) -> torch.Tensor: - """[seq_len] bool: True at position s if any control token appears at - positions <= s. Token at position s itself counts — the swap happens - before that position's hidden state enters the decoder.""" - ctrl_set = set(adapter_token_ids) - is_ctrl = torch.tensor( - [int(t.item()) in ctrl_set for t in input_ids], dtype=torch.bool - ) - # Cumulative OR along the sequence. - return torch.cummax(is_ctrl.int(), dim=0).values.bool() - - -def _per_position_metrics( - logits_old: torch.Tensor, - logits_new: torch.Tensor, - top_p: float, - adapter_active: Optional[torch.Tensor] = None, -) -> List[PositionResult]: - """logits_{old,new}: [seq_len, vocab_size]. Returns one result per position.""" - assert logits_old.shape == logits_new.shape - results: List[PositionResult] = [] - # Promote to float32 for metric stability. - logits_old = logits_old.to(torch.float32) - logits_new = logits_new.to(torch.float32) - p_old_all = F.softmax(logits_old, dim=-1) - p_new_all = F.softmax(logits_new, dim=-1) - for s in range(logits_old.shape[0]): - p_old = p_old_all[s] - p_new = p_new_all[s] - nuc_old = _nucleus_indices(p_old, top_p) - nuc_new = _nucleus_indices(p_new, top_p) - results.append( - PositionResult( - kl=_kl_from_logits(logits_old[s], logits_new[s]), - top1_agree=bool(p_old.argmax() == p_new.argmax()), - nucleus_jaccard=_jaccard(nuc_old, nuc_new), - mass_under_old_nucleus=float(p_new[nuc_old].sum()), - old_nucleus_size=int(nuc_old.numel()), - new_nucleus_size=int(nuc_new.numel()), - adapter_active=bool(adapter_active[s]) if adapter_active is not None else False, - ) - ) - return results - - -def _aggregate(results: List[PositionResult]) -> AggregateResult: - if not results: - raise ValueError("No positions measured") - kls = sorted(r.kl for r in results) - jaccards = [r.nucleus_jaccard for r in results] - mass = sorted(r.mass_under_old_nucleus for r in results) - old_sizes = sorted(r.old_nucleus_size for r in results) - n = len(results) - p05_idx = max(0, int(n * 0.05) - 1) - p95_idx = min(n - 1, int(n * 0.95)) - return AggregateResult( - n_positions=n, - kl_mean=statistics.mean(kls), - kl_median=statistics.median(kls), - kl_p95=kls[p95_idx], - kl_max=kls[-1], - top1_agree_rate=sum(r.top1_agree for r in results) / n, - nucleus_jaccard_mean=statistics.mean(jaccards), - nucleus_jaccard_exact_match_rate=sum(1 for j in jaccards if j == 1.0) / n, - mass_under_old_nucleus_mean=statistics.mean(mass), - mass_under_old_nucleus_p05=mass[p05_idx], - frac_mass_under_80=sum(1 for m in mass if m < 0.80) / n, - old_nucleus_size_median=statistics.median(r.old_nucleus_size for r in results), - new_nucleus_size_median=statistics.median(r.new_nucleus_size for r in results), - old_nucleus_size_p05=old_sizes[p05_idx], - old_nucleus_size_p95=old_sizes[p95_idx], - ) - - -# ──────────────────────────────────────────────────────────────────── -# Synthetic model builder (CPU-friendly, weight-sharing pair) -# ──────────────────────────────────────────────────────────────────── - - -_SYNTHETIC_BASE_KWARGS = dict( - vocab_size=512, - hidden_size=64, - num_attention_heads=4, - num_key_value_heads=2, - num_hidden_layers=4, - intermediate_size=128, - shared_intermediate_size=128, - max_position_embeddings=128, - mamba_n_heads=1, - mamba_expand=1, - torch_dtype=torch.float32, -) - - -def _build_synthetic_pair( - num_adapters: int = 2, - seed: int = 0, -) -> Tuple[GraniteSwitchForCausalLM, GraniteSwitchForCausalLM]: - """Build two models with identical base weights: old=hiding, new=exchange. - - Any logit difference between them is therefore purely from control-token - handling, not from weight initialization. - """ - adapter_token_ids = [100, 101][:num_adapters] - substitute_token_ids = [5, 7][:num_adapters] - - torch.manual_seed(seed) - old_config = GraniteSwitchConfig( - **_SYNTHETIC_BASE_KWARGS, - num_adapters=num_adapters, - adapter_ranks=[4] * num_adapters, - max_lora_rank=4, - adapter_token_ids=adapter_token_ids, - adapter_names=[f"a{i}" for i in range(num_adapters)], - control_dims=32, - ) - old_model = GraniteSwitchForCausalLM(old_config).eval() - - new_config = GraniteSwitchConfig( - **_SYNTHETIC_BASE_KWARGS, - num_adapters=num_adapters, - adapter_ranks=[4] * num_adapters, - max_lora_rank=4, - adapter_token_ids=adapter_token_ids, - adapter_substitute_token_ids=substitute_token_ids, - adapter_names=[f"a{i}" for i in range(num_adapters)], - control_dims=0, - ) - new_model = GraniteSwitchForCausalLM(new_config).eval() - - # Share weights where the two configs have matching parameter shapes. - # Non-shared: tensors whose shape depends on control_dims (e.g. switch - # head_dim in the legacy path differs from the new native-head_dim path). - old_sd = old_model.state_dict() - new_sd = new_model.state_dict() - transferred = 0 - skipped: List[str] = [] - for name, new_tensor in new_sd.items(): - if name in old_sd and old_sd[name].shape == new_tensor.shape: - new_tensor.copy_(old_sd[name]) - transferred += 1 - else: - skipped.append(name) - assert transferred > 0, "no weights transferred; synthetic pair would be meaningless" - return old_model, new_model - - -def _synthetic_prompts( - num_adapters: int, - adapter_token_ids: List[int], - vocab_size: int, -) -> List[torch.Tensor]: - """A small, deterministic set of prompt sequences. - - Mix of: - - Prompts with no control token (base-path sanity). - - Prompts with a control token at different positions (adapter-activated). - """ - torch.manual_seed(42) - prompts: List[torch.Tensor] = [] - seq_len = 24 - # Fill tokens are drawn from the vocab excluding control-token ids. - safe_vocab = [t for t in range(1, vocab_size) if t not in adapter_token_ids] - - def _rand_seq() -> List[int]: - return [safe_vocab[int(torch.randint(0, len(safe_vocab), (1,)))] for _ in range(seq_len)] - - # Base-path prompts (no control tokens). - for _ in range(4): - prompts.append(torch.tensor([_rand_seq()], dtype=torch.long)) - # Adapter-activated prompts (one control token at varying positions). - for pos in (0, 2, 5, 10): - for ctrl_id in adapter_token_ids: - seq = _rand_seq() - seq[pos] = ctrl_id - prompts.append(torch.tensor([seq], dtype=torch.long)) - return prompts - - -def _demo_prompts(tokenizer, adapter_names: List[str]) -> List[torch.Tensor]: - """Realistic parity prompts: render every demo from tutorials/scripts - through the composed model's chat template, then tokenize. - - Each returned tensor is shape [1, seq_len]. Shape varies per prompt — - the parity eval loops one at a time, so no padding is needed. - """ - from tutorials.scripts.run_adapter_generation_direct import build_demo_prompts - - prompts: List[torch.Tensor] = [] - pairs = build_demo_prompts(tokenizer, available_adapters=set(adapter_names)) - for _demo_key, prompt_text in pairs: - ids = tokenizer(prompt_text, return_tensors="pt").input_ids - prompts.append(ids) - return prompts - - -# ──────────────────────────────────────────────────────────────────── -# Runner -# ──────────────────────────────────────────────────────────────────── - - -@dataclass -class ParityReport: - overall: AggregateResult - pre_control: Optional[AggregateResult] # positions before any control token - adapter_active: Optional[AggregateResult] # positions at / after control token - - def render(self) -> str: - parts = [self.overall.render("overall")] - if self.pre_control is not None: - parts.append("") - parts.append(self.pre_control.render("pre-control (base path)")) - if self.adapter_active is not None: - parts.append("") - parts.append(self.adapter_active.render("adapter-active")) - return "\n".join(parts) - - def as_dict(self) -> Dict: - d = {"overall": self.overall.as_dict()} - if self.pre_control is not None: - d["pre_control"] = self.pre_control.as_dict() - if self.adapter_active is not None: - d["adapter_active"] = self.adapter_active.as_dict() - return d - - -def run_parity_eval( - old_model: GraniteSwitchForCausalLM, - new_model: GraniteSwitchForCausalLM, - prompts: List[torch.Tensor], - adapter_token_ids: List[int], - top_p: float = 0.9, -) -> ParityReport: - all_results: List[PositionResult] = [] - for prompt in prompts: - with torch.no_grad(): - out_old = old_model(input_ids=prompt) - out_new = new_model(input_ids=prompt) - logits_old = out_old.logits[0] # [seq_len, vocab] - logits_new = out_new.logits[0] - mask = _adapter_active_mask(prompt[0], adapter_token_ids) - all_results.extend( - _per_position_metrics(logits_old, logits_new, top_p, adapter_active=mask) - ) - - overall = _aggregate(all_results) - pre = [r for r in all_results if not r.adapter_active] - active = [r for r in all_results if r.adapter_active] - return ParityReport( - overall=overall, - pre_control=_aggregate(pre) if pre else None, - adapter_active=_aggregate(active) if active else None, - ) - - -# ──────────────────────────────────────────────────────────────────── -# pytest entry points -# ──────────────────────────────────────────────────────────────────── - - -@pytest.mark.integration -def test_synthetic_parity_cpu(): - """Plumbing sanity check: legacy vs. token-exchange on a synthetic pair. - - With random weights, the two paths produce different logits (the swap IS - the difference), but the *structure* of the comparison should hold: base - positions (no control token) should agree perfectly; adapter-activated - positions will differ and we report how much. - """ - old_model, new_model = _build_synthetic_pair() - prompts = _synthetic_prompts( - num_adapters=2, - adapter_token_ids=[100, 101], - vocab_size=_SYNTHETIC_BASE_KWARGS["vocab_size"], - ) - report = run_parity_eval( - old_model, new_model, prompts, adapter_token_ids=[100, 101] - ) - print("\n" + report.render()) - assert report.overall.n_positions > 0 - assert report.overall.kl_max >= 0.0 - assert 0.0 <= report.overall.top1_agree_rate <= 1.0 - # Pre-control positions MUST agree bit-for-bit (both paths process them - # identically — no substitution, no hiding). Any disagreement here is a - # bug in the swap gating, not a mode trade-off. - if report.pre_control is not None: - assert report.pre_control.kl_max < 1e-6, ( - f"Pre-control KL max {report.pre_control.kl_max} should be ~0" - ) - assert report.pre_control.top1_agree_rate == 1.0 - - -@pytest.mark.slow -@pytest.mark.requires_model -def test_real_model_parity(): - """Gate for real composed checkpoints. Opt-in via env var. - - Set ``GRANITE_SWITCH_PARITY_MODELS`` to a JSON object with two paths: - '{"old": "/path/to/control_dims=32_build", "new": "/path/to/token_exchange_build"}' - - Both must be composed from the same base + adapter pair, differing only - in --legacy-hiding. Acceptance thresholds are documented per-metric; the - test fails if any is exceeded. - """ - spec = os.environ.get("GRANITE_SWITCH_PARITY_MODELS") - if spec is None: - pytest.skip("GRANITE_SWITCH_PARITY_MODELS env var not set") - paths = json.loads(spec) - old_path, new_path = paths["old"], paths["new"] - - old_model = GraniteSwitchForCausalLM.from_pretrained(old_path).eval() - new_model = GraniteSwitchForCausalLM.from_pretrained(new_path).eval() - - # Prompt set priority: - # 1. GRANITE_SWITCH_PARITY_PROMPTS env var (JSON array of int lists). - # 2. Rendered demo prompts from tutorials/scripts/run_adapter_generation_direct - # via the composed tokenizer — realistic adapter inputs. - # 3. Synthetic fallback (only useful when demo prompts fail for some reason). - prompts_spec = os.environ.get("GRANITE_SWITCH_PARITY_PROMPTS") - if prompts_spec: - prompt_lists = json.loads(prompts_spec) - prompts = [torch.tensor([p], dtype=torch.long) for p in prompt_lists] - else: - from transformers import AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained(old_path) - adapter_names = list(old_model.config.adapter_names or []) - prompts = _demo_prompts(tokenizer, adapter_names) - if not prompts: - prompts = _synthetic_prompts( - num_adapters=old_model.config.num_adapters, - adapter_token_ids=list(old_model.config.adapter_token_ids), - vocab_size=old_model.config.vocab_size, - ) - - report = run_parity_eval( - old_model, - new_model, - prompts, - adapter_token_ids=list(old_model.config.adapter_token_ids), - ) - print("\n" + report.render()) - - # Acceptance thresholds on the adapter-active partition (the comparison - # we actually care about). See docs/KV_CACHE_OVERHEAD_REMOVAL.md §4. - # Initial guesses — calibrate against a control_dims=32 vs control_dims=1 - # baseline run first and tighten to 2-3x the observed noise floor. - active = report.adapter_active if report.adapter_active else report.overall - assert active.top1_agree_rate >= 0.95, ( - f"top-1 agreement {active.top1_agree_rate:.4f} below 0.95 threshold" - ) - assert active.kl_mean <= 0.02, ( - f"mean KL {active.kl_mean:.5f} above 0.02 threshold" - ) - assert active.mass_under_old_nucleus_mean >= 0.88, ( - f"mean mass under old nucleus {active.mass_under_old_nucleus_mean:.4f} " - f"below 0.88 threshold" - ) - - -# ──────────────────────────────────────────────────────────────────── -# CLI entry point -# ──────────────────────────────────────────────────────────────────── - - -def _cli(): - p = argparse.ArgumentParser(description="Token-exchange parity eval.") - p.add_argument("--old", type=str, default=None, help="Path to legacy-hiding model") - p.add_argument("--new", type=str, default=None, help="Path to token-exchange model") - p.add_argument("--top-p", type=float, default=0.9) - p.add_argument("--json-out", type=str, default=None, help="Optional JSON report path") - args = p.parse_args() - - if args.old and args.new: - print(f"Loading old model from {args.old}...") - old_model = GraniteSwitchForCausalLM.from_pretrained(args.old).eval() - print(f"Loading new model from {args.new}...") - new_model = GraniteSwitchForCausalLM.from_pretrained(args.new).eval() - from transformers import AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained(args.old) - adapter_names = list(old_model.config.adapter_names or []) - print(f"Building demo prompts for adapters: {adapter_names}") - prompts = _demo_prompts(tokenizer, adapter_names) - if not prompts: - print("No demo prompts matched; falling back to synthetic.") - prompts = _synthetic_prompts( - num_adapters=old_model.config.num_adapters, - adapter_token_ids=list(old_model.config.adapter_token_ids), - vocab_size=old_model.config.vocab_size, - ) - print(f"Collected {len(prompts)} prompts.") - else: - print("Running synthetic parity (no --old/--new paths given)...") - old_model, new_model = _build_synthetic_pair() - prompts = _synthetic_prompts( - num_adapters=2, - adapter_token_ids=[100, 101], - vocab_size=_SYNTHETIC_BASE_KWARGS["vocab_size"], - ) - - if args.old and args.new: - adapter_token_ids = list(old_model.config.adapter_token_ids) - else: - adapter_token_ids = [100, 101] - report = run_parity_eval( - old_model, new_model, prompts, adapter_token_ids=adapter_token_ids, top_p=args.top_p, - ) - print() - print(report.render()) - if args.json_out: - import json as _json - with open(args.json_out, "w") as f: - _json.dump(report.as_dict(), f, indent=2) - print(f"\nWrote JSON report to {args.json_out}") - - -if __name__ == "__main__": - _cli() diff --git a/tests/shared/generation_models.py b/tests/shared/generation_models.py index 727ca71..6c5e477 100644 --- a/tests/shared/generation_models.py +++ b/tests/shared/generation_models.py @@ -49,26 +49,14 @@ def single_overrides(base_cfg): - """SingleSwitch overrides for the given base config. - - Pinned to the legacy KV-hiding path (control_dims=32) so existing - generation tests exercise hiding semantics even after the default - flipped to token-exchange. - """ + """SingleSwitch overrides for the given base config (token exchange).""" base_layers = base_cfg["layer_types"] return { "num_adapters": NUM_ADAPTERS, "adapter_ranks": [ADAPTER_RANK] * NUM_ADAPTERS, "adapter_token_ids": [250, 251], + "adapter_substitute_token_ids": [1, 1], "adapter_names": ["adapter_0", "adapter_1"], - "hiding_groups": {"all_controls": ["adapter_0", "adapter_1"]}, - "hiding_policy": { - "base": ["all_controls"], - "adapter_0": ["all_controls"], - "adapter_1": ["all_controls"], - }, - "adapter_third_party": ["adapter_0", "adapter_1"], - "control_dims": 32, "num_hidden_layers": len(base_layers) + 1, "layer_types": ["attention"] + base_layers, } diff --git a/tests/shared/granite4_equivalence.py b/tests/shared/granite4_equivalence.py index f2606dd..68f58ba 100644 --- a/tests/shared/granite4_equivalence.py +++ b/tests/shared/granite4_equivalence.py @@ -186,16 +186,10 @@ def augment_cfg_with_adapters(cfg_dict, num_adapters=2, rank=8): cfg["adapter_token_ids"] = [ _ADAPTER_TOKEN_BASE + i for i in range(num_adapters) ] - - # Default hiding config: all adapters in a single group, all hide it. - cfg["hiding_groups"] = {"all_controls": list(adapter_names)} - cfg["hiding_policy"] = { - name: ["all_controls"] for name in ["base"] + list(adapter_names) - } - cfg["adapter_third_party"] = list(adapter_names) - # These equivalence tests specifically exercise the legacy KV-hiding path. - # Pin control_dims=32 so they keep running after the default flipped to 0. - cfg.setdefault("control_dims", 32) + # Token-exchange substitute ids — use a benign shared id (the BOS-or- + # equivalent doesn't matter for these synthetic equivalence tests since + # all LoRA weights are zero, so the embedding is what feeds the decoder). + cfg["adapter_substitute_token_ids"] = [1] * num_adapters return cfg diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 92b53f3..7225958 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 """Config validation tests for GraniteSwitchConfig. -Tests every ValueError path in __init__, default values, and derived properties. +Covers the validators in __init__, default values, and the config +fields that survived the legacy-hiding removal. """ import pytest @@ -11,13 +12,9 @@ # ── Helper ──────────────────────────────────────────────────────────── -def _valid_kwargs(num_adapters=2, **overrides): - """Return kwargs for a valid SingleSwitch config, with optional overrides. - Default mode: legacy KV hiding (control_dims=32) so existing shape/hiding - tests keep their original contract. Token-exchange tests override - control_dims=0 and pass adapter_substitute_token_ids. - """ +def _valid_kwargs(num_adapters=2, **overrides): + """Return kwargs for a valid token-exchange config.""" adapter_names = [f"adapter_{i}" for i in range(num_adapters)] base = dict( vocab_size=300, @@ -28,10 +25,10 @@ def _valid_kwargs(num_adapters=2, **overrides): num_key_value_heads=4, num_adapters=num_adapters, adapter_token_ids=list(range(500, 500 + num_adapters)), + adapter_substitute_token_ids=[1] * num_adapters, adapter_names=adapter_names, max_lora_rank=8, adapter_ranks=[8] * num_adapters, - control_dims=32, ) base.update(overrides) return base @@ -41,6 +38,7 @@ def _valid_kwargs(num_adapters=2, **overrides): # 1. Config validation — every ValueError path # ════════════════════════════════════════════════════════════════════ + class TestConfigValidation: def test_negative_num_adapters_raises(self): @@ -49,137 +47,56 @@ def test_negative_num_adapters_raises(self): def test_adapter_token_ids_wrong_length_raises(self): with pytest.raises(ValueError, match="adapter_token_ids length"): - GraniteSwitchConfig(**_valid_kwargs( - adapter_token_ids=[500, 501, 502], # length 3, expected 2 - )) - - def test_missing_adapter_ranks_raises(self): + GraniteSwitchConfig(**_valid_kwargs(adapter_token_ids=[500])) + + def test_substitute_ids_required_when_adapters_present(self): + with pytest.raises(ValueError, match="adapter_substitute_token_ids is required"): + GraniteSwitchConfig( + **_valid_kwargs(adapter_substitute_token_ids=None) + ) + + def test_substitute_ids_wrong_length_raises(self): + with pytest.raises(ValueError, match="adapter_substitute_token_ids length"): + GraniteSwitchConfig( + **_valid_kwargs(adapter_substitute_token_ids=[1]) + ) + + def test_substitute_ids_negative_raises(self): + with pytest.raises(ValueError, match=">= 0"): + GraniteSwitchConfig( + **_valid_kwargs(adapter_substitute_token_ids=[-1, 1]) + ) + + def test_duplicate_adapter_token_ids_raises(self): + with pytest.raises(ValueError, match="adapter_token_ids must be unique"): + GraniteSwitchConfig(**_valid_kwargs(adapter_token_ids=[500, 500])) + + def test_adapter_ranks_required(self): with pytest.raises(ValueError, match="adapter_ranks must be provided"): GraniteSwitchConfig(**_valid_kwargs(adapter_ranks=None)) - def test_adapter_ranks_wrong_length_raises(self): + def test_adapter_ranks_wrong_length(self): with pytest.raises(ValueError, match="adapter_ranks length"): - GraniteSwitchConfig(**_valid_kwargs( - adapter_ranks=[8], # length 1, expected 2 - )) + GraniteSwitchConfig(**_valid_kwargs(adapter_ranks=[8])) - def test_max_adapter_ranks_mismatch_raises(self): - with pytest.raises(ValueError, match="max.*adapter_ranks.*must equal max_lora_rank"): - GraniteSwitchConfig(**_valid_kwargs( - adapter_ranks=[4, 4], # max=4, but max_lora_rank=8 - )) + def test_max_lora_rank_must_match(self): + with pytest.raises(ValueError, match="max_lora_rank"): + GraniteSwitchConfig(**_valid_kwargs(max_lora_rank=4)) # ════════════════════════════════════════════════════════════════════ -# 2. Config defaults and derived properties +# 2. Defaults # ════════════════════════════════════════════════════════════════════ + class TestConfigDefaults: - def test_zero_adapters_no_validation(self): - """Config with 0 adapters should not require adapter_ranks or token_ids.""" - cfg = GraniteSwitchConfig( - vocab_size=256, hidden_size=64, intermediate_size=128, - num_hidden_layers=2, num_attention_heads=4, num_key_value_heads=4, - num_adapters=0, - ) + def test_zero_adapter_default(self): + cfg = GraniteSwitchConfig(num_adapters=0) assert cfg.num_adapters == 0 - assert cfg.adapter_ranks is None - - -# ════════════════════════════════════════════════════════════════════ -# 3. Hiding groups and policy -# ════════════════════════════════════════════════════════════════════ - -class TestHidingConfig: - - def test_hiding_groups_none_by_default(self): - """No hiding groups when not specified.""" - cfg = GraniteSwitchConfig(**_valid_kwargs()) - assert cfg.num_hiding_groups == 0 - assert cfg.hiding_group_names == [] - - def test_hiding_groups_count(self): - """num_hiding_groups reflects configured groups.""" - cfg = GraniteSwitchConfig(**_valid_kwargs( - hiding_groups={ - "group_a": ["adapter_0"], - "group_b": ["adapter_1"], - }, - )) - assert cfg.num_hiding_groups == 2 - assert cfg.hiding_group_names == ["group_a", "group_b"] - - def test_control_dims_less_than_groups_raises(self): - """control_dims must be >= number of hiding groups.""" - with pytest.raises(ValueError, match="control_dims.*must be >= number of hiding groups"): - GraniteSwitchConfig(**_valid_kwargs( - control_dims=1, - hiding_groups={ - "g1": ["adapter_0"], - "g2": ["adapter_1"], - }, - )) - - def test_get_hiding_group_token_ids(self): - """Token IDs resolved correctly for SingleSwitch.""" - cfg = GraniteSwitchConfig(**_valid_kwargs( - hiding_groups={"all_controls": ["adapter_0", "adapter_1"]}, - )) - group_tokens = cfg.get_hiding_group_token_ids() - # SingleSwitch: no base offset, adapter_0 → token 500, adapter_1 → token 501 - assert group_tokens == {0: [500, 501]} - - def test_get_hiding_group_token_ids_multiple_groups(self): - """Multiple groups with different adapter assignments.""" - cfg = GraniteSwitchConfig(**_valid_kwargs( - hiding_groups={ - "group_a": ["adapter_0"], - "group_b": ["adapter_1"], - }, - )) - group_tokens = cfg.get_hiding_group_token_ids() - assert group_tokens == {0: [500], 1: [501]} - - def test_get_adapter_hiding_policy_matrix(self): - """Policy matrix built correctly from named config.""" - cfg = GraniteSwitchConfig(**_valid_kwargs( - hiding_groups={ - "group_a": ["adapter_0"], - "group_b": ["adapter_1"], - }, - hiding_policy={ - "base": ["group_a", "group_b"], - "adapter_0": ["group_b"], - "adapter_1": ["group_a"], - }, - )) - matrix = cfg.get_adapter_hiding_policy_matrix() - # [base, adapter_0, adapter_1] x [group_a, group_b] - assert matrix == [ - [True, True], # base hides both - [False, True], # adapter_0 hides group_b only - [True, False], # adapter_1 hides group_a only - ] - - def test_get_adapter_hiding_policy_matrix_no_policy(self): - """Empty matrix when no policy configured.""" - cfg = GraniteSwitchConfig(**_valid_kwargs()) - assert cfg.get_adapter_hiding_policy_matrix() == [] - - -# ════════════════════════════════════════════════════════════════════ -# 4. Third-party adapter config -# ════════════════════════════════════════════════════════════════════ - -class TestAdapterThirdParty: - - def test_adapter_third_party_stored(self): - cfg = GraniteSwitchConfig(**_valid_kwargs( - adapter_third_party=["adapter_0"], - )) - assert cfg.adapter_third_party == ["adapter_0"] + assert cfg.adapter_token_ids is None + assert cfg.adapter_substitute_token_ids is None - def test_adapter_third_party_none_by_default(self): + def test_projection_head_dim_inferred_from_hidden_size(self): cfg = GraniteSwitchConfig(**_valid_kwargs()) - assert cfg.adapter_third_party is None + assert cfg.projection_head_dim == 64 // 4 diff --git a/tests/unit/test_config_edge_cases.py b/tests/unit/test_config_edge_cases.py index 55ca3c9..917d933 100644 --- a/tests/unit/test_config_edge_cases.py +++ b/tests/unit/test_config_edge_cases.py @@ -1,13 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -"""Additional config edge case tests for GraniteSwitchConfig. - -These tests cover edge cases not covered by the main test_config.py, -specifically targeting previously uncovered code paths: -- Line 99: shared_intermediate_size default from intermediate_size -- Line 119: negative control_dims validation -- Lines 220, 222: get_hiding_group_token_ids with missing configs -- Lines 250-259: get_third_party_adapter_mask functionality -""" +"""Additional config edge case tests for GraniteSwitchConfig.""" import pytest @@ -15,10 +7,7 @@ def _valid_kwargs(num_adapters=2, **overrides): - """Return kwargs for a valid SingleSwitch config, with optional overrides. - - Default mode: legacy KV hiding (control_dims=32). - """ + """Return kwargs for a valid token-exchange config.""" adapter_names = [f"adapter_{i}" for i in range(num_adapters)] base = dict( vocab_size=300, @@ -29,213 +18,64 @@ def _valid_kwargs(num_adapters=2, **overrides): num_key_value_heads=4, num_adapters=num_adapters, adapter_token_ids=list(range(500, 500 + num_adapters)), + adapter_substitute_token_ids=[1] * num_adapters, adapter_names=adapter_names, max_lora_rank=8, adapter_ranks=[8] * num_adapters, - control_dims=32, ) base.update(overrides) return base class TestSharedIntermediateSize: - """Tests for shared_intermediate_size default handling (line 99). - - Note: The parent GraniteMoeHybridConfig may have a non-None default, - so line 99 (the None check) may not always be hit. We test the - explicit case and verify the config has a sensible value. - """ + """The parent GraniteMoeHybridConfig may have a non-None default for + shared_intermediate_size. Verify our config has a sensible value.""" def test_shared_intermediate_size_has_value(self): - """shared_intermediate_size has a value (either explicit or parent default).""" cfg = GraniteSwitchConfig(**_valid_kwargs()) - # Should have a sensible value (not None) assert cfg.shared_intermediate_size is not None assert cfg.shared_intermediate_size > 0 def test_explicit_shared_intermediate_size_preserved(self): - """Explicit shared_intermediate_size is preserved.""" cfg = GraniteSwitchConfig(**_valid_kwargs( shared_intermediate_size=256, )) assert cfg.shared_intermediate_size == 256 -class TestControlDimsValidation: - """Tests for control_dims validation (line 119).""" - - def test_negative_control_dims_raises(self): - """Negative control_dims should raise ValueError.""" - with pytest.raises(ValueError, match="control_dims must be >= 0"): - GraniteSwitchConfig(**_valid_kwargs(control_dims=-1)) - - def test_zero_control_dims_valid_with_substitute_ids(self): - """Zero control_dims is valid when substitute ids are provided (token exchange).""" - cfg = GraniteSwitchConfig( - **_valid_kwargs(control_dims=0, adapter_substitute_token_ids=[1, 2]) - ) - assert cfg.control_dims == 0 - assert cfg.use_token_exchange is True - - def test_zero_control_dims_no_substitute_ids_raises(self): - """Zero control_dims without substitute ids must fail: no hiding and no exchange.""" - with pytest.raises(ValueError, match="either control_dims > 0"): - GraniteSwitchConfig(**_valid_kwargs(control_dims=0)) - - def test_positive_control_dims_valid(self): - """Positive control_dims is valid.""" - cfg = GraniteSwitchConfig(**_valid_kwargs(control_dims=64)) - assert cfg.control_dims == 64 - - -class TestGetHidingGroupTokenIds: - """Tests for get_hiding_group_token_ids edge cases (lines 220, 222).""" - - def test_no_hiding_groups_returns_empty(self): - """Empty dict when hiding_groups is None (line 219).""" - cfg = GraniteSwitchConfig(**_valid_kwargs(hiding_groups=None)) - result = cfg.get_hiding_group_token_ids() - assert result == {} - - def test_no_adapter_names_returns_empty(self): - """Empty dict when adapter_names is None (line 219).""" - cfg = GraniteSwitchConfig(**_valid_kwargs( - adapter_names=None, - hiding_groups={"all": ["adapter_0"]}, - )) - result = cfg.get_hiding_group_token_ids() - assert result == {} - - def test_no_adapter_token_ids_returns_empty(self): - """Empty dict when adapter_token_ids is None (line 222).""" - cfg = GraniteSwitchConfig(**_valid_kwargs( - adapter_token_ids=None, - hiding_groups={"all": ["adapter_0"]}, - )) - result = cfg.get_hiding_group_token_ids() - assert result == {} - - def test_partial_adapter_name_match(self): - """Only matching adapter names are included in result.""" - cfg = GraniteSwitchConfig(**_valid_kwargs( - hiding_groups={"all": ["adapter_0", "nonexistent_adapter"]}, - )) - result = cfg.get_hiding_group_token_ids() - # Only adapter_0 should be in the result (token 500) - assert result == {0: [500]} - - -class TestGetThirdPartyAdapterMask: - """Tests for get_third_party_adapter_mask (lines 250-259).""" - - def test_no_third_party_returns_all_false(self): - """All-False mask when adapter_third_party is not configured.""" - cfg = GraniteSwitchConfig(**_valid_kwargs( - adapter_third_party=None, - )) - mask = cfg.get_third_party_adapter_mask() - # Length = num_adapters + 1 (base + adapters) - assert len(mask) == 3 # base + 2 adapters - assert mask == [False, False, False] - - def test_empty_third_party_returns_all_false(self): - """All-False mask when adapter_third_party is empty list.""" - cfg = GraniteSwitchConfig(**_valid_kwargs( - adapter_third_party=[], - )) - mask = cfg.get_third_party_adapter_mask() - assert mask == [False, False, False] - - def test_no_adapter_names_returns_all_false(self): - """All-False mask when adapter_names is None.""" - cfg = GraniteSwitchConfig(**_valid_kwargs( - adapter_names=None, - adapter_third_party=["adapter_0"], - )) - mask = cfg.get_third_party_adapter_mask() - assert mask == [False, False, False] - - def test_single_third_party_adapter(self): - """Mask correctly identifies single third-party adapter.""" - cfg = GraniteSwitchConfig(**_valid_kwargs( - adapter_third_party=["adapter_0"], - )) - mask = cfg.get_third_party_adapter_mask() - # Index 0 = base (never third-party) - # Index 1 = adapter_0 (third-party) - # Index 2 = adapter_1 (not third-party) - assert mask == [False, True, False] - - def test_multiple_third_party_adapters(self): - """Mask correctly identifies multiple third-party adapters.""" - cfg = GraniteSwitchConfig(**_valid_kwargs( - adapter_third_party=["adapter_0", "adapter_1"], - )) - mask = cfg.get_third_party_adapter_mask() - # Both adapters are third-party - assert mask == [False, True, True] - - def test_base_never_third_party(self): - """Base adapter (index 0) is never marked as third-party.""" - cfg = GraniteSwitchConfig(**_valid_kwargs( - adapter_third_party=["adapter_0", "adapter_1"], - )) - mask = cfg.get_third_party_adapter_mask() - assert mask[0] is False # Base is never third-party - - def test_mask_length_matches_num_adapters_plus_one(self): - """Mask length is num_adapters + 1 (includes base slot).""" - for num_adapters in [0, 1, 4, 10]: - cfg = GraniteSwitchConfig(**_valid_kwargs( - num_adapters=num_adapters, - adapter_token_ids=list(range(500, 500 + num_adapters)), - adapter_names=[f"adapter_{i}" for i in range(num_adapters)], - adapter_ranks=[8] * num_adapters if num_adapters > 0 else None, - )) - mask = cfg.get_third_party_adapter_mask() - assert len(mask) == num_adapters + 1 - - class TestLayerTypesDefault: - """Tests for layer_types default handling.""" + """layer_types defaults to all-attention with length == num_hidden_layers.""" - def test_layer_types_defaults_to_attention(self): - """layer_types defaults to all 'attention' when None.""" - cfg = GraniteSwitchConfig(**_valid_kwargs( - layer_types=None, - num_hidden_layers=4, - )) - # Should have 4 attention layers (adapters add a switch layer at index 0, - # but the config has num_hidden_layers=4 which becomes 5 with switch) - # The default is set before parent init adds the switch layer - assert cfg.layer_types is not None + def test_default_layer_types_when_omitted(self): + cfg = GraniteSwitchConfig(num_adapters=0, num_hidden_layers=4) + assert cfg.layer_types == ["attention"] * 4 def test_explicit_layer_types_preserved(self): - """Explicit layer_types are preserved.""" - cfg = GraniteSwitchConfig(**_valid_kwargs( - layer_types=["attention", "attention"], - num_hidden_layers=2, - )) - assert cfg.layer_types == ["attention", "attention"] + cfg = GraniteSwitchConfig( + num_adapters=0, + num_hidden_layers=3, + layer_types=["attention", "attention", "attention"], + ) + assert cfg.layer_types == ["attention", "attention", "attention"] class TestLoraTargetModulesDefault: - """Tests for lora_target_modules default handling.""" + """lora_target_modules defaults to qkv_proj/o_proj + shared_mlp pair + when num_adapters > 0; empty when num_adapters == 0.""" - def test_lora_target_modules_empty_when_no_adapters(self): - """lora_target_modules defaults to empty when num_adapters=0.""" - cfg = GraniteSwitchConfig( - vocab_size=256, hidden_size=64, intermediate_size=128, - num_hidden_layers=2, num_attention_heads=4, num_key_value_heads=4, - num_adapters=0, - ) + def test_no_adapters_empty_target_modules(self): + cfg = GraniteSwitchConfig(num_adapters=0) assert cfg.lora_target_modules == [] - def test_lora_target_modules_populated_with_adapters(self): - """lora_target_modules defaults to standard modules when adapters present.""" + def test_adapters_populate_target_modules(self): cfg = GraniteSwitchConfig(**_valid_kwargs()) - # Should include attention and MLP modules assert "qkv_proj" in cfg.lora_target_modules assert "o_proj" in cfg.lora_target_modules assert "shared_input_linear" in cfg.lora_target_modules assert "shared_output_linear" in cfg.lora_target_modules + + def test_explicit_target_modules_preserved(self): + cfg = GraniteSwitchConfig( + **_valid_kwargs(lora_target_modules=["qkv_proj"]) + ) + assert cfg.lora_target_modules == ["qkv_proj"] diff --git a/tests/unit/test_hiding_constant.py b/tests/unit/test_hiding_constant.py deleted file mode 100644 index 00c7724..0000000 --- a/tests/unit/test_hiding_constant.py +++ /dev/null @@ -1,118 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""Tests for the K-side hiding constant used in control-dimension masking. - -The hiding mechanism sets K[d_g] = finfo(dtype).min for tokens in hiding group g. -This test validates that this constant behaves correctly across all supported -floating-point types: - -1. exp(constant) == 0 (softmax produces zero weight) -2. Accumulation of multiple constants (multiple groups) also exponentiates to zero -3. Adding realistic finite attention scores to the constant still exponentiates to zero -4. 0 * constant does NOT produce NaN (critical: Q[d_g]=0 for non-hiding adapters) - -Safety margin reporting (how large a positive value must be added before exp produces -a nonzero result) is part of the builder's verbose output, not tested here. -""" - -import pytest -import torch - -DTYPES = [torch.float16, torch.bfloat16, torch.float32, torch.float64] -DTYPE_IDS = ["float16", "bfloat16", "float32", "float64"] - -# Maximum number of hiding groups we might realistically accumulate in a dot product -MAX_GROUPS = 32 - - - -def hiding_constant(dtype: torch.dtype) -> torch.Tensor: - """The K-side hiding constant: the most negative finite value in the dtype.""" - return torch.tensor(torch.finfo(dtype).min, dtype=dtype) - - -@pytest.fixture(params=DTYPES, ids=DTYPE_IDS) -def dtype(request): - return request.param - - -class TestHidingConstantExponentiation: - """exp(hiding_constant) must be exactly zero — this is what makes softmax - assign zero attention weight to hidden tokens.""" - - def test_exp_of_constant_is_zero(self, dtype): - c = hiding_constant(dtype) - assert torch.exp(c).item() == 0.0 - - def test_exp_of_sum_of_constants_is_zero(self, dtype): - """A token in multiple groups: dot product accumulates multiple constants.""" - c = hiding_constant(dtype) - for n in [2, 4, MAX_GROUPS]: - accum = torch.zeros(1, dtype=dtype) - for _ in range(n): - accum = accum + c - assert torch.exp(accum).item() == 0.0, f"Failed for {n} groups" - - def test_exp_of_constant_plus_finite_is_zero(self, dtype): - """Normal attention score added to the constant must still exponentiate to zero.""" - c = hiding_constant(dtype) - for score in [0.0, 10.0, 100.0, 1000.0]: - s = torch.tensor(score, dtype=dtype) - result = torch.exp(c + s) - assert result.item() == 0.0, f"Failed for score={score}" - - -class TestHidingConstantNoNaN: - """0 * hiding_constant must NOT produce NaN. This is the scenario where - Q[d_g] = 0 (adapter does not hide group g) and K[d_g] = constant.""" - - def test_zero_times_constant_is_not_nan(self, dtype): - c = hiding_constant(dtype) - zero = torch.tensor(0.0, dtype=dtype) - result = zero * c - assert not result.isnan().item() - - def test_zero_times_constant_does_not_corrupt_dot_product(self, dtype): - """In a realistic dot product, 0 * constant contributions must not - change the result compared to a clean dot product without control dims.""" - torch.manual_seed(42) - head_dim = 128 - control_dims = 4 - total_dim = head_dim + control_dims - - Q = torch.randn(total_dim, dtype=dtype) - K = torch.randn(total_dim, dtype=dtype) - - c = hiding_constant(dtype) - # Token is in groups 0 and 2 - K[head_dim + 0] = c - K[head_dim + 1] = 0.0 - K[head_dim + 2] = c - K[head_dim + 3] = 0.0 - - # Query does NOT hide any group - Q[head_dim:] = 0.0 - - dot_with_ctrl = torch.dot(Q, K) - dot_clean = torch.dot(Q[:head_dim], K[:head_dim]) - - assert not dot_with_ctrl.isnan().item() - assert torch.isclose(dot_with_ctrl, dot_clean, atol=1e-2) - - -class TestHidingConstantSoftmax: - """End-to-end: softmax assigns exactly zero weight to hidden positions.""" - - def test_softmax_zero_weight_for_hidden(self, dtype): - scores = torch.tensor([5.0, 3.0, 7.0], dtype=dtype) - c = hiding_constant(dtype) - scores_with_hidden = scores.clone() - scores_with_hidden[1] = scores_with_hidden[1] + c # hide position 1 - - sm = torch.softmax(scores_with_hidden, dim=0) - assert sm[1].item() == 0.0 - # Non-hidden positions should get all the probability mass - assert sm[0].item() > 0.0 - assert sm[2].item() > 0.0 - assert abs(sm.sum().item() - 1.0) < 1e-3 - - diff --git a/tests/unit/test_token_exchange.py b/tests/unit/test_token_exchange.py index 3e74cb0..d24e968 100644 --- a/tests/unit/test_token_exchange.py +++ b/tests/unit/test_token_exchange.py @@ -1,12 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 """Unit tests for the token-exchange config path. -Covers the new fields and validators on GraniteSwitchConfig: -- adapter_substitute_token_ids length check -- use_token_exchange property -- rejection of num_adapters>0 with no hiding and no substitute ids -- rejection of duplicate adapter_token_ids (LUT would collide) -- default control_dims flipped to 0 +Verifies the validators and required-field semantics on +GraniteSwitchConfig, now that token-exchange is the only mode. """ import pytest @@ -25,6 +21,7 @@ def _base(num_adapters=2, **overrides): num_key_value_heads=4, num_adapters=num_adapters, adapter_token_ids=list(range(500, 500 + num_adapters)), + adapter_substitute_token_ids=[1] * num_adapters, adapter_names=names, max_lora_rank=8, adapter_ranks=[8] * num_adapters, @@ -34,69 +31,30 @@ def _base(num_adapters=2, **overrides): class TestDefaults: - def test_control_dims_default_is_zero(self): - cfg = GraniteSwitchConfig(num_adapters=0) - assert cfg.control_dims == 0 - def test_no_adapters_no_validation(self): cfg = GraniteSwitchConfig(num_adapters=0) - assert cfg.use_token_exchange is False - - -class TestUseTokenExchangeProperty: - def test_true_when_substitute_and_zero_dims(self): - cfg = GraniteSwitchConfig( - **_base( - control_dims=0, - adapter_substitute_token_ids=[1, 2], - ) - ) - assert cfg.use_token_exchange is True - - def test_false_when_legacy_hiding(self): - cfg = GraniteSwitchConfig(**_base(control_dims=32)) - assert cfg.use_token_exchange is False - - def test_false_when_no_substitute_ids_even_with_zero_dims_requires_validator(self): - # This combo is invalid — validator rejects it, so the property - # cannot be observed in a built config. Covered in TestValidation. - pass + assert cfg.adapter_substitute_token_ids is None class TestValidation: - def test_zero_dims_plus_missing_substitute_ids_raises(self): - with pytest.raises(ValueError, match="either control_dims > 0"): - GraniteSwitchConfig(**_base(control_dims=0)) + def test_substitute_ids_required_when_adapters_present(self): + with pytest.raises(ValueError, match="adapter_substitute_token_ids is required"): + GraniteSwitchConfig(**_base(adapter_substitute_token_ids=None)) def test_substitute_wrong_length_raises(self): with pytest.raises(ValueError, match="adapter_substitute_token_ids length"): - GraniteSwitchConfig( - **_base(control_dims=0, adapter_substitute_token_ids=[1]) - ) + GraniteSwitchConfig(**_base(adapter_substitute_token_ids=[1])) def test_duplicate_adapter_token_ids_raises(self): with pytest.raises(ValueError, match="adapter_token_ids must be unique"): - GraniteSwitchConfig( - **_base( - adapter_token_ids=[100, 100], - adapter_substitute_token_ids=[1, 2], - control_dims=0, - ) - ) - + GraniteSwitchConfig(**_base(adapter_token_ids=[100, 100])) -class TestLegacyPathStillWorks: - def test_control_dims_positive_without_substitute_ids(self): - cfg = GraniteSwitchConfig(**_base(control_dims=32)) - assert cfg.control_dims == 32 - assert cfg.use_token_exchange is False - # Expanded head_dim reflects the legacy path. - assert cfg.expanded_head_dim == cfg.projection_head_dim + 32 + def test_negative_substitute_id_raises(self): + with pytest.raises(ValueError, match=">= 0"): + GraniteSwitchConfig(**_base(adapter_substitute_token_ids=[-1, 1])) -class TestExpandedHeadDim: - def test_token_exchange_gives_native_head_dim(self): - cfg = GraniteSwitchConfig( - **_base(control_dims=0, adapter_substitute_token_ids=[1, 2]) - ) - assert cfg.expanded_head_dim == cfg.projection_head_dim +class TestProjectionHeadDim: + def test_inferred_from_hidden_size(self): + cfg = GraniteSwitchConfig(**_base()) + assert cfg.projection_head_dim == cfg.hidden_size // cfg.num_attention_heads diff --git a/tests/vllm/_kv_hiding_gap_tests.py b/tests/vllm/_kv_hiding_gap_tests.py deleted file mode 100644 index be29394..0000000 --- a/tests/vllm/_kv_hiding_gap_tests.py +++ /dev/null @@ -1,71 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""KV hiding gap equivalence tests (inner file — run by test_kv_hiding_gap_equivalence.py). - -Verify that a hidden control token creates a transparent gap in vLLM. -Requires CUDA GPU and vLLM installed. -""" - -import pytest -import torch - -from tests.shared.granite4_equivalence import GRANITE4_MINI -from tests.shared.gap_equivalence import extract_visible_flat - - -_CUDA_AVAILABLE = torch.cuda.is_available() - - -def _try_import_vllm(): - try: - from vllm import LLM # noqa: F401 - return True - except ImportError: - return False - - -_VLLM_AVAILABLE = _try_import_vllm() if _CUDA_AVAILABLE else False - -_CFG_NAME = "4.0-350m" - - -@pytest.mark.skipif( - not _CUDA_AVAILABLE or not _VLLM_AVAILABLE, - reason="requires CUDA GPU and vLLM installed", -) -class TestKVHidingGapEquivalence: - - @pytest.fixture - def gap_runner(self, tmp_path): - cfg_dict = GRANITE4_MINI[_CFG_NAME] - - def run(seq_len, ctrl_pos): - from tests.shared.vllm_equivalence import run_gap_equivalence - return run_gap_equivalence( - cfg_dict, - seq_len=seq_len, ctrl_pos=ctrl_pos, - tmpdir=tmp_path, - ) - - return run - - def _assert_gap(self, run, seq_len, ctrl_pos, atol=0, rtol=0): - upstream_lp, switch_lp = run(seq_len, ctrl_pos) - visible_lp = extract_visible_flat(switch_lp, ctrl_pos) - - torch.testing.assert_close( - visible_lp, upstream_lp, - atol=atol, rtol=rtol, - msg=( - f"{_CFG_NAME}: visible logprobs diverge " - f"(seq={seq_len}, ctrl={ctrl_pos})" - ), - ) - - def test_gap_short(self, gap_runner): - self._assert_gap(gap_runner, seq_len=16, ctrl_pos=2) - - def test_gap_ctrl_at_1(self, gap_runner): - self._assert_gap(gap_runner, seq_len=16, ctrl_pos=1) - - def test_gap_near_end(self, gap_runner): - self._assert_gap(gap_runner, seq_len=16, ctrl_pos=14) diff --git a/tests/vllm/_position_zero_nan_tests.py b/tests/vllm/_position_zero_nan_tests.py deleted file mode 100644 index 69b3360..0000000 --- a/tests/vllm/_position_zero_nan_tests.py +++ /dev/null @@ -1,479 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""NaN regression tests — control token at sequence position 0 (vLLM backend). -Inner file — run by test_position_zero_nan.py in subprocess. - -Combines: - - vLLM-specific unit tests for GraniteLoRAEmbeddedAttention._expand_with_control_dimensions - (flat token layout: [num_tokens, num_heads * head_dim]) - - Shared SDPANaNCases and ModelFinitenessCases from tests/shared/position_zero_nan_cases.py - -Requires CUDA GPU and vLLM installed. -""" - -import types -import json -import os -import tempfile - -import pytest -import torch - -from tests.shared.position_zero_nan_cases import ModelFinitenessCases, SDPANaNCases - -_CUDA_AVAILABLE = torch.cuda.is_available() - - -def _try_import_vllm(): - try: - from vllm.config import VllmConfig # noqa: F401 - from vllm.model_executor.layers.attention.attention import Attention # noqa: F401 - from vllm.forward_context import ForwardContext, override_forward_context # noqa: F401 - return True - except ImportError: - return False - - -_VLLM_AVAILABLE = _try_import_vllm() if _CUDA_AVAILABLE else False - -pytestmark = pytest.mark.skipif( - not _CUDA_AVAILABLE or not _VLLM_AVAILABLE, - reason="requires CUDA GPU and vLLM installed", -) - -if _VLLM_AVAILABLE: - from vllm.config import VllmConfig, ModelConfig, set_current_vllm_config - from vllm.forward_context import ForwardContext, override_forward_context - from granite_switch.config import GraniteSwitchConfig - from granite_switch.vllm.granite_switch_model import GraniteSwitchForCausalLM - from granite_switch.vllm.core.decoder import GraniteLoRAEmbeddedAttention - -from tests.shared.vllm_distributed import ensure_distributed as _ensure_distributed - -BLOCK_SIZE = 16 -MAX_TOKENS = 512 -SEED = 42 - - -# ── vLLM config + ctrl token ID ──────────────────────────────────── - - -def _make_config(): - return GraniteSwitchConfig( - vocab_size=300, - hidden_size=64, - intermediate_size=128, - num_hidden_layers=3, - num_attention_heads=2, - num_key_value_heads=2, - num_adapters=2, - adapter_token_ids=[250, 251], - adapter_names=["adapter_1", "adapter_2"], - hiding_groups={"all_controls": ["adapter_1", "adapter_2"]}, - hiding_policy={ - "base": ["all_controls"], - "adapter_1": ["all_controls"], - "adapter_2": ["all_controls"], - }, - adapter_third_party=["adapter_1", "adapter_2"], - max_lora_rank=4, - adapter_ranks=[4, 4], - switch_head_dim=32, - control_dims=32, - max_position_embeddings=MAX_TOKENS, - attention_multiplier=1.0, - embedding_multiplier=1.0, - residual_multiplier=1.0, - logits_scaling=1.0, - ) - - -_CTRL_TOKEN = 250 - - -# ── vLLM model runner ─────────────────────────────────────────────── - - -def _make_vllm_config(config): - from granite_switch.vllm import register - register() - tmpdir = tempfile.mkdtemp(prefix="gs_nan_test_") - cfg_dict = config.to_dict() - cfg_dict["architectures"] = ["GraniteSwitchForCausalLM"] - with open(os.path.join(tmpdir, "config.json"), "w") as f: - json.dump(cfg_dict, f) - model_config = ModelConfig( - model=tmpdir, - dtype="bfloat16", - max_model_len=config.max_position_embeddings, - enforce_eager=True, - ) - return VllmConfig(model_config=model_config) - - -def _init_weights(model): - torch.manual_seed(SEED) - with torch.no_grad(): - for name, param in model.named_parameters(): - if not param.is_floating_point(): - continue - if "lora_A" in name or "lora_B" in name: - continue - if "layernorm" in name or "norm" in name: - continue - param.data.normal_(0, 0.02) - - -def _setup_kv_caches(model, config, vllm_config, device): - kv_caches = [] - attention_map = {} - num_blocks = (MAX_TOKENS + BLOCK_SIZE - 1) // BLOCK_SIZE + 1 - - def _add(attn, name): - attn.kv_cache_torch_dtype = torch.bfloat16 - shape = attn.attn_backend.get_kv_cache_shape( - num_blocks, BLOCK_SIZE, attn.num_kv_heads, attn.head_size, - ) - kv = torch.zeros(shape, device=device, dtype=torch.bfloat16) - attn.kv_cache = kv - kv_caches.append(kv) - attention_map[name] = attn - - sw = model.model.switch - _add(sw.attn, "switch.layers.0") - num_decoder = config.num_hidden_layers - sw.num_cache_layers - for i in range(num_decoder): - _add(model.model.layers[i].self_attn.attn, f"model.layers.{i}.self_attn.attn") - - return kv_caches, attention_map - - -def _build_metadata(attention_map, seq_len, device): - slot_mapping = torch.arange(seq_len, dtype=torch.int64, device=device) - num_blocks = (seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE - block_table = torch.arange(num_blocks, dtype=torch.int32, device=device).unsqueeze(0) - query_start_loc = torch.tensor([0, seq_len], dtype=torch.int32, device=device) - seq_lens = torch.tensor([seq_len], dtype=torch.int32, device=device) - - backend_name = list(attention_map.values())[0].attn_backend.get_name() - if backend_name == "FLASH_ATTN": - from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata - - # scheduler_metadata is FA3-only — passing it on FA2 (Ampere/A100) - # forces FA3 kernel dispatch and crashes with "no kernel image - # is available". Only compute it when get_flash_attn_version() == 3 - # (Hopper SM90+). - scheduler_metadata = None - try: - from vllm.v1.attention.backends.fa_utils import ( - get_flash_attn_version, - get_scheduler_metadata, - ) - if get_flash_attn_version() == 3: - first_attn = list(attention_map.values())[0] - scheduler_metadata = get_scheduler_metadata( - batch_size=1, - max_seqlen_q=seq_len, - max_seqlen_k=seq_len, - num_heads_q=first_attn.num_heads, - num_heads_kv=first_attn.num_kv_heads, - headdim=first_attn.head_size, - cache_seqlens=seq_lens, - qkv_dtype=torch.bfloat16, - cu_seqlens_q=query_start_loc, - page_size=BLOCK_SIZE, - causal=True, - num_splits=0, - ) - except ImportError: - pass - - metadata = FlashAttentionMetadata( - num_actual_tokens=seq_len, - max_query_len=seq_len, - query_start_loc=query_start_loc, - max_seq_len=seq_len, - seq_lens=seq_lens, - block_table=block_table, - slot_mapping=slot_mapping, - use_cascade=False, - common_prefix_len=0, - cu_prefix_query_lens=None, - prefix_kv_lens=None, - suffix_kv_lens=None, - causal=True, - scheduler_metadata=scheduler_metadata, - ) - else: - pytest.skip(f"Backend {backend_name}: metadata not implemented for this test") - - return metadata, slot_mapping - - -def _run_vllm_forward_is_finite(ctrl_pos, seq_len, seed): - """Create a vLLM switch model and check for finite output at the given ctrl_pos.""" - _ensure_distributed() - device = torch.device("cuda") - config = _make_config() - - old_dtype = torch.get_default_dtype() - torch.set_default_dtype(torch.bfloat16) - try: - vllm_config = _make_vllm_config(config) - with set_current_vllm_config(vllm_config): - model = GraniteSwitchForCausalLM(vllm_config=vllm_config).to(device) - finally: - torch.set_default_dtype(old_dtype) - - _init_weights(model) - kv_caches, attention_map = _setup_kv_caches(model, config, vllm_config, device) - - # Build input: ctrl_token at ctrl_pos, random content elsewhere - torch.manual_seed(seed) - ctrl_id = _CTRL_TOKEN - content = torch.randint(0, 100, (seq_len,)).tolist() - ids_list = content[:ctrl_pos] + [ctrl_id] + content[ctrl_pos:] - total_len = len(ids_list) - - input_ids = torch.tensor(ids_list, dtype=torch.long, device=device) - positions = torch.arange(total_len, dtype=torch.long, device=device) - metadata, slot_mapping = _build_metadata(attention_map, total_len, device) - - layer_names = list(attention_map.keys()) - attn_metadata = {n: metadata for n in layer_names} - slot_mapping_dict = {n: slot_mapping for n in layer_names} - - forward_ctx = ForwardContext( - no_compile_layers=vllm_config.compilation_config.static_forward_context, - attn_metadata=attn_metadata, - slot_mapping=slot_mapping_dict, - ) - - old_direct = {n: attention_map[n].use_direct_call for n in layer_names} - for n in layer_names: - attention_map[n].use_direct_call = True - - try: - for kv in kv_caches: - kv.zero_() - with override_forward_context(forward_ctx): - hidden = model.forward(input_ids=input_ids, positions=positions) - logits = model.compute_logits(hidden) - finally: - for n in layer_names: - attention_map[n].use_direct_call = old_direct[n] - - sfc = vllm_config.compilation_config.static_forward_context - for n in layer_names: - sfc.pop(n, None) - - return bool(logits.isfinite().all()) - - -# ════════════════════════════════════════════════════════════════════ -# 1. vLLM-specific unit tests: _expand_with_control_dimensions -# Tensor layout: [num_tokens, num_heads * head_dim] (flat) -# ════════════════════════════════════════════════════════════════════ - - -def _vllm_stub(num_heads=2, num_kv_heads=2, head_dim=32, control_dims=1): - """Minimal namespace for vLLM _expand_with_control_dimensions.""" - return types.SimpleNamespace( - num_heads=num_heads, - num_kv_heads=num_kv_heads, - head_dim=head_dim, - control_dims=control_dims, - expanded_head_dim=head_dim + control_dims, - ) - - -def _vllm_expand(stub, q, k, v, membership, suppression): - return GraniteLoRAEmbeddedAttention._expand_with_control_dimensions( - stub, q, k, v, membership, suppression - ) - - -def _vllm_qkv(stub, num_tokens): - """Create flat vLLM-layout Q/K/V tensors.""" - q = torch.randn(num_tokens, stub.num_heads * stub.head_dim) - k = torch.randn(num_tokens, stub.num_kv_heads * stub.head_dim) - v = torch.randn(num_tokens, stub.num_kv_heads * stub.head_dim) - return q, k, v - - -class TestExpandControlDimensions: - """Direct tests of _expand_with_control_dimensions (vLLM flat tensor layout). - - Input shape: [num_tokens, num_heads * head_dim] - Output shape: [num_tokens, num_heads * expanded_head_dim] - """ - - _HEAD_DIM = 32 - _CTRL_DIMS = 1 - - def test_control_token_q_hide_zero_at_position_zero(self): - """Core fix: control token at pos 0 must not activate Q-side hiding.""" - stub = _vllm_stub(control_dims=self._CTRL_DIMS) - membership = torch.ones(1, 1, dtype=torch.bool) - suppression = torch.ones(1, 1, dtype=torch.bool) - q, k, v = _vllm_qkv(stub, num_tokens=1) - - q_exp, _, _ = _vllm_expand(stub, q, k, v, membership, suppression) - - q_reshaped = q_exp.view(1, stub.num_heads, stub.expanded_head_dim) - q_ctrl = q_reshaped[0, :, self._HEAD_DIM:] - assert q_ctrl.eq(0).all(), f"Control token at pos 0: q_control must be 0, got {q_ctrl}" - - def test_adapter_generated_tokens_q_hide_one(self): - """Adapter-generated tokens (non-members) keep q_control=1.""" - stub = _vllm_stub(control_dims=self._CTRL_DIMS) - num_tokens = 5 - membership = torch.zeros(num_tokens, 1, dtype=torch.bool) - membership[0, 0] = True # control token at pos 0 - suppression = torch.ones(num_tokens, 1, dtype=torch.bool) - q, k, v = _vllm_qkv(stub, num_tokens) - - q_exp, _, _ = _vllm_expand(stub, q, k, v, membership, suppression) - q_reshaped = q_exp.view(num_tokens, stub.num_heads, stub.expanded_head_dim) - - assert q_reshaped[0, :, self._HEAD_DIM:].eq(0).all(), "Control token: q_control must be 0" - for pos in range(1, num_tokens): - assert q_reshaped[pos, :, self._HEAD_DIM:].eq(1).all(), ( - f"Adapter-generated token at pos {pos}: q_control must be 1" - ) - - def test_k_side_finfo_min_for_control_token(self): - """K-side branding is unaffected by the fix.""" - stub = _vllm_stub(control_dims=self._CTRL_DIMS) - membership = torch.ones(1, 1, dtype=torch.bool) - q, k, v = _vllm_qkv(stub, num_tokens=1) - - _, k_exp, _ = _vllm_expand(stub, q, k, v, membership, None) - - k_reshaped = k_exp.view(1, stub.num_kv_heads, stub.expanded_head_dim) - k_ctrl = k_reshaped[0, :, self._HEAD_DIM:] - expected_min = torch.finfo(k.dtype).min - torch.testing.assert_close(k_ctrl, torch.full_like(k_ctrl, expected_min)) - - def test_both_none_leaves_all_control_dims_zero(self): - """With both tensors None, all control dims remain zero.""" - stub = _vllm_stub(control_dims=2) - q, k, v = _vllm_qkv(stub, num_tokens=4) - q_exp, k_exp, v_exp = _vllm_expand(stub, q, k, v, None, None) - - exp_head = stub.expanded_head_dim - assert q_exp.view(4, stub.num_heads, exp_head)[:, :, self._HEAD_DIM:].eq(0).all() - assert k_exp.view(4, stub.num_kv_heads, exp_head)[:, :, self._HEAD_DIM:].eq(0).all() - assert v_exp.view(4, stub.num_kv_heads, exp_head)[:, :, self._HEAD_DIM:].eq(0).all() - - def test_original_dimensions_preserved(self): - """Original head dims are unchanged; only control dims appended.""" - stub = _vllm_stub(control_dims=2) - q, k, v = _vllm_qkv(stub, num_tokens=3) - q_exp, k_exp, v_exp = _vllm_expand(stub, q, k, v, None, None) - - exp_head = stub.expanded_head_dim - torch.testing.assert_close( - q_exp.view(3, stub.num_heads, exp_head)[:, :, :self._HEAD_DIM], - q.view(3, stub.num_heads, stub.head_dim), - ) - - -# ════════════════════════════════════════════════════════════════════ -# 2. Shared SDPA cases -# ════════════════════════════════════════════════════════════════════ - - -class TestSDPANaN(SDPANaNCases): - pass - - -# ════════════════════════════════════════════════════════════════════ -# 3. Shared model finiteness cases — vLLM backend -# ════════════════════════════════════════════════════════════════════ - - -class TestModelFiniteness(ModelFinitenessCases): - def _assert_no_nan(self, switch_type, ctrl_pos, seq_len, seed): - is_finite = _run_vllm_forward_is_finite(ctrl_pos, seq_len, seed) - assert is_finite, ( - f"[vLLM] ctrl_pos={ctrl_pos}: logits contain NaN/Inf" - ) - - -# ════════════════════════════════════════════════════════════════════ -# 4. Mutation test — proves TestModelFiniteness is sensitive to the fix -# ════════════════════════════════════════════════════════════════════ - - -def _buggy_expand( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - token_group_membership, - query_group_suppression, -) -> tuple: - """Pre-fix version: omits q_hide *= (1 - membership), causing NaN at ctrl_pos=0.""" - num_tokens = q.size(0) - device = q.device - dtype = q.dtype - - q = q.view(num_tokens, self.num_heads, self.head_dim) - k = k.view(num_tokens, self.num_kv_heads, self.head_dim) - v = v.view(num_tokens, self.num_kv_heads, self.head_dim) - - q_control = torch.zeros(num_tokens, self.num_heads, self.control_dims, device=device, dtype=dtype) - k_control = torch.zeros(num_tokens, self.num_kv_heads, self.control_dims, device=device, dtype=dtype) - v_control = torch.zeros(num_tokens, self.num_kv_heads, self.control_dims, device=device, dtype=dtype) - - if token_group_membership is not None: - num_groups = token_group_membership.shape[-1] - hiding_constant = torch.finfo(dtype).min - k_control[:, :, :num_groups] = ( - token_group_membership.unsqueeze(1) - .expand(-1, self.num_kv_heads, -1) - .to(dtype) * hiding_constant - ) - - if query_group_suppression is not None: - num_groups = query_group_suppression.shape[-1] - q_hide = query_group_suppression.to(dtype) - # BUG: missing `q_hide *= (1 - token_group_membership)` — control token - # at position 0 gets q_ctrl=1, causing softmax([-inf]) = NaN. - q_control[:, :, :num_groups] = q_hide.unsqueeze(1).expand(-1, self.num_heads, -1) - - q = torch.cat([q, q_control], dim=-1) - k = torch.cat([k, k_control], dim=-1) - v = torch.cat([v, v_control], dim=-1) - - q = q.view(num_tokens, self.num_heads * self.expanded_head_dim) - k = k.view(num_tokens, self.num_kv_heads * self.expanded_head_dim) - v = v.view(num_tokens, self.num_kv_heads * self.expanded_head_dim) - - return q, k, v - - -class TestFixSensitivity: - """Mutation test: revert the fix and confirm NaN is produced. - - Patches _expand_with_control_dimensions with the pre-fix (buggy) version. - If _run_vllm_forward_is_finite still returns True, the model-level tests - are not actually sensitive to the fix and must be reconsidered. - """ - - def test_buggy_expand_produces_nan_at_ctrl_pos_zero(self): - """Without the fix, ctrl_pos=0 must produce non-finite logits in vLLM.""" - from granite_switch.vllm.core.decoder import GraniteLoRAEmbeddedAttention - from unittest.mock import patch - - with patch.object( - GraniteLoRAEmbeddedAttention, - "_expand_with_control_dimensions", - _buggy_expand, - ): - is_finite = _run_vllm_forward_is_finite(ctrl_pos=0, seq_len=8, seed=99) - - assert not is_finite, ( - "[vLLM] Expected NaN with buggy expand at ctrl_pos=0, " - "but output was finite — test is not sensitive to the fix" - ) diff --git a/tests/vllm/test_kv_hiding_gap_equivalence.py b/tests/vllm/test_kv_hiding_gap_equivalence.py deleted file mode 100644 index 13e6f34..0000000 --- a/tests/vllm/test_kv_hiding_gap_equivalence.py +++ /dev/null @@ -1,39 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""KV hiding gap equivalence tests (subprocess wrapper). - -Runs _kv_hiding_gap_tests.py in a subprocess so the parent pytest process -never creates a CUDA context. -""" - -import importlib.util -import subprocess -import sys -from pathlib import Path - -import pytest - -_VLLM_AVAILABLE = importlib.util.find_spec("vllm") is not None - -pytestmark = pytest.mark.skipif( - not _VLLM_AVAILABLE, - reason="requires vLLM installed (GPU checked by inner tests)", -) - -_INNER = Path(__file__).parent / "_kv_hiding_gap_tests.py" -_TIMEOUT = 600 - - -def _run_inner_class(class_name): - cmd = [sys.executable, "-m", "pytest", str(_INNER), - "-v", "-s", "--tb=short", "-k", class_name] - result = subprocess.run(cmd, capture_output=True, text=True, timeout=_TIMEOUT) - if result.stdout: - print(result.stdout[-4000:]) - if result.stderr: - print("STDERR:", result.stderr[-2000:]) - assert result.returncode == 0, f"Inner tests failed (exit {result.returncode})" - - -class TestKVHidingGapEquivalence: - def test_suite(self): - _run_inner_class("TestKVHidingGapEquivalence") From 8b77f1b5f0667a7322becf4cd2a6a30b272a877d Mon Sep 17 00:00:00 2001 From: AlonMalach Date: Thu, 14 May 2026 15:21:57 +0300 Subject: [PATCH 09/18] Fix base-weight validator rejecting control_to_substitute_lut buffer (#8) The new switch buffer was failing compose-pipeline validation because buffer_keywords still listed the deleted legacy buffer names instead of the new one. Replace token_to_group_mask / adapter_hiding_matrix / all_hiding_group_token_ids with control_to_substitute_lut in arch.py and in the two test_granite4_mini parameter-allowlist assertions. --- .gitignore | 7 ++++++- src/granite_switch/composer/arch.py | 4 +--- tests/hf/test_granite4_mini.py | 2 +- tests/vllm/test_granite4_mini.py | 2 +- 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index f4f70e5..a5dcceb 100644 --- a/.gitignore +++ b/.gitignore @@ -49,4 +49,9 @@ TEST_STATUS_REPORT.md htmlcov/ # pyenv version file (local dev preference) -.python-version \ No newline at end of file +.python-version + +# Local design/planning doc (keep on disk, do not version) +docs/KV_CACHE_OVERHEAD_REMOVAL.md +docs/KV_CACHE_OVERHEAD_REMOVAL.html +docs/KV_CACHE_OVERHEAD_REMOVAL*.html \ No newline at end of file diff --git a/src/granite_switch/composer/arch.py b/src/granite_switch/composer/arch.py index f41735e..22e70fb 100644 --- a/src/granite_switch/composer/arch.py +++ b/src/granite_switch/composer/arch.py @@ -119,9 +119,7 @@ class ArchDescriptor: default_factory=lambda: [ "adapter_token_ids", "adapter_scalings", - "token_to_group_mask", - "adapter_hiding_matrix", - "all_hiding_group_token_ids", + "control_to_substitute_lut", ] ) diff --git a/tests/hf/test_granite4_mini.py b/tests/hf/test_granite4_mini.py index c70412d..3de3884 100644 --- a/tests/hf/test_granite4_mini.py +++ b/tests/hf/test_granite4_mini.py @@ -178,7 +178,7 @@ def _make_zero_adapter_pair(cfg_dict): for name in unloaded: assert any(k in name for k in ( "lora_A", "lora_B", "switch", "adapter_token_ids", - "token_to_group_mask", "adapter_hiding_matrix", + "control_to_substitute_lut", )), f"Unexpected unloaded parameter: {name}" # Zero all LoRA weights defensively diff --git a/tests/vllm/test_granite4_mini.py b/tests/vllm/test_granite4_mini.py index 363edaf..6f34688 100644 --- a/tests/vllm/test_granite4_mini.py +++ b/tests/vllm/test_granite4_mini.py @@ -95,7 +95,7 @@ def test_weight_transfer(self, model_name): for name in unloaded: assert any(k in name for k in ( "lora_A", "lora_B", "switch", "adapter_token_ids", - "token_to_group_mask", "adapter_hiding_matrix", + "control_to_substitute_lut", )), f"Unexpected unloaded parameter: {name}" assert len(unloaded) > 0, "Expected LoRA/switch params to be unloaded" From 2543eaf1f2841b4295998ab1e6a0349c66946105 Mon Sep 17 00:00:00 2001 From: AlonMalach Date: Thu, 14 May 2026 15:31:30 +0300 Subject: [PATCH 10/18] Remove dead hiding-constant report from compose output (#8) The report described safety margins for the finfo.min K-side hiding constant. Hiding is gone, so the section is meaningless. Drop the module, the call site in compose_report.py, and the package re-exports. --- .../composer/reporting/__init__.py | 3 - .../composer/reporting/compose_report.py | 5 -- .../reporting/hiding_constant_report.py | 58 ------------------- 3 files changed, 66 deletions(-) delete mode 100644 src/granite_switch/composer/reporting/hiding_constant_report.py diff --git a/src/granite_switch/composer/reporting/__init__.py b/src/granite_switch/composer/reporting/__init__.py index 687c2f7..14fca07 100644 --- a/src/granite_switch/composer/reporting/__init__.py +++ b/src/granite_switch/composer/reporting/__init__.py @@ -4,7 +4,6 @@ from .population_table import generate_adapter_population_table, print_adapter_population_table from .compose_report import generate_compose_report from .adapter_analysis import print_source_adapter_analysis -from .hiding_constant_report import compute_hiding_constant_safety, print_hiding_constant_safety from .model_card import render_model_card, write_model_card, write_build_doc __all__ = [ @@ -12,8 +11,6 @@ 'print_adapter_population_table', 'generate_compose_report', 'print_source_adapter_analysis', - 'compute_hiding_constant_safety', - 'print_hiding_constant_safety', 'render_model_card', 'write_model_card', 'write_build_doc', diff --git a/src/granite_switch/composer/reporting/compose_report.py b/src/granite_switch/composer/reporting/compose_report.py index b392f12..7537e77 100644 --- a/src/granite_switch/composer/reporting/compose_report.py +++ b/src/granite_switch/composer/reporting/compose_report.py @@ -325,11 +325,6 @@ def _print_summary( if len(base_source_not_connected) > 10: print(f" ... and {len(base_source_not_connected) - 10} more") - # Hiding constant safety margin - if model is not None: - from .hiding_constant_report import print_hiding_constant_safety - print_hiding_constant_safety(model.dtype) - print(f"\nDetailed report saved to: {report_path}") print("="*80) diff --git a/src/granite_switch/composer/reporting/hiding_constant_report.py b/src/granite_switch/composer/reporting/hiding_constant_report.py deleted file mode 100644 index c54860a..0000000 --- a/src/granite_switch/composer/reporting/hiding_constant_report.py +++ /dev/null @@ -1,58 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""Safety margin report for the K-side hiding constant. - -The hiding mechanism uses finfo(dtype).min as the K-side control dimension value -for tokens in hiding groups. This module computes and reports the safety margin: -how large a positive attention score would need to be before the hiding breaks -(i.e., exp(fmin + score) > 0). -""" - -import torch - - -def _find_exp_underflow_threshold(dtype: torch.dtype) -> float: - """Find the smallest (most negative) x where exp(x) > 0 for the given dtype. - - Searches from -500 upward in steps of 0.5. - """ - for x_int in range(-1000, 0): - x = x_int * 0.5 - x_t = torch.tensor(x, dtype=dtype) - if torch.exp(x_t).item() > 0.0: - return x - return 0.0 # fallback: exp(x) > 0 for all tested x - - -def compute_hiding_constant_safety(dtype: torch.dtype) -> dict: - """Compute safety margin data for the hiding constant at the given dtype. - - Returns a dict with: - fmin: the hiding constant value - exp_underflow_threshold: smallest x where exp(x) > 0 - safety_margin: positive value that must be added to fmin to break hiding - """ - fmin_val = torch.finfo(dtype).min - exp_threshold = _find_exp_underflow_threshold(dtype) - safety_margin = abs(fmin_val) + exp_threshold # exp_threshold is negative - - return { - "dtype": str(dtype), - "fmin": fmin_val, - "exp_underflow_threshold": exp_threshold, - "safety_margin": safety_margin, - } - - -def print_hiding_constant_safety(dtype: torch.dtype): - """Print the hiding constant safety margin report for the given dtype.""" - data = compute_hiding_constant_safety(dtype) - - print(f"\n{'='*80}") - print("CONTROL DIMENSION HIDING CONSTANT") - print(f"{'='*80}") - print(f" Model dtype: {data['dtype']}") - print(f" Hiding constant (finfo.min): {data['fmin']:.6e}") - print(f" exp(hiding_constant) underflows to zero: True") - print(f" exp() underflow threshold: {data['exp_underflow_threshold']}") - print(f" Safety margin: a positive attention score of {data['safety_margin']:.6e}") - print(f" would be needed to break hiding (make exp(fmin + score) > 0)") From aeda2d55f771f29835b71b17a0c35864c383df26 Mon Sep 17 00:00:00 2001 From: AlonMalach Date: Thu, 14 May 2026 15:57:02 +0300 Subject: [PATCH 11/18] =?UTF-8?q?Update=20tests=20for=20removed=20legacy?= =?UTF-8?q?=20hiding=20fields=20(#8)=20=E2=80=94=20partial?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace control_dims / hiding_groups / hiding_policy / adapter_third_party references with adapter_substitute_token_ids in test fixtures, and drop TestControlTokenKVInvisibility (tested the deleted hiding mechanism). This is a partial sweep — vLLM workers, hf/test_single_switch_e2e.py, and shared/granite4_equivalence.py still need follow-up edits. --- tests/composer/test_save_load_compose.py | 18 ++--- tests/integration/test_hf_to_vllm_weights.py | 9 +-- tests/vllm/_model_forward_tests.py | 73 +------------------- 3 files changed, 7 insertions(+), 93 deletions(-) diff --git a/tests/composer/test_save_load_compose.py b/tests/composer/test_save_load_compose.py index e1008c3..f7e579f 100644 --- a/tests/composer/test_save_load_compose.py +++ b/tests/composer/test_save_load_compose.py @@ -492,14 +492,14 @@ def test_pipeline_metadata_files_exist(self, phase1): ) def test_config_adapter_identity(self, phase1): - """num_adapters, token IDs, names, third_party survive save→load.""" + """num_adapters, token IDs, names, substitute IDs survive save→load.""" built = phase1["built_config"] loaded = phase1["loaded_config"] assert loaded.num_adapters == built.num_adapters assert loaded.adapter_token_ids == built.adapter_token_ids assert loaded.adapter_names == built.adapter_names - assert loaded.adapter_third_party == built.adapter_third_party + assert loaded.adapter_substitute_token_ids == built.adapter_substitute_token_ids def test_config_lora(self, phase1): """adapter_ranks, max_lora_rank, lora_target_modules survive save→load.""" @@ -511,22 +511,13 @@ def test_config_lora(self, phase1): assert loaded.lora_target_modules == built.lora_target_modules def test_config_switch(self, phase1): - """switch head_dim, control_dims, gain survive save→load.""" + """switch_head_dim and control_token_gain survive save→load.""" built = phase1["built_config"] loaded = phase1["loaded_config"] assert loaded.switch_head_dim == built.switch_head_dim - assert loaded.control_dims == built.control_dims assert loaded.control_token_gain == built.control_token_gain - def test_config_hiding(self, phase1): - """hiding_groups and hiding_policy survive save→load.""" - built = phase1["built_config"] - loaded = phase1["loaded_config"] - - assert loaded.hiding_groups == built.hiding_groups - assert loaded.hiding_policy == built.hiding_policy - def test_config_granite_scaling(self, phase1): """Granite-specific scaling parameters survive save→load.""" built = phase1["built_config"] @@ -848,8 +839,7 @@ def test_config_matches(self, phase2): assert c2.adapter_ranks == c1.adapter_ranks assert c2.max_lora_rank == c1.max_lora_rank assert c2.lora_target_modules == c1.lora_target_modules - assert c2.hiding_groups == c1.hiding_groups - assert c2.hiding_policy == c1.hiding_policy + assert c2.adapter_substitute_token_ids == c1.adapter_substitute_token_ids assert c2.logits_scaling == c1.logits_scaling assert c2.attention_multiplier == c1.attention_multiplier assert c2.vocab_size == c1.vocab_size diff --git a/tests/integration/test_hf_to_vllm_weights.py b/tests/integration/test_hf_to_vllm_weights.py index 1daeecd..711652b 100644 --- a/tests/integration/test_hf_to_vllm_weights.py +++ b/tests/integration/test_hf_to_vllm_weights.py @@ -373,17 +373,10 @@ def _config(self): num_adapters=2, adapter_token_ids=[250, 251], adapter_names=["adapter_1", "adapter_2"], - hiding_groups={"all_controls": ["adapter_1", "adapter_2"]}, - hiding_policy={ - "base": ["all_controls"], - "adapter_1": ["all_controls"], - "adapter_2": ["all_controls"], - }, - adapter_third_party=["adapter_1", "adapter_2"], + adapter_substitute_token_ids=[1, 1], max_lora_rank=4, adapter_ranks=[4, 4], switch_head_dim=32, - control_dims=32, max_position_embeddings=512, attention_multiplier=1.0, embedding_multiplier=1.0, diff --git a/tests/vllm/_model_forward_tests.py b/tests/vllm/_model_forward_tests.py index faf4cf4..4caeebb 100644 --- a/tests/vllm/_model_forward_tests.py +++ b/tests/vllm/_model_forward_tests.py @@ -63,40 +63,10 @@ def _tiny_vllm_config(): num_adapters=2, adapter_token_ids=[250, 251], adapter_names=["adapter_1", "adapter_2"], - hiding_groups={"all_controls": ["adapter_1", "adapter_2"]}, - hiding_policy={"base": ["all_controls"], "adapter_1": ["all_controls"], "adapter_2": ["all_controls"]}, - adapter_third_party=["adapter_1", "adapter_2"], + adapter_substitute_token_ids=[1, 1], max_lora_rank=4, adapter_ranks=[4, 4], switch_head_dim=32, - control_dims=32, - max_position_embeddings=512, - attention_multiplier=1.0, - embedding_multiplier=1.0, - residual_multiplier=1.0, - logits_scaling=1.0, - ) - - -def _tiny_vllm_mixed_tp_config(): - """SingleSwitch config where only adapter_1 is third-party.""" - return GraniteSwitchConfig( - vocab_size=300, - hidden_size=64, - intermediate_size=128, - num_hidden_layers=3, - num_attention_heads=2, - num_key_value_heads=2, - num_adapters=2, - adapter_token_ids=[250, 251], - adapter_names=["adapter_1", "adapter_2"], - hiding_groups={"all_controls": ["adapter_1", "adapter_2"]}, - hiding_policy={"base": ["all_controls"], "adapter_1": ["all_controls"], "adapter_2": ["all_controls"]}, - adapter_third_party=["adapter_1"], - max_lora_rank=4, - adapter_ranks=[4, 4], - switch_head_dim=32, - control_dims=32, max_position_embeddings=512, attention_multiplier=1.0, embedding_multiplier=1.0, @@ -449,46 +419,7 @@ def test_different_adapters_produce_different_logits(self): # ════════════════════════════════════════════════════════════════════ -# 5. Control token KV invisibility -# ════════════════════════════════════════════════════════════════════ - -class TestControlTokenKVInvisibility(_VLLMModelTestBase): - - def test_control_token_invisible_to_future_positions(self): - torch.manual_seed(SEED) - self.model.eval() - - seq = [10, 20, 250, 30, 40, 50, 60, 70] - - with torch.no_grad(): - logits_a = self._run_forward_and_logits(seq) - - with torch.no_grad(): - perturbation = torch.randn( - self.config.hidden_size, device=self.device, dtype=torch.bfloat16 - ) * 10.0 - self.model.model.embed_tokens.weight.data[250] += perturbation - - with torch.no_grad(): - logits_b = self._run_forward_and_logits(seq) - - torch.testing.assert_close( - logits_a[:2], logits_b[:2], - msg="Pre-control logits should be identical" - ) - - assert not torch.allclose(logits_a[2], logits_b[2]), \ - "Control token logits should differ after perturbation" - - torch.testing.assert_close( - logits_a[3:], logits_b[3:], - msg="Post-control logits should be identical " - "(control token KV masked by control_dims)" - ) - - -# ════════════════════════════════════════════════════════════════════ -# 6. KV visibility tests +# 5. KV visibility tests # ════════════════════════════════════════════════════════════════════ class TestKVVisibility(_VLLMModelTestBase): From f342f1fde2631c6fd9ece0bbdf266f2d970858a9 Mon Sep 17 00:00:00 2001 From: AlonMalach Date: Thu, 14 May 2026 17:53:23 +0300 Subject: [PATCH 12/18] Strip remaining legacy hiding-field references from tests/docs (#8) - tests/hf/test_single_switch_e2e.py: drop CONTROL_DIMS_MODES axis; one parametrization on attention_multiplier only. Fixture returns a 3-tuple. - tests/vllm/_generation_equivalence_worker.py and _tp_integration_worker.py: remove control_dims/hiding_groups/hiding_policy/adapter_third_party from composer calls; pass adapter_substitute_token_ids instead. - tests/vllm/_single_switch_worker.py: mock_config uses projection_head_dim. - tests/vllm/test_generation_equivalence.py: docstring updated. - tests/shared/granite4_equivalence.py: rationale comments updated for token-exchange (no behavior change). - src/granite_switch/composer/compose_utils.py: docstring/comment cleanup. --- src/granite_switch/composer/compose_utils.py | 7 +- tests/hf/test_single_switch_e2e.py | 64 +++++---------- tests/shared/granite4_equivalence.py | 82 +++++++++----------- tests/vllm/_generation_equivalence_worker.py | 11 +-- tests/vllm/_single_switch_worker.py | 2 +- tests/vllm/_tp_integration_worker.py | 5 +- tests/vllm/test_generation_equivalence.py | 7 +- 7 files changed, 70 insertions(+), 108 deletions(-) diff --git a/src/granite_switch/composer/compose_utils.py b/src/granite_switch/composer/compose_utils.py index 66704b2..dabc47f 100644 --- a/src/granite_switch/composer/compose_utils.py +++ b/src/granite_switch/composer/compose_utils.py @@ -50,9 +50,8 @@ def from_base_and_adapters( adapter_token_ids: Token IDs for adapter control. Required when ``adapter_paths`` is non-empty. adapter_substitute_token_ids: Token IDs whose embeddings replace - control-token embeddings in token-exchange mode. One per adapter. - Pass ``None`` to run the legacy KV-hiding path (requires - ``control_dims > 0`` in ``**kwargs``). + control-token embeddings at the switch. Required when + ``adapter_paths`` is non-empty; one per adapter. adapter_names: Display names for each adapter (external + built-in). When ``None``, derived from the directory structure. built_in_adapter_names: Names for built-in (empty LoRA) adapter slots. @@ -160,7 +159,7 @@ def from_base_and_adapters( } ) - # Merge caller-provided overrides (switch_head_dim, control_dims, etc.) + # Merge caller-provided overrides (switch_head_dim, etc.) config_kwargs.update(kwargs) switch_config = GraniteSwitchConfig(**config_kwargs) diff --git a/tests/hf/test_single_switch_e2e.py b/tests/hf/test_single_switch_e2e.py index e02856b..89401d4 100644 --- a/tests/hf/test_single_switch_e2e.py +++ b/tests/hf/test_single_switch_e2e.py @@ -11,8 +11,8 @@ GraniteSwitchConfig → create_switch() → SingleSwitch.__init__ → model forward → _last_adapter_indices -Parametrized over both PRODUCTION_ATTENTION_MULTIPLIERS and both control_dims -modes (native and hiding) to catch config-flow regressions in either code path. +Parametrized over PRODUCTION_ATTENTION_MULTIPLIERS to catch config-flow +regressions across the realistic multiplier values. CPU-only. Does not exercise vLLM gain compensation — HF SingleSwitch hardcodes scaling=1.0 regardless of config. Compensation is tested in the Tier 2 composer @@ -35,11 +35,6 @@ ) from tests.shared.single_switch_cases import ADAPTER_TOKEN_IDS_LIST, NUM_ADAPTERS -# control_dims=0 → native mode (no KV hiding). control_dims=32 → hiding mode. -# Both take different code paths through SingleSwitch.__init__ (expanded_head_dim) -# and through GraniteSwitchModel.forward (hiding-group mask construction). -CONTROL_DIMS_MODES = [0, 32] - # TEXT_TOKEN matches tests/shared/single_switch_cases.py convention. Any # non-adapter token ID works — 50 is outside ADAPTER_TOKEN_IDS_LIST (1000+). TEXT_TOKEN = 50 @@ -62,35 +57,20 @@ ) -def _build_e2e_overrides(base_cfg, *, num_adapters=NUM_ADAPTERS, control_dims=32): - """Build config overrides for a production-ish E2E test model. - - Three overrides beyond the `single_overrides()` defaults: - - vocab_size: large enough to hold every adapter token ID (derived). - - max_position_embeddings: supports the long-context test matrix (derived). - - control_dims parametrized: native (0) vs hiding (32+). - """ +def _build_e2e_overrides(base_cfg, *, num_adapters=NUM_ADAPTERS): + """Build config overrides for a production-ish E2E test model.""" adapter_names = [f"adapter_{i}" for i in range(num_adapters)] - overrides = { + return { "vocab_size": _E2E_VOCAB_SIZE, "max_position_embeddings": _E2E_MAX_POSITION_EMBEDDINGS, "num_adapters": num_adapters, "adapter_ranks": [8] * num_adapters, "adapter_token_ids": ADAPTER_TOKEN_IDS_LIST[:num_adapters], "adapter_names": adapter_names, - "control_dims": control_dims, + "adapter_substitute_token_ids": [1] * num_adapters, "num_hidden_layers": len(base_cfg["layer_types"]) + 1, "layer_types": ["attention"] + base_cfg["layer_types"], } - if control_dims > 0: - # Hiding mode needs hiding_groups + hiding_policy + adapter_third_party. - overrides["hiding_groups"] = {"all_controls": adapter_names} - overrides["hiding_policy"] = { - n: ["all_controls"] for n in ["base"] + adapter_names - } - overrides["adapter_third_party"] = adapter_names - # control_dims == 0 → native mode → no hiding_groups/policy. - return overrides def _make_e2e_model(base_cfg, overrides): @@ -107,29 +87,28 @@ def _make_e2e_model(base_cfg, overrides): # Module scope would save ~19s across the long-context matrix but would require # auditing that no test mutates model state — not worth it. @pytest.fixture( - params=[(m, cd) for m in PRODUCTION_ATTENTION_MULTIPLIERS for cd in CONTROL_DIMS_MODES], - ids=lambda p: f"mult={p[0]}-cd={p[1]}", + params=PRODUCTION_ATTENTION_MULTIPLIERS, + ids=lambda m: f"mult={m}", ) def e2e_model(request): - """GraniteSwitchForCausalLM parametrized over (attention_multiplier, control_dims).""" - multiplier, control_dims = request.param + """GraniteSwitchForCausalLM parametrized over attention_multiplier.""" + multiplier = request.param base_cfg = {**DENSE_CFG, "attention_multiplier": multiplier} - overrides = _build_e2e_overrides(base_cfg, control_dims=control_dims) + overrides = _build_e2e_overrides(base_cfg) model, config = _make_e2e_model(base_cfg, overrides) - return model, config, multiplier, control_dims + return model, config, multiplier @pytest.fixture def e2e_model_32adapter(): """Single-variant fixture for the 32-adapter stress test. - The adapter-ID rounding math is independent of (multiplier, control_dims), - so we don't parametrize this fixture — TestE2EBasicAdapterActivation already - covers the cross-product. Chosen variant: hiding mode (control_dims=32) with - the most common production multiplier (0.0078125, granite-4.0-h-1b/tiny/small/4.1-8b/30b). + The adapter-ID rounding math is independent of multiplier, so we don't + parametrize. Chosen variant: most common production multiplier + (0.0078125, granite-4.0-h-1b/tiny/small/4.1-8b/30b). """ base_cfg = {**DENSE_CFG, "attention_multiplier": 0.0078125} - overrides = _build_e2e_overrides(base_cfg, control_dims=32) + overrides = _build_e2e_overrides(base_cfg) model, config = _make_e2e_model(base_cfg, overrides) return model, config @@ -160,10 +139,9 @@ def test_pre_control_is_zero_post_control_matches_adapter(self, e2e_model): from position 2 onward; positions before it remain at 0 (base). Proves the full chain config → create_switch → forward → - _last_adapter_indices works on a production-ish multiplier/control_dims - combination. + _last_adapter_indices works on a production-ish multiplier. """ - model, config, mult, cd = e2e_model + model, config, mult = e2e_model ctrl_token = config.adapter_token_ids[0] # adapter_0 → expected index 1 input_ids = torch.tensor([[10, 20, ctrl_token, 30, 40, 50, 60, 70]]) with torch.no_grad(): @@ -227,7 +205,7 @@ def test_long_context_e2e(self, e2e_model, seq_len, adapter_idx, control_positio Default CI runs seq_len ∈ {10K, 32K}; `-m slow` adds 65K and 131K. """ - model, config, mult, cd = e2e_model + model, config, mult = e2e_model ctrl_token = config.adapter_token_ids[adapter_idx] expected_id = adapter_idx + 1 ctrl_pos = _control_position(seq_len, control_position) @@ -244,10 +222,10 @@ def test_long_context_e2e(self, e2e_model, seq_len, adapter_idx, control_positio assert (ai[:ctrl_pos] == 0).all(), ( f"pre-control slice should be all 0; failed at seq_len={seq_len}, " f"adapter_idx={adapter_idx}, position={control_position}, " - f"mult={mult}, cd={cd}" + f"mult={mult}" ) assert (ai[ctrl_pos:] == expected_id).all(), ( f"post-control slice should be all {expected_id}; failed at seq_len={seq_len}, " f"adapter_idx={adapter_idx}, position={control_position}, " - f"mult={mult}, cd={cd}" + f"mult={mult}" ) diff --git a/tests/shared/granite4_equivalence.py b/tests/shared/granite4_equivalence.py index 68f58ba..3237c52 100644 --- a/tests/shared/granite4_equivalence.py +++ b/tests/shared/granite4_equivalence.py @@ -122,31 +122,29 @@ def transfer_weights_strict(upstream_sd, switch_sd): # zero LoRA weights the delta is zero for the LoRA path. # # All control tokens (adapter_token_ids) are KV-hidden by default. -# Test sequences use adapter tokens; hidden positions are excluded +# Test sequences use adapter tokens; the control position is excluded # from comparison via get_visible_mask(). # -# KV hiding: control tokens get K=finfo.min masking on control dims, -# zeroing their attention contribution at hidden positions. +# Token-exchange: at the switch, each control token's id is rewritten to +# its configured substitute id before the decoder embeds. Upstream sees +# the original control id; switch sees the substitute. The two embeddings +# differ at the control position only. # -# Error sources (with control_dims=32, exact K=-inf masking): +# Error sources: # -# 1. Hidden token attention contribution removal (primary): -# The upstream model attends to control tokens normally. The switch -# model masks them exactly (K=-inf -> zero attention weight). The diff -# is the value of the removed attention contribution -- fundamental -# and unavoidable. Error scales with hidden_tokens / seq_len. +# 1. Embedding divergence at control positions (primary): +# Upstream embeds the original control id; switch embeds the substitute. +# The control position itself is excluded from comparison. Visible +# positions attend to the control position too — the attention +# contribution from the (substitute vs original) embedding propagates. +# Error scales with control_tokens / seq_len. # -# 2. Expanded tensor FP rounding (secondary): -# control_dims adds extra dimensions to Q/K/V, changing the dot -# product accumulation in the attention kernel (D+32 vs D elements). -# This introduces small FP differences even for real token positions. -# -# 3. Mamba conv1d zero-gap (additional, hybrid only): -# Input zeroing writes zeros into conv1d's sliding window, perturbing -# K-1 subsequent real tokens per hidden token (issue #5). +# 2. Mamba conv1d effects (hybrid only): +# Conv1d's sliding window over substituted vs original tokens at the +# control position perturbs K-1 subsequent real tokens. # # Token allocation (within vocab_size=256): -# 101+ = adapter_token_ids (KV-hidden, activate switch) +# 101+ = adapter_token_ids (rewritten to substitute by switch) # Random fill: [0, 100) -- guaranteed no collisions with control tokens # Control token IDs -- low-vocab, valid embeddings in the base model @@ -165,9 +163,9 @@ def augment_cfg_with_adapters(cfg_dict, num_adapters=2, rank=8): - num_hidden_layers += 1 (1 cache slot for SingleSwitch) - layer_types prepended with "attention" (switch layer type) - LoRA adapter config fields - - adapter_token_ids (KV-hidden) + - adapter_token_ids (rewritten to substitute ids by the switch) + - adapter_substitute_token_ids (token-exchange substitutes) - adapter_names for name-to-index mapping - - control_dims=32 (default: exact K=-inf masking, no softmax dilution) """ cfg = dict(cfg_dict) @@ -230,13 +228,14 @@ def zero_lora_weights(model): def get_visible_mask(input_ids): - """Return boolean mask of non-hidden (visible) positions. + """Return boolean mask of non-control (visible) positions. - Positions in a hiding group get K=finfo.min masking on control dims, - making their logits intentionally different from upstream. This mask - identifies positions that should be compared in equivalence tests. + Control positions hold the substitute embedding in the switch model + versus the original control-token embedding upstream — their logits + are intentionally different. This mask identifies positions that + should be compared in equivalence tests. - All adapter tokens (>= _ADAPTER_TOKEN_BASE) are KV-hidden. + All adapter tokens (>= _ADAPTER_TOKEN_BASE) are control positions. Fill tokens from [0, 100) are visible. """ is_adapter = input_ids >= _ADAPTER_TOKEN_BASE @@ -249,38 +248,33 @@ def get_visible_mask(input_ids): def get_tolerances(layer_types, long_sequence=False, has_kv_hidden=False): """Return (atol, rtol) for a given architecture. - Error sources (systematic analysis): - - 1. **No hiding, no adapters**: GraniteSwitch with num_adapters=0 is - bit-exact vs upstream Granite (all configs). Fused QKV matmul is - bit-identical to separate Q/K/V matmuls in float32. + Error sources: - 2. **Hidden token attention contribution removal**: When control tokens - are hidden, the switch model masks them exactly (K=-inf, zero attention - weight via control_dims). Visible tokens lose the attention contribution - that the upstream model gets from those positions. Fundamental and - unavoidable — error scales with hidden_tokens / seq_len. + 1. **No adapters**: GraniteSwitch with num_adapters=0 is bit-exact vs + upstream Granite. Fused QKV matmul is bit-identical to separate + Q/K/V matmuls in float32. - 3. **Expanded tensor FP rounding**: control_dims adds extra dimensions - to Q/K/V, changing the attention kernel's dot product accumulation - (D+32 vs D elements). Small FP rounding differences at real positions. + 2. **Token-exchange embedding divergence**: With adapters and a control + token in the input, the switch embeds the substitute id at that + position while upstream embeds the original control id. Visible + positions attending to the control position pick up that delta. Args: layer_types: list of "attention" strings long_sequence: unused (kept for API compatibility) - has_kv_hidden: True when control token hiding is active + has_kv_hidden: True when adapters are active and control tokens + are present (kept name for API compatibility — the parameter + now means "control tokens get substituted"). Returns: (atol, rtol) tuple, or None if bit-exact match expected. """ if not has_kv_hidden: - # No hiding: bit-exact (fused QKV is numerically identical, - # control_dims expansion adds exactly 0 to dot products). + # Pure base-model path: bit-exact (fused QKV numerically identical). return None else: - # Attention-only with hiding (control_dims=32): hidden token - # attention contribution removed. - # Worst observed: ~5.0e-2 (multi 1b, seed-dependent). + # Substitute-embedding propagates through attention to visible + # positions. Worst observed: ~5.0e-2 (multi 1b, seed-dependent). return (6e-2, 6e-2) diff --git a/tests/vllm/_generation_equivalence_worker.py b/tests/vllm/_generation_equivalence_worker.py index 0609b8f..17c4e40 100644 --- a/tests/vllm/_generation_equivalence_worker.py +++ b/tests/vllm/_generation_equivalence_worker.py @@ -9,8 +9,8 @@ python worker.py compare --work-dir --label **build**: Loads config for dtype/vocab, generates a deterministic 64-token prompt, -builds a GraniteSwitch model with 1 built-in adapter (zero LoRA weights) and -control_dims=32. Saves the switch model and inputs to ``/``. +builds a GraniteSwitch model with 1 built-in adapter (zero LoRA weights). +Saves the switch model and inputs to ``/``. **run**: Loads inputs from ``/inputs.json``, loads model in vLLM, runs greedy autoregressive generation (temperature=0, max_tokens=32), saves generated @@ -83,17 +83,14 @@ def cmd_build(args): print(f" saved inputs to {inputs_path}") # Build switch model with 1 built-in adapter - print(f"\nBuilding GraniteSwitch (1 built-in adapter, control_dims=32)...") + print(f"\nBuilding GraniteSwitch (1 built-in adapter)...") skin_dir = os.path.join(work_dir, "switch") model = GraniteSwitchComposer.from_base_and_adapters( model_name, built_in_adapter_names=["test"], adapter_names=["test"], adapter_token_ids=[adapter_token_id], - control_dims=32, - hiding_groups={"all_controls": ["test"]}, - hiding_policy={"base": ["all_controls"], "test": ["all_controls"]}, - adapter_third_party=["test"], + adapter_substitute_token_ids=[1], torch_dtype=dtype, ) diff --git a/tests/vllm/_single_switch_worker.py b/tests/vllm/_single_switch_worker.py index d876dc4..8f879de 100644 --- a/tests/vllm/_single_switch_worker.py +++ b/tests/vllm/_single_switch_worker.py @@ -54,7 +54,7 @@ def _setup(): mock_config = SimpleNamespace( num_attention_heads=4, num_key_value_heads=2, - expanded_head_dim=64, + projection_head_dim=64, attention_multiplier=0.125, ) diff --git a/tests/vllm/_tp_integration_worker.py b/tests/vllm/_tp_integration_worker.py index 6ebf2ca..1729c01 100644 --- a/tests/vllm/_tp_integration_worker.py +++ b/tests/vllm/_tp_integration_worker.py @@ -47,12 +47,9 @@ def cmd_build(args): built_in_adapter_names=["test"], adapter_names=["test"], adapter_token_ids=[adapter_token_id], + adapter_substitute_token_ids=[1], muted_adapter_token_ids=[muted_token_id], - control_dims=32, switch_type="single", - hiding_groups={"all_controls": ["test"]}, - hiding_policy={"base": ["all_controls"], "test": ["all_controls"]}, - adapter_third_party=["test"], ) model.save_pretrained(output_dir) diff --git a/tests/vllm/test_generation_equivalence.py b/tests/vllm/test_generation_equivalence.py index 8f64e1f..d967107 100644 --- a/tests/vllm/test_generation_equivalence.py +++ b/tests/vllm/test_generation_equivalence.py @@ -2,13 +2,10 @@ """Verify greedy generation equivalence: upstream model vs zero-adapter switch model. Tests that autoregressive generation produces identical token sequences when a -GraniteSwitch model has a single built-in adapter with zero LoRA weights and -control_dims=32 (KV hiding infrastructure active, standard third-party mode). +GraniteSwitch model has a single built-in adapter with zero LoRA weights. No control tokens appear in the prompt, so: -- Switch layer → adapter_indices=0 everywhere -- hidden_count=0 → no RoPE gap correction -- K control dims = 0 for all tokens → QK dot product unchanged +- Switch layer → adapter_indices=0 everywhere, no token rewrite - LoRA delta = 0 → decoder layers produce identical output Each model runs in its own set of subprocesses so CUDA context is fully torn From 41d656b825aa5e95eb752c9b1da94702c64ad95c Mon Sep 17 00:00:00 2001 From: AlonMalach Date: Thu, 14 May 2026 18:59:04 +0300 Subject: [PATCH 13/18] Drop tensor.any() gate from switch token-exchange rewrite (#8) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The vLLM decoder is wrapped in @support_torch_compile; Dynamo cannot trace data-dependent branching like ``if is_control.any()``. The gate broke engine init on GPU runs. Replace it with an unconditional torch.where in both backends — keeps HF and vLLM symmetric, costs one indexed gather + one elementwise select per forward, and makes the switch compile-safe. --- src/granite_switch/hf/switch/single.py | 12 +++++------- src/granite_switch/vllm/switch/single.py | 13 +++++++------ 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/granite_switch/hf/switch/single.py b/src/granite_switch/hf/switch/single.py index ea2508b..6891fbd 100644 --- a/src/granite_switch/hf/switch/single.py +++ b/src/granite_switch/hf/switch/single.py @@ -234,16 +234,14 @@ def forward( # so the decoder sees a clean, unified input_ids and never has to # know about substitutes. Skipped only when the LUT was not built # (no substitute ids configured — e.g. a non-token-exchange test - # fixture). + # fixture). Kept symmetric with the vLLM switch, which forbids the + # `tensor.any()` short-circuit under @support_torch_compile. if self.control_to_substitute_lut is not None: sub_id_per_pos = self.control_to_substitute_lut[input_ids] is_control = sub_id_per_pos >= 0 - if is_control.any(): - modified_input_ids = torch.where( - is_control, sub_id_per_pos, input_ids - ) - else: - modified_input_ids = input_ids + modified_input_ids = torch.where( + is_control, sub_id_per_pos, input_ids + ) else: modified_input_ids = input_ids diff --git a/src/granite_switch/vllm/switch/single.py b/src/granite_switch/vllm/switch/single.py index a250270..6a95c0d 100644 --- a/src/granite_switch/vllm/switch/single.py +++ b/src/granite_switch/vllm/switch/single.py @@ -198,15 +198,16 @@ def forward( # Token-exchange rewrite: see the HF switch for the rationale. # Skipped only when no LUT was built (no substitute ids configured). + # No data-dependent gate here — the surrounding decoder is wrapped in + # @support_torch_compile, which forbids `tensor.any()` branching. + # `torch.where` runs every step; the cost is one indexed gather and + # one elementwise select on the flat input. if self.control_to_substitute_lut is not None: sub_id_per_pos = self.control_to_substitute_lut[input_ids] is_control = sub_id_per_pos >= 0 - if is_control.any(): - modified_input_ids = torch.where( - is_control, sub_id_per_pos, input_ids - ) - else: - modified_input_ids = input_ids + modified_input_ids = torch.where( + is_control, sub_id_per_pos, input_ids + ) else: modified_input_ids = input_ids From c40707c32b1c022bf156492481b1c7f15b2a22c1 Mon Sep 17 00:00:00 2001 From: AlonMalach Date: Thu, 14 May 2026 19:42:15 +0300 Subject: [PATCH 14/18] Fix vLLM test runners after switch tuple-return + dead-class purge (#8) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three fixes uncovered by GPU run: 1. tests/vllm/_single_switch_worker.py: switch.forward now returns (adapter_indices, modified_input_ids); unpack and return only the indices. Worker was calling .cpu() on a tuple → every parametrized test in tests/vllm/test_single_switch.py failed at the same point. 2. tests/vllm/test_model_forward.py: drop the TestControlTokenKVInvisibility class stub. The inner class was deleted with the legacy hiding tests in 0ddaf0e, but the parametrized runner still referenced it. 3. tests/vllm/test_position_zero_nan.py: deleted. The inner _position_zero_nan_tests.py was removed (only existed for the legacy hiding path); the runner became orphan and pytest reported "file or directory not found" on every parametrized variant. The flash_api.cpp:697 "no kernel image" failures in test_model_forward are pre-existing GPU/FlashAttention environment issues, not branch bugs. --- tests/vllm/_single_switch_worker.py | 4 +-- tests/vllm/test_model_forward.py | 5 --- tests/vllm/test_position_zero_nan.py | 54 ---------------------------- 3 files changed, 2 insertions(+), 61 deletions(-) delete mode 100644 tests/vllm/test_position_zero_nan.py diff --git a/tests/vllm/_single_switch_worker.py b/tests/vllm/_single_switch_worker.py index 8f879de..91439fe 100644 --- a/tests/vllm/_single_switch_worker.py +++ b/tests/vllm/_single_switch_worker.py @@ -213,7 +213,7 @@ def _run(harness, seq, num_adapters, control_token_gain): try: with override_forward_context(forward_ctx): - result = switch.forward( + adapter_indices, _modified_input_ids = switch.forward( input_ids=input_ids, adapter_token_ids=adapter_token_ids, ) @@ -223,7 +223,7 @@ def _run(harness, seq, num_adapters, control_token_gain): switch.effective_gain = orig_effective_gain switch.num_adapters = orig_num_adapters - return result.cpu().tolist() + return adapter_indices.cpu().tolist() def _query_geometry(harness): diff --git a/tests/vllm/test_model_forward.py b/tests/vllm/test_model_forward.py index b70daca..17f98be 100644 --- a/tests/vllm/test_model_forward.py +++ b/tests/vllm/test_model_forward.py @@ -50,11 +50,6 @@ def test_suite(self): _run_inner_class("TestAdapterIndicesWiring") -class TestControlTokenKVInvisibility: - def test_suite(self): - _run_inner_class("TestControlTokenKVInvisibility") - - class TestKVVisibility: def test_suite(self): _run_inner_class("TestKVVisibility") diff --git a/tests/vllm/test_position_zero_nan.py b/tests/vllm/test_position_zero_nan.py deleted file mode 100644 index 0bb0644..0000000 --- a/tests/vllm/test_position_zero_nan.py +++ /dev/null @@ -1,54 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""NaN regression tests — control token at sequence position 0 (vLLM backend). - -Runs _position_zero_nan_tests.py in a subprocess so the parent pytest process -never creates a CUDA context. -""" - -import importlib.util -import subprocess -import sys -from pathlib import Path - -import pytest - -_VLLM_AVAILABLE = importlib.util.find_spec("vllm") is not None - -pytestmark = pytest.mark.skipif( - not _VLLM_AVAILABLE, - reason="requires vLLM installed (GPU checked by inner tests)", -) - -_INNER = Path(__file__).parent / "_position_zero_nan_tests.py" -_TIMEOUT = 600 - - -def _run_inner_class(class_name): - cmd = [sys.executable, "-m", "pytest", str(_INNER), - "-v", "-s", "--tb=short", "-k", class_name] - result = subprocess.run(cmd, capture_output=True, text=True, timeout=_TIMEOUT) - if result.stdout: - print(result.stdout[-4000:]) - if result.stderr: - print("STDERR:", result.stderr[-2000:]) - assert result.returncode == 0, f"Inner tests failed (exit {result.returncode})" - - -class TestExpandControlDimensions: - def test_suite(self): - _run_inner_class("TestExpandControlDimensions") - - -class TestSDPANaN: - def test_suite(self): - _run_inner_class("TestSDPANaN") - - -class TestModelFiniteness: - def test_suite(self): - _run_inner_class("TestModelFiniteness") - - -class TestFixSensitivity: - def test_suite(self): - _run_inner_class("TestFixSensitivity") From 3e4b4672c2d74709c6516ccd1456abf76c250e43 Mon Sep 17 00:00:00 2001 From: AlonMalach Date: Sat, 16 May 2026 22:24:23 +0300 Subject: [PATCH 15/18] Document the LoRA substitute probe's data-independence assumption (#8) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per review feedback: tighten the probe's docstring to state the assumption it relies on — that the chat template emits a constant input_ids[0] regardless of message content, system-prompt presence, or generation-prompt flag — and call out that this is verified empirically for Granite 4.x (every realistic render shape produces <|start_of_role|>). Note what would need to change if a future base model's template breaks the assumption. No behavior change. --- .../composer/compose_granite_switch.py | 41 +++++++++++++------ 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/src/granite_switch/composer/compose_granite_switch.py b/src/granite_switch/composer/compose_granite_switch.py index 0f90b2a..0ea871f 100755 --- a/src/granite_switch/composer/compose_granite_switch.py +++ b/src/granite_switch/composer/compose_granite_switch.py @@ -78,19 +78,34 @@ def _load_tokenizer(model_name_or_path): def _probe_lora_substitute_token_id(tokenizer) -> int: - """Ask the tokenizer which token naturally appears at the start of a - rendered no-adapter chat. - - The LoRA prefix insertion prepends the adapter control token at the very - beginning of the rendered output, so whatever the template emits first - for a normal user turn is exactly what sits at position 1 after the - control token — and therefore the right substitute whose embedding - should land at the swap site. - - By deriving this from the tokenizer's own chat template at compose - time, we avoid hard-coding a Granite-4.x-specific token string - (<|start_of_role|>). Other base models with different chat templates - get the correct substitute for their template by construction. + """Ask the tokenizer which token naturally appears at sequence position 0 + of a rendered no-adapter chat. + + The LoRA prefix insertion places the adapter control token at sequence + position 0 of the rendered output. Whatever token would otherwise have + occupied position 0 (in a no-adapter render) is the right substitute + whose embedding should land at the swap site so the post-swap sequence + is indistinguishable from a no-adapter render. + + Assumption (Granite 4.x): the chat template emits a constant + ``input_ids[0]`` regardless of message content, system prompt presence, + or generation-prompt flag. Empirically verified — every realistic render + of the Granite 4.1 template yields ``<|start_of_role|>`` (id 100264) at + position 0. The probe renders a single minimal chat to read that + constant out of the template. + + A future model whose chat template branches on inputs at position 0 + (e.g. emits BOS only when no system message is present) would break + this assumption: the probe would still return *some* valid id, but it + might not match position 0 in another render mode at runtime, leaving + the LoRA control token swapped to an embedding the model doesn't + expect at that position. ``tests/composer/test_lora_substitute_probe.py`` + pins the Granite 4.x behavior; if you port to another base model with + a more dynamic template, extend the probe to render multiple shapes + and verify they all agree. + + By deriving the substitute from the tokenizer's own chat template at + compose time we avoid hard-coding a Granite-specific token string. Raises ``ValueError`` if the template is missing, fails to render, or emits an unknown token. From 9bb2a0c94f7c7a838228b8c80496e7c4bd799a5a Mon Sep 17 00:00:00 2001 From: AlonMalach Date: Mon, 18 May 2026 17:09:31 +0300 Subject: [PATCH 16/18] Remove residual position-correction references (#8) The position-correction code path was removed when token-exchange became the default. Drop the dead test class and the two stale parenthetical comments that still mentioned it. No behavioural change. --- src/granite_switch/hf/modeling_granite_switch.py | 10 +++------- tests/hf/test_token_exchange.py | 16 ---------------- 2 files changed, 3 insertions(+), 23 deletions(-) diff --git a/src/granite_switch/hf/modeling_granite_switch.py b/src/granite_switch/hf/modeling_granite_switch.py index dd23110..a6275d2 100644 --- a/src/granite_switch/hf/modeling_granite_switch.py +++ b/src/granite_switch/hf/modeling_granite_switch.py @@ -292,10 +292,9 @@ def forward( position_ids=position_ids, ) - # Compute adapter_indices using switch (BEFORE RoPE for position correction). - # The switch also returns modified_input_ids: input_ids with each - # control token rewritten to its substitute id, so the decoder can - # embed once without any token-exchange awareness. + # The switch returns adapter_indices alongside modified_input_ids: + # input_ids with each control token rewritten to its substitute id, + # so the decoder can embed once without any token-exchange awareness. modified_input_ids = input_ids if self.switch is not None: adapter_indices, modified_input_ids = self.switch( @@ -320,9 +319,6 @@ def forward( # Expose adapter_indices for tests and debugging. self._last_adapter_indices = adapter_indices - # Position embeddings (only if RoPE is configured). Control tokens - # in token-exchange mode count as real positions, so position_ids - # is used directly — no hidden_count subtraction. position_embeddings = None if self.rotary_emb is not None: position_embeddings = self.rotary_emb(inputs_embeds, position_ids=position_ids) diff --git a/tests/hf/test_token_exchange.py b/tests/hf/test_token_exchange.py index 11818e1..ae13392 100644 --- a/tests/hf/test_token_exchange.py +++ b/tests/hf/test_token_exchange.py @@ -107,19 +107,3 @@ def test_adapter_indices_still_activate(self): assert adapter_indices[0, 2].item() == 1 assert adapter_indices[0, 3].item() == 1 assert adapter_indices[0, 4].item() == 1 - - -class TestPositionCorrectionSkipped: - """In token-exchange mode, position correction is a no-op.""" - - def test_no_position_shift_in_te_mode(self): - """RoPE positions should equal the input positions (no hidden_count subtraction).""" - config = _build(substitute_ids=(5, 7)) - model = GraniteSwitchForCausalLM(config).eval() - input_ids = torch.tensor([[10, 100, 20, 30]], dtype=torch.long) - # Forward runs without error; the guarded branch would otherwise fire - # and shift positions by 1 for tokens 2/3. - with torch.no_grad(): - out = model(input_ids=input_ids) - # Sanity: logits shape matches input_ids shape. - assert out.logits.shape[:2] == input_ids.shape From 3c652601363c5729a2867ec5875ac65210aa8b9f Mon Sep 17 00:00:00 2001 From: AlonMalach Date: Tue, 19 May 2026 17:23:29 +0300 Subject: [PATCH 17/18] Add tests/vllm/test_token_exchange.py covering vLLM token-exchange path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mirrors tests/hf/test_token_exchange.py for the vLLM SingleSwitch backend. Closes the coverage gap raised by @antonpibm on PR #34: the token-exchange LUT and modified_input_ids rewrite were tested only on HF — vLLM had zero direct assertions on either. New file `tests/vllm/test_token_exchange.py`: - TestLUTMapping: query_lut command returns control_to_substitute_lut; asserts lut[ctrl_id] == sub_id for each adapter, lut[other] == -1 - TestInputRewrite: forward_with_modified command returns both adapter_indices and modified_input_ids; asserts non-control positions unchanged, control positions rewritten to substitute, multi-control sequence handles each independently, and adapter detection still fires on the original (pre-rewrite) input_ids Worker changes (`tests/vllm/_single_switch_worker.py`): - Mock config now populates adapter_token_ids + adapter_substitute_token_ids so SingleSwitch builds the LUT — production configs always have these. Existing tests are unaffected (they discard modified_input_ids). - New _run_with_modified() helper returns both forward outputs as lists - New "forward_with_modified" and "query_lut" commands wired into the request loop. The pre-existing "forward" command is unchanged. Substitute mapping in worker: control id (1000+i) → substitute id (i+1), matched by ADAPTER_SUBSTITUTE_TOKEN_IDS_LIST in the new test file. --- tests/vllm/_single_switch_worker.py | 81 +++++++++ tests/vllm/test_token_exchange.py | 254 ++++++++++++++++++++++++++++ 2 files changed, 335 insertions(+) create mode 100644 tests/vllm/test_token_exchange.py diff --git a/tests/vllm/_single_switch_worker.py b/tests/vllm/_single_switch_worker.py index 91439fe..cdd8ad4 100644 --- a/tests/vllm/_single_switch_worker.py +++ b/tests/vllm/_single_switch_worker.py @@ -48,14 +48,23 @@ def _setup(): MAX_TOKENS = 131_072 NUM_ADAPTERS = 32 ADAPTER_TOKEN_IDS_LIST = list(range(1000, 1000 + NUM_ADAPTERS)) + # Deterministic substitute mapping for token-exchange tests: + # control id 1000+i → substitute id i+1 (i.e. 1, 2, ..., NUM_ADAPTERS). + # Substitute ids must be < vocab_size and != any control id. + ADAPTER_SUBSTITUTE_TOKEN_IDS_LIST = [i + 1 for i in range(NUM_ADAPTERS)] # Mock config with realistic backbone geometry (GQA: 4Q/2KV, head_dim=64) # so unit tests exercise the multi-head path, not the fallback. + # adapter_token_ids + adapter_substitute_token_ids enable the + # control_to_substitute_lut path, which production configs always have. mock_config = SimpleNamespace( num_attention_heads=4, num_key_value_heads=2, projection_head_dim=64, attention_multiplier=0.125, + vocab_size=2000, + adapter_token_ids=ADAPTER_TOKEN_IDS_LIST, + adapter_substitute_token_ids=ADAPTER_SUBSTITUTE_TOKEN_IDS_LIST, ) device = torch.device("cuda") @@ -98,6 +107,7 @@ def _setup(): "kv_cache": kv_cache, "device": device, "layer_name": layer_name, + "adapter_substitute_token_ids_list": ADAPTER_SUBSTITUTE_TOKEN_IDS_LIST, "backend_name": backend_name, "block_size": BLOCK_SIZE, "adapter_token_ids_list": ADAPTER_TOKEN_IDS_LIST, @@ -226,6 +236,66 @@ def _run(harness, seq, num_adapters, control_token_gain): return adapter_indices.cpu().tolist() +def _run_with_modified(harness, seq, num_adapters, control_token_gain): + """Execute SingleSwitch.forward and return BOTH outputs as lists. + + Used by token-exchange tests that need to inspect modified_input_ids + in addition to adapter_indices. Same setup as _run, but the response + is a dict {"adapter_indices": [...], "modified_input_ids": [...]}. + """ + from vllm.forward_context import ForwardContext, override_forward_context + + switch = harness["switch"] + vllm_config = harness["vllm_config"] + kv_cache = harness["kv_cache"] + device = harness["device"] + layer_name = harness["layer_name"] + adapter_token_ids_list = harness["adapter_token_ids_list"] + + seq_len = len(seq) + kv_cache.zero_() + + orig_gain = switch.control_token_gain + orig_effective_gain = switch.effective_gain + orig_num_adapters = switch.num_adapters + switch.control_token_gain = control_token_gain + switch.effective_gain = control_token_gain / switch.scaling + switch.num_adapters = num_adapters + + input_ids = torch.tensor(seq, dtype=torch.long, device=device) + adapter_token_ids = torch.tensor( + adapter_token_ids_list[:num_adapters], dtype=torch.long, device=device, + ) + + metadata, slot_mapping = _build_metadata(harness, seq_len) + + forward_ctx = ForwardContext( + no_compile_layers=vllm_config.compilation_config.static_forward_context, + attn_metadata={layer_name: metadata}, + slot_mapping={layer_name: slot_mapping}, + ) + + old_direct = switch.attn.use_direct_call + switch.attn.use_direct_call = True + + try: + with override_forward_context(forward_ctx): + adapter_indices, modified_input_ids = switch.forward( + input_ids=input_ids, + adapter_token_ids=adapter_token_ids, + ) + finally: + switch.attn.use_direct_call = old_direct + switch.control_token_gain = orig_gain + switch.effective_gain = orig_effective_gain + switch.num_adapters = orig_num_adapters + + return { + "adapter_indices": adapter_indices.cpu().tolist(), + "modified_input_ids": modified_input_ids.cpu().tolist(), + } + + def _query_geometry(harness): """Return switch geometry and cache info for infrastructure tests.""" switch = harness["switch"] @@ -316,6 +386,17 @@ def main(): control_token_gain=req.get("control_token_gain", 15.0), ) resp = {"result": result} + elif command == "forward_with_modified": + result = _run_with_modified( + harness, + seq=req["seq"], + num_adapters=req.get("num_adapters", 32), + control_token_gain=req.get("control_token_gain", 15.0), + ) + resp = {"result": result} + elif command == "query_lut": + lut = harness["switch"].control_to_substitute_lut + resp = {"result": lut.cpu().tolist() if lut is not None else None} else: resp = {"error": f"Unknown command: {command}"} except Exception: diff --git a/tests/vllm/test_token_exchange.py b/tests/vllm/test_token_exchange.py new file mode 100644 index 0000000..faac66f --- /dev/null +++ b/tests/vllm/test_token_exchange.py @@ -0,0 +1,254 @@ +# SPDX-License-Identifier: Apache-2.0 +"""vLLM backend tests for token-exchange mode. + +Mirrors tests/hf/test_token_exchange.py — verifies that on the vLLM +SingleSwitch path: + +1. The control_to_substitute_lut tensor maps each adapter control token id + to its configured substitute id, and leaves all other ids at -1. +2. Non-control positions in modified_input_ids are unchanged from the + original input_ids tensor. +3. Control positions in modified_input_ids are rewritten to the + substitute id from the LUT. + +Tests #2 and #3 require a forward pass, so they go through the long-lived +SingleSwitch worker subprocess (the same one used by test_single_switch.py) +via two new commands: 'query_lut' and 'forward_with_modified'. The worker's +mock config now populates adapter_token_ids + adapter_substitute_token_ids +so the LUT path is exercised — see _single_switch_worker.py:_setup. + +Requires CUDA GPU and vLLM installed. All tests skipped otherwise. +All GPU work happens in the subprocess worker — the parent pytest process +never creates a CUDA context (required for Exclusive_Process GPU mode). +""" + +import atexit +import importlib.util +import json +import subprocess +import sys +import threading +from pathlib import Path + +import pytest + +_VLLM_AVAILABLE = importlib.util.find_spec("vllm") is not None + +pytestmark = pytest.mark.skipif( + not _VLLM_AVAILABLE, + reason="requires vLLM installed (GPU checked by worker)", +) + +from tests.shared.single_switch_cases import ( + NUM_ADAPTERS, + TEXT_TOKEN, + ADAPTER_TOKEN_IDS_LIST, +) + +# Worker's deterministic substitute mapping: control_id (1000+i) → sub_id (i+1). +# Matches ADAPTER_SUBSTITUTE_TOKEN_IDS_LIST in _single_switch_worker.py:_setup. +ADAPTER_SUBSTITUTE_TOKEN_IDS_LIST = [i + 1 for i in range(NUM_ADAPTERS)] + + +# ── Worker management ───────────────────────────────────────────── +# Same pattern as test_single_switch.py — own module-private worker so +# pytest can run the two files independently or together. + +_WORKER_PATH = Path(__file__).parent / "_single_switch_worker.py" +_worker_proc = None +_worker_lock = threading.Lock() +_fatal_startup_error = None + + +def _ensure_worker(): + global _worker_proc, _fatal_startup_error + if _fatal_startup_error is not None: + pytest.fail(_fatal_startup_error, pytrace=False) + if _worker_proc is not None and _worker_proc.poll() is None: + return + proc = subprocess.Popen( + [sys.executable, str(_WORKER_PATH)], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + bufsize=1, + ) + ready_line = proc.stdout.readline() + if not ready_line: + stderr = proc.stderr.read() + raise RuntimeError(f"Worker failed to start:\n{stderr}") + ready = json.loads(ready_line) + if "fatal" in ready: + try: + proc.wait(timeout=5) + except subprocess.TimeoutExpired: + proc.kill() + stderr_tail = (proc.stderr.read() or "")[-2000:] + backend = ready.get("backend_name", "unknown") + _fatal_startup_error = ( + f"vLLM worker cannot start: {ready['fatal']}\n" + f"Backend: {backend}\n" + f"Hint: {ready.get('hint', '')}\n" + f"--- worker stderr (tail) ---\n{stderr_tail}" + ) + pytest.fail(_fatal_startup_error, pytrace=False) + assert ready.get("ready"), f"Unexpected ready message: {ready}" + _worker_proc = proc + atexit.register(_shutdown_worker) + + +def _shutdown_worker(): + global _worker_proc + if _worker_proc is not None and _worker_proc.poll() is None: + _worker_proc.stdin.close() + _worker_proc.wait(timeout=30) + _worker_proc = None + + +def _send_command(req): + """Send a JSON request to the worker and return its 'result' field.""" + _ensure_worker() + with _worker_lock: + _worker_proc.stdin.write(json.dumps(req) + "\n") + _worker_proc.stdin.flush() + resp_line = _worker_proc.stdout.readline() + if not resp_line: + stderr = _worker_proc.stderr.read() + raise RuntimeError(f"Worker died unexpectedly:\n{stderr}") + resp = json.loads(resp_line) + if "error" in resp: + raise RuntimeError(f"Worker error:\n{resp['error']}") + return resp["result"] + + +@pytest.fixture(autouse=True, scope="module") +def _worker_lifecycle(): + yield + _shutdown_worker() + + +# ── Tests ───────────────────────────────────────────────────────── + + +class TestLUTMapping: + """control_to_substitute_lut is the canonical control→substitute table. + + It is built once at SingleSwitch construction from + config.adapter_token_ids + config.adapter_substitute_token_ids; tested + here against the worker's mock config (control 1000+i → substitute i+1). + """ + + def test_lut_maps_control_to_substitute(self): + lut = _send_command({"command": "query_lut"}) + assert lut is not None, ( + "control_to_substitute_lut was None — adapter_substitute_token_ids " + "missing from worker mock config?" + ) + for ctrl_id, sub_id in zip( + ADAPTER_TOKEN_IDS_LIST, ADAPTER_SUBSTITUTE_TOKEN_IDS_LIST + ): + assert lut[ctrl_id] == sub_id, ( + f"lut[{ctrl_id}]={lut[ctrl_id]}, expected substitute {sub_id}" + ) + + def test_lut_marks_non_control_with_sentinel(self): + lut = _send_command({"command": "query_lut"}) + assert lut is not None + # TEXT_TOKEN (50) and a few arbitrary non-control ids should be -1. + for non_control in [TEXT_TOKEN, 0, 51, 52, 999]: + assert lut[non_control] == -1, ( + f"lut[{non_control}]={lut[non_control]}, expected -1 sentinel" + ) + + +class TestInputRewrite: + """SingleSwitch.forward returns (adapter_indices, modified_input_ids). + + modified_input_ids must equal input_ids at non-control positions and + equal lut[ctrl_id] (the substitute) at control positions. The decoder + embeds modified_input_ids; the switch itself reads the original + input_ids so adapter detection is unaffected. + """ + + def test_non_control_positions_unchanged(self): + # Mix of non-control tokens with one control token in the middle. + ctrl_id = ADAPTER_TOKEN_IDS_LIST[0] + seq = [TEXT_TOKEN, 51, ctrl_id, 53, 54] + result = _send_command( + { + "command": "forward_with_modified", + "seq": seq, + "num_adapters": 4, + "control_token_gain": 15.0, + } + ) + modified = result["modified_input_ids"] + # Positions 0, 1, 3, 4 are non-control — must be unchanged. + assert modified[0] == seq[0] + assert modified[1] == seq[1] + assert modified[3] == seq[3] + assert modified[4] == seq[4] + + def test_control_positions_rewritten_to_substitute(self): + # Control token at position 2 — must be rewritten to its substitute. + ctrl_id = ADAPTER_TOKEN_IDS_LIST[0] + expected_sub = ADAPTER_SUBSTITUTE_TOKEN_IDS_LIST[0] + seq = [TEXT_TOKEN, 51, ctrl_id, 53, 54] + result = _send_command( + { + "command": "forward_with_modified", + "seq": seq, + "num_adapters": 4, + "control_token_gain": 15.0, + } + ) + modified = result["modified_input_ids"] + assert modified[2] == expected_sub, ( + f"control position rewrite failed: got {modified[2]}, " + f"expected substitute {expected_sub}" + ) + + def test_multiple_control_tokens_each_rewritten(self): + # Two distinct control tokens; each must map to its own substitute. + ctrl0 = ADAPTER_TOKEN_IDS_LIST[0] + ctrl1 = ADAPTER_TOKEN_IDS_LIST[1] + sub0 = ADAPTER_SUBSTITUTE_TOKEN_IDS_LIST[0] + sub1 = ADAPTER_SUBSTITUTE_TOKEN_IDS_LIST[1] + seq = [TEXT_TOKEN, ctrl0, TEXT_TOKEN, ctrl1, TEXT_TOKEN] + result = _send_command( + { + "command": "forward_with_modified", + "seq": seq, + "num_adapters": 4, + "control_token_gain": 15.0, + } + ) + modified = result["modified_input_ids"] + assert modified[0] == TEXT_TOKEN + assert modified[1] == sub0 + assert modified[2] == TEXT_TOKEN + assert modified[3] == sub1 + assert modified[4] == TEXT_TOKEN + + def test_switch_still_detects_adapter_after_rewrite(self): + # The rewrite must NOT confuse adapter detection — the switch reads + # the original input_ids before the rewrite happens. + ctrl_id = ADAPTER_TOKEN_IDS_LIST[2] + seq = [TEXT_TOKEN, ctrl_id, TEXT_TOKEN, TEXT_TOKEN] + result = _send_command( + { + "command": "forward_with_modified", + "seq": seq, + "num_adapters": 4, + "control_token_gain": 15.0, + } + ) + adapter_indices = result["adapter_indices"] + # Position 0 fires before any control: adapter 0 (base). + # Position 1 is the control for adapter index 3 (1-indexed: ctrl_idx 2 → adapter 3). + # SingleSwitch persists adapter id once fired → positions 1+ all 3. + assert adapter_indices[0] == 0 + assert adapter_indices[1] == 3 + assert adapter_indices[2] == 3 + assert adapter_indices[3] == 3 From 0d382042c4c877167e79fa2281b05a5508e68de0 Mon Sep 17 00:00:00 2001 From: AlonMalach Date: Tue, 19 May 2026 17:30:10 +0300 Subject: [PATCH 18/18] Move LUT buffer to CUDA in worker setup MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previous commit added adapter_substitute_token_ids to the worker's mock config to enable token-exchange tests. SingleSwitch.__init__ registers control_to_substitute_lut as a CPU buffer (no device specified), and the worker never explicitly calls switch.to(device) — the Q/K/V tensors are built on CUDA directly in forward(). On Ampere this caused the very first forward to raise: RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu) …inside the new lut[input_ids] indexing operation. The fail-fast probe caught it cleanly and broke every existing test that uses the worker. Fix: after switch construction, move control_to_substitute_lut to the same CUDA device the worker uses for everything else. Test-only fix; production code path (SingleSwitch.__init__) is untouched. --- tests/vllm/_single_switch_worker.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/vllm/_single_switch_worker.py b/tests/vllm/_single_switch_worker.py index cdd8ad4..4f7e632 100644 --- a/tests/vllm/_single_switch_worker.py +++ b/tests/vllm/_single_switch_worker.py @@ -81,6 +81,14 @@ def _setup(): control_token_gain=15.0, config=mock_config, ) + # Move buffers to CUDA so the LUT (registered as a CPU buffer in + # SingleSwitch.__init__) can index the CUDA input_ids during forward. + # The Q/K/V tensors in SingleSwitch.forward() are constructed directly + # on CUDA so they don't otherwise force a .to() call. + if switch.control_to_substitute_lut is not None: + switch.control_to_substitute_lut = ( + switch.control_to_substitute_lut.to(device) + ) finally: torch.set_default_dtype(old_dtype)