diff --git a/.gitignore b/.gitignore index f4f70e5..a5dcceb 100644 --- a/.gitignore +++ b/.gitignore @@ -49,4 +49,9 @@ TEST_STATUS_REPORT.md htmlcov/ # pyenv version file (local dev preference) -.python-version \ No newline at end of file +.python-version + +# Local design/planning doc (keep on disk, do not version) +docs/KV_CACHE_OVERHEAD_REMOVAL.md +docs/KV_CACHE_OVERHEAD_REMOVAL.html +docs/KV_CACHE_OVERHEAD_REMOVAL*.html \ No newline at end of file diff --git a/src/granite_switch/composer/arch.py b/src/granite_switch/composer/arch.py index f41735e..22e70fb 100644 --- a/src/granite_switch/composer/arch.py +++ b/src/granite_switch/composer/arch.py @@ -119,9 +119,7 @@ class ArchDescriptor: default_factory=lambda: [ "adapter_token_ids", "adapter_scalings", - "token_to_group_mask", - "adapter_hiding_matrix", - "all_hiding_group_token_ids", + "control_to_substitute_lut", ] ) diff --git a/src/granite_switch/composer/compose_granite_switch.py b/src/granite_switch/composer/compose_granite_switch.py index 08e1786..0ea871f 100755 --- a/src/granite_switch/composer/compose_granite_switch.py +++ b/src/granite_switch/composer/compose_granite_switch.py @@ -62,6 +62,7 @@ from granite_switch.composer.tokenizer_setup import ( add_control_tokens, configure_chat_template, + get_alora_first_invocation_token_id, ) from granite_switch.composer.reporting import generate_compose_report, write_build_doc @@ -76,6 +77,70 @@ def _load_tokenizer(model_name_or_path): return AutoTokenizer.from_pretrained(model_name_or_path) +def _probe_lora_substitute_token_id(tokenizer) -> int: + """Ask the tokenizer which token naturally appears at sequence position 0 + of a rendered no-adapter chat. + + The LoRA prefix insertion places the adapter control token at sequence + position 0 of the rendered output. Whatever token would otherwise have + occupied position 0 (in a no-adapter render) is the right substitute + whose embedding should land at the swap site so the post-swap sequence + is indistinguishable from a no-adapter render. + + Assumption (Granite 4.x): the chat template emits a constant + ``input_ids[0]`` regardless of message content, system prompt presence, + or generation-prompt flag. Empirically verified — every realistic render + of the Granite 4.1 template yields ``<|start_of_role|>`` (id 100264) at + position 0. The probe renders a single minimal chat to read that + constant out of the template. + + A future model whose chat template branches on inputs at position 0 + (e.g. emits BOS only when no system message is present) would break + this assumption: the probe would still return *some* valid id, but it + might not match position 0 in another render mode at runtime, leaving + the LoRA control token swapped to an embedding the model doesn't + expect at that position. ``tests/composer/test_lora_substitute_probe.py`` + pins the Granite 4.x behavior; if you port to another base model with + a more dynamic template, extend the probe to render multiple shapes + and verify they all agree. + + By deriving the substitute from the tokenizer's own chat template at + compose time we avoid hard-coding a Granite-specific token string. + + Raises ``ValueError`` if the template is missing, fails to render, or + emits an unknown token. + """ + if tokenizer.chat_template is None: + raise ValueError( + "Tokenizer has no chat_template; cannot probe the LoRA " + "substitute token." + ) + try: + probe_text = tokenizer.apply_chat_template( + [{"role": "user", "content": "probe"}], + tokenize=False, + add_generation_prompt=False, + ) + except Exception as e: + raise ValueError( + "Failed to render a probe chat via tokenizer.apply_chat_template " + f"while detecting the LoRA substitute token: {e!r}." + ) from e + ids = tokenizer(probe_text, add_special_tokens=False).input_ids + if not ids: + raise ValueError( + "Probe chat tokenized to an empty id list; cannot determine the " + "LoRA substitute token." + ) + sub_id = ids[0] + if sub_id == tokenizer.unk_token_id: + raise ValueError( + "First token of the rendered probe chat is ; the template " + "appears to emit content outside the tokenizer's vocabulary." + ) + return sub_id + + def _get_directory_size(directory): """Return ``(total_size in GBs, file_count)`` for *directory*.""" if Path(directory).exists(): @@ -449,12 +514,6 @@ def _compose_argparser(): default=None, help="Dimension of Q/K/V vectors in switch attention", ) - parser.add_argument( - "--control-dims", - type=int, - default=None, - help="Extra dims for K/V to mask control tokens in decoder layers", - ) parser.add_argument( "--built-in-adapters", type=str, @@ -678,9 +737,8 @@ def build(): has_external = len(external_discovered) > 0 has_built_in = len(built_in_discovered) > 0 - # Mode detection: - # Mode A (native): built-in only → no hiding, control_dims=0 - # Mode B (third-party): externals present → full hiding + # Mode detection (informational only — token-exchange handles both + # native and third-party adapter builds uniformly). if has_built_in and not has_external: build_mode = "native" elif has_external: @@ -692,7 +750,6 @@ def build(): # Extract fields from 4-tuples (path, name, tech, source) adapter_paths = [t[0] for t in all_discovered if t[0] is not None] adapter_names = [t[1] for t in all_discovered] - external_names = [t[1] for t in external_discovered] built_in_names = [name for name in (args.built_in_adapters or [])] print(f"\nBuild mode: {build_mode}") @@ -747,33 +804,30 @@ def build(): optional_kwargs = {} if args.switch_head_dim is not None: optional_kwargs["switch_head_dim"] = args.switch_head_dim - if args.control_dims is not None: - optional_kwargs["control_dims"] = args.control_dims - - # Per-mode hiding configuration - if build_mode == "native": - # Mode A (native): no hiding, control_dims=0 (unless overridden) - hiding_groups = None - hiding_policy = None - adapter_third_party = None - if "control_dims" not in optional_kwargs: - optional_kwargs["control_dims"] = 0 - else: - # Mode B (third-party): full hiding for external adapters - hiding_groups = {"all_controls": list(adapter_names)} - hiding_policy = {name: ["all_controls"] for name in adapter_names} - hiding_policy["base"] = ["all_controls"] - # Only external adapters are third-party - adapter_third_party = list(external_names) + + # Token-exchange substitute choice (must mirror the token that appears + # right after the control token in the rendered chat prompt, so the + # swap keeps the residual stream in-distribution): + # - ALoRA: first token of the adapter's alora_invocation_tokens. + # - LoRA/builtin: whatever the tokenizer's chat template emits at + # the very start of a no-adapter user turn. For Granite 4.x that's + # <|start_of_role|>; the probe derives this from the template at + # compose time so other base models work by construction. + lora_sub_id = _probe_lora_substitute_token_id(tokenizer) + adapter_substitute_token_ids = [] + for adapter_path, _name, technology, _source in all_discovered: + if technology == "alora": + sub_id = get_alora_first_invocation_token_id(adapter_path) + else: + sub_id = lora_sub_id + adapter_substitute_token_ids.append(sub_id) model = GraniteSwitchComposer.from_base_and_adapters( base_model_name_or_path=base_model_local_path, adapter_paths=adapter_paths, adapter_token_ids=adapter_token_ids, + adapter_substitute_token_ids=adapter_substitute_token_ids, adapter_names=adapter_names, - hiding_groups=hiding_groups, - hiding_policy=hiding_policy, - adapter_third_party=adapter_third_party, built_in_adapter_names=built_in_names, built_in_lora_rank=args.lora_rank, built_in_lora_alpha=args.lora_alpha if args.lora_alpha is not None else float(args.lora_rank), diff --git a/src/granite_switch/composer/compose_utils.py b/src/granite_switch/composer/compose_utils.py index d230f27..dabc47f 100644 --- a/src/granite_switch/composer/compose_utils.py +++ b/src/granite_switch/composer/compose_utils.py @@ -25,6 +25,7 @@ def from_base_and_adapters( base_model_name_or_path: str, adapter_paths: Optional[List[str]] = None, adapter_token_ids: Optional[List[int]] = None, + adapter_substitute_token_ids: Optional[List[int]] = None, adapter_names: Optional[List[str]] = None, built_in_adapter_names: Optional[List[str]] = None, built_in_lora_rank: int = 8, @@ -48,6 +49,9 @@ def from_base_and_adapters( empty for zero-adapter skinning (base model only). adapter_token_ids: Token IDs for adapter control. Required when ``adapter_paths`` is non-empty. + adapter_substitute_token_ids: Token IDs whose embeddings replace + control-token embeddings at the switch. Required when + ``adapter_paths`` is non-empty; one per adapter. adapter_names: Display names for each adapter (external + built-in). When ``None``, derived from the directory structure. built_in_adapter_names: Names for built-in (empty LoRA) adapter slots. @@ -112,10 +116,6 @@ def from_base_and_adapters( source_analysis = {} # --- Step 4: Build switch config from arch descriptor --- - hiding_groups = kwargs.pop("hiding_groups", None) - hiding_policy = kwargs.pop("hiding_policy", None) - adapter_third_party = kwargs.pop("adapter_third_party", None) - # Copy config fields driven by architecture descriptor config_kwargs: Dict = {} @@ -151,17 +151,15 @@ def from_base_and_adapters( { "num_adapters": num_total, "adapter_token_ids": adapter_token_ids, + "adapter_substitute_token_ids": adapter_substitute_token_ids, "adapter_names": adapter_names, - "hiding_groups": hiding_groups, - "hiding_policy": hiding_policy, - "adapter_third_party": adapter_third_party, "max_lora_rank": lora_rank, "adapter_ranks": adapter_ranks, "lora_target_modules": lora_target_modules, } ) - # Merge caller-provided overrides (switch_head_dim, control_dims, etc.) + # Merge caller-provided overrides (switch_head_dim, etc.) config_kwargs.update(kwargs) switch_config = GraniteSwitchConfig(**config_kwargs) diff --git a/src/granite_switch/composer/reporting/__init__.py b/src/granite_switch/composer/reporting/__init__.py index 687c2f7..14fca07 100644 --- a/src/granite_switch/composer/reporting/__init__.py +++ b/src/granite_switch/composer/reporting/__init__.py @@ -4,7 +4,6 @@ from .population_table import generate_adapter_population_table, print_adapter_population_table from .compose_report import generate_compose_report from .adapter_analysis import print_source_adapter_analysis -from .hiding_constant_report import compute_hiding_constant_safety, print_hiding_constant_safety from .model_card import render_model_card, write_model_card, write_build_doc __all__ = [ @@ -12,8 +11,6 @@ 'print_adapter_population_table', 'generate_compose_report', 'print_source_adapter_analysis', - 'compute_hiding_constant_safety', - 'print_hiding_constant_safety', 'render_model_card', 'write_model_card', 'write_build_doc', diff --git a/src/granite_switch/composer/reporting/compose_report.py b/src/granite_switch/composer/reporting/compose_report.py index b392f12..7537e77 100644 --- a/src/granite_switch/composer/reporting/compose_report.py +++ b/src/granite_switch/composer/reporting/compose_report.py @@ -325,11 +325,6 @@ def _print_summary( if len(base_source_not_connected) > 10: print(f" ... and {len(base_source_not_connected) - 10} more") - # Hiding constant safety margin - if model is not None: - from .hiding_constant_report import print_hiding_constant_safety - print_hiding_constant_safety(model.dtype) - print(f"\nDetailed report saved to: {report_path}") print("="*80) diff --git a/src/granite_switch/composer/reporting/hiding_constant_report.py b/src/granite_switch/composer/reporting/hiding_constant_report.py deleted file mode 100644 index c54860a..0000000 --- a/src/granite_switch/composer/reporting/hiding_constant_report.py +++ /dev/null @@ -1,58 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""Safety margin report for the K-side hiding constant. - -The hiding mechanism uses finfo(dtype).min as the K-side control dimension value -for tokens in hiding groups. This module computes and reports the safety margin: -how large a positive attention score would need to be before the hiding breaks -(i.e., exp(fmin + score) > 0). -""" - -import torch - - -def _find_exp_underflow_threshold(dtype: torch.dtype) -> float: - """Find the smallest (most negative) x where exp(x) > 0 for the given dtype. - - Searches from -500 upward in steps of 0.5. - """ - for x_int in range(-1000, 0): - x = x_int * 0.5 - x_t = torch.tensor(x, dtype=dtype) - if torch.exp(x_t).item() > 0.0: - return x - return 0.0 # fallback: exp(x) > 0 for all tested x - - -def compute_hiding_constant_safety(dtype: torch.dtype) -> dict: - """Compute safety margin data for the hiding constant at the given dtype. - - Returns a dict with: - fmin: the hiding constant value - exp_underflow_threshold: smallest x where exp(x) > 0 - safety_margin: positive value that must be added to fmin to break hiding - """ - fmin_val = torch.finfo(dtype).min - exp_threshold = _find_exp_underflow_threshold(dtype) - safety_margin = abs(fmin_val) + exp_threshold # exp_threshold is negative - - return { - "dtype": str(dtype), - "fmin": fmin_val, - "exp_underflow_threshold": exp_threshold, - "safety_margin": safety_margin, - } - - -def print_hiding_constant_safety(dtype: torch.dtype): - """Print the hiding constant safety margin report for the given dtype.""" - data = compute_hiding_constant_safety(dtype) - - print(f"\n{'='*80}") - print("CONTROL DIMENSION HIDING CONSTANT") - print(f"{'='*80}") - print(f" Model dtype: {data['dtype']}") - print(f" Hiding constant (finfo.min): {data['fmin']:.6e}") - print(f" exp(hiding_constant) underflows to zero: True") - print(f" exp() underflow threshold: {data['exp_underflow_threshold']}") - print(f" Safety margin: a positive attention score of {data['safety_margin']:.6e}") - print(f" would be needed to break hiding (make exp(fmin + score) > 0)") diff --git a/src/granite_switch/composer/reporting/model_card.py b/src/granite_switch/composer/reporting/model_card.py index 721e4cb..8c3505f 100644 --- a/src/granite_switch/composer/reporting/model_card.py +++ b/src/granite_switch/composer/reporting/model_card.py @@ -391,7 +391,9 @@ def _short_source(source): "lora_rank": getattr(args, "lora_rank", None) if built_in else None, "lora_alpha": getattr(args, "lora_alpha", None) if built_in else None, "switch_head_dim": getattr(args, "switch_head_dim", None), - "control_dims": getattr(args, "control_dims", None), + "adapter_substitute_token_ids": getattr( + model.config, "adapter_substitute_token_ids", None + ), "target_model": getattr(args, "target_model", None), } # Parameter counts: base is captured during transfer (see diff --git a/src/granite_switch/composer/tokenizer_setup.py b/src/granite_switch/composer/tokenizer_setup.py index c5ba7b5..5f0a118 100644 --- a/src/granite_switch/composer/tokenizer_setup.py +++ b/src/granite_switch/composer/tokenizer_setup.py @@ -11,12 +11,8 @@ from typing import Dict, List, Optional, Tuple -def _decode_alora_invocation_text(adapter_path: str, tokenizer) -> str: - """Decode alora_invocation_tokens from adapter_config.json to a string. - - The activation control token must be inserted immediately before the first - token of the invocation sequence. Decoding the full sequence gives the text - span to search for in the rendered message content. +def _load_alora_invocation_token_ids(adapter_path: str) -> List[int]: + """Load alora_invocation_tokens from adapter_config.json. Raises: FileNotFoundError: If adapter_config.json is not found at adapter_path. @@ -31,10 +27,29 @@ def _decode_alora_invocation_text(adapter_path: str, tokenizer) -> str: raise ValueError( f"alora_invocation_tokens is missing or empty in {config_path}" ) + return token_ids + +def _decode_alora_invocation_text(adapter_path: str, tokenizer) -> str: + """Decode alora_invocation_tokens from adapter_config.json to a string. + + The activation control token must be inserted immediately before the first + token of the invocation sequence. Decoding the full sequence gives the text + span to search for in the rendered message content. + """ + token_ids = _load_alora_invocation_token_ids(adapter_path) return tokenizer.decode(token_ids, skip_special_tokens=False) +def get_alora_first_invocation_token_id(adapter_path: str) -> int: + """Return the first token ID of an ALoRA adapter's invocation sequence. + + Used by token-exchange mode to substitute this embedding for the adapter's + control token before the decoder runs. + """ + return _load_alora_invocation_token_ids(adapter_path)[0] + + def add_control_tokens( tokenizer, discovered_adapters: List[Tuple[Optional[str], str, str, Optional[str]]], @@ -177,9 +192,16 @@ def configure_chat_template( """ + # LoRA prefix: emit the control token at the sequence start AND arm + # skip_next_start_of_role so the template's very next <|start_of_role|> + # emission is suppressed. This avoids a duplicate-embedding OOD at runtime: + # the runtime swap replaces the control token's embedding with + # <|start_of_role|>'s embedding, and without this drop the sequence + # would carry two identical embeddings back-to-back. lora_prefix_insertion = """{#- For lora adapters: insert activation token at the very beginning -#} {%- if adapter_token and adapter_type == 'lora' %} {{- adapter_token }} +{%- set ns.skip_next_start_of_role = true %} {%- endif %} """ @@ -213,23 +235,46 @@ def configure_chat_template( # Pass 2: runs inside the main message loop after content.val is assembled. # rsplit(..., 1) splits on the last occurrence so the token lands in the # right place when the invocation text appears more than once in the message. - alora_pass2 = """ {#- ALoRA Pass 2: inject activation token before invocation text in the target message -#} + # + # Token drop (mirrors the <|start_of_role|> skip-once flag used for LoRA / + # assistant-boundary ALoRA): we also omit the FIRST CHARACTER of the + # invocation text. The runtime embedding swap replaces the control-token + # embedding with the first-invocation-token's embedding; writing the full + # invocation text after the control token would then produce two copies + # of that first-invocation-token back to back — an OOD pattern at the + # swap site. + # + # For every ALoRA invocation text in the standard Granite adapter library + # (, , , , etc.) the first + # character is a single '<' that the tokenizer emits as its own token, + # and the tail of the string retokenizes identically to the tail of the + # full string. So dropping the first character on the string side is + # equivalent to dropping exactly the first token on the tokenized side — + # no re-merging, no change to what follows. + alora_pass2 = """ {#- ALoRA Pass 2: inject activation token AND drop the first char of + the invocation text so the runtime-swapped embedding doesn't duplicate. -#} {%- if loop.index0 == ns.alora_target_idx %} {%- set _parts = content.val.rsplit(ns.adapter_invocation_text, 1) %} {%- if _parts | length > 1 %} - {%- set content.val = _parts[0] + ns.adapter_token + ns.adapter_invocation_text + _parts[1] %} + {%- set content.val = _parts[0] + ns.adapter_token + ns.adapter_invocation_text[1:] + _parts[1] %} {%- endif %} {%- endif %} """ # Fallback for adapters whose invocation sequence is the assistant role tokens: # Pass 1 never sets alora_target_idx >= 0 for those, so we emit here instead. + # Also arm skip_next_start_of_role so the generation-prompt <|start_of_role|> + # that would immediately follow is suppressed — mirrors the LoRA rationale: + # the runtime swap replaces the control token's embedding with the first + # invocation token's embedding (<|start_of_role|>), so without this drop the + # sequence would carry two identical embeddings back-to-back. alora_insertion = """{#- ALoRA fallback: insert activation token right before generation prompt. Only fires when Pass 1 found no user message with the invocation text (alora_target_idx == -1), meaning the adapter activates at the assistant role token boundary rather than inside a user message. -#} {%- if ns.adapter_token and ns.adapter_type == 'alora' and ns.alora_target_idx == -1 %} {{- ns.adapter_token }} +{%- set ns.skip_next_start_of_role = true %} {%- endif %} """ @@ -271,7 +316,8 @@ def configure_chat_template( "\n adapter_token=adapter_token," "\n adapter_type=adapter_type," "\n adapter_invocation_text=adapter_invocation_text," - "\n alora_target_idx=-1" + "\n alora_target_idx=-1," + "\n skip_next_start_of_role=false" "\n )" ) modified_chat_template = ( @@ -322,6 +368,61 @@ def configure_chat_template( else: modified_chat_template += "\n" + alora_insertion + # Skip-once wrapper for every <|start_of_role|> emission in the template. + # ns.skip_next_start_of_role is set to true immediately after a LoRA or + # assistant-boundary ALoRA control token is emitted; the very next role + # marker consumes the flag and is suppressed. Prevents a duplicate + # embedding at position 1 (see lora_prefix_insertion / alora_insertion + # comments). + # + # Every <|start_of_role|> in the base template appears inside a string + # literal, either merged with the following role text ('<|start_of_role|>user<|end_of_role|>') + # or standalone ('<|start_of_role|>' + message.role + ...). We split at + # the '<|start_of_role|>' boundary and route only that fragment through + # the skip-once Jinja block. + skip_once_block = ( + "{%- if ns.skip_next_start_of_role %}" + "{%- set ns.skip_next_start_of_role = false %}" + "{%- else %}" + "{{- '<|start_of_role|>' }}" + "{%- endif %}" + ) + # Case A: '<|start_of_role|>' as a standalone literal, possibly at the + # start of a concatenation ({{- '<|start_of_role|>' + expr + ... }}). + # Replace the literal emission with the skip block; the rest of the + # expression stays. Handles sites 77 and 79 directly. + modified_chat_template = re.sub( + r"\{\{-\s*'<\|start_of_role\|>'\s*\+\s*", + skip_once_block + "\n {{- ", + modified_chat_template, + ) + # Case B: '<|start_of_role|>ROLE<|end_of_role|>' merged literal (with or + # without trailing concatenation). Split the literal so only the + # '<|start_of_role|>' prefix goes through the skip block and the rest + # ('ROLE<|end_of_role|>' + anything) emits normally. + # Pattern: {{- 'literal_starting_with_start_of_role' (+ expr | ) }} + def _split_merged(match: "re.Match") -> str: + remainder = match.group(1) # text after <|start_of_role|> up to end of literal + tail = match.group(2) # trailing + expr or empty + return ( + skip_once_block + + "\n {{- '" + + remainder + + "'" + + tail + + " }}" + ) + + # Merged literal like '<|start_of_role|>system<|end_of_role|>' followed by + # optional " + expr + ...". The first group captures everything inside the + # literal after <|start_of_role|>; the second captures any trailing + # concatenation up to the closing }}. + modified_chat_template = re.sub( + r"\{\{-\s*'<\|start_of_role\|>([^']*)'((?:\s*\+\s*[^}]+?)?)\s*\}\}", + _split_merged, + modified_chat_template, + ) + tokenizer.chat_template = modified_chat_template print( f"Chat template configured with {len(adapter_mapping)} adapter mappings:" diff --git a/src/granite_switch/config.py b/src/granite_switch/config.py index 026797e..7824002 100644 --- a/src/granite_switch/config.py +++ b/src/granite_switch/config.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 """Configuration for Granite model with adapter switching.""" -from typing import List, Optional, Dict +from typing import List, Optional from transformers import GraniteMoeHybridConfig @@ -9,38 +9,30 @@ class GraniteSwitchConfig(GraniteMoeHybridConfig): """Configuration class for GraniteSwitch model. - Extends the Granite base config with parameters for adapter switching using - the SingleSwitch mechanism. - - Inherits from GraniteMoeHybridConfig (the transformers base class for - Granite 4 models) and adds adapter routing parameters. + Extends the Granite base config with parameters for adapter switching + using the SingleSwitch mechanism. Control tokens are handled exclusively + via token exchange: the switch reads ``input_ids``, decides the active + adapter, and rewrites each control token to its substitute id (from + ``adapter_substitute_token_ids``) before the decoder embeds the + sequence. The decoder is unaware of the substitution. Args: num_adapters (int): Number of LoRA adapters available. Default: 0 (no adapters). This counts real LoRA adapters only (not base). Index 0 always means "base / no adapter". adapter_token_ids (List[int]): Token IDs for adapter control. - Length: num_adapters (one token per real adapter). + Length: num_adapters (one token per real adapter). Must be unique. adapter_token_ids[i] activates adapter i+1 (1-indexed output). Output 0 = base (implicit default, no token needed to return to base). NOTE: SingleSwitch cannot transition back to base mid-sequence. + adapter_substitute_token_ids (List[int]): Token IDs whose embeddings + replace the control-token embeddings before the decoder runs. + Length: num_adapters. Required when num_adapters > 0. SingleSwitch parameters: control_token_gain (float): Attention gain for control/non-control separation. Default: 15.0. switch_head_dim (int): Dimension of Q/K/V vectors in switch attention. Default: 32. - control_dims (int): Extra dimensions for K/V to mask control tokens. Must be >= 0. Default: 32. adapter_names (List[str]): Ordered adapter names for name-to-index mapping. - Used by hiding_groups and hiding_policy to resolve names to indices. - hiding_groups (Dict[str, List[str]]): Hiding group definitions. - Maps group_name → list of adapter names whose control tokens belong to this group. - Each group uses one control dimension. Requires control_dims >= len(hiding_groups). - hiding_policy (Dict[str, List[str]]): Per-adapter hiding policy. - Maps adapter_name → list of group names that adapter hides. Use "base" for the - base adapter (adapter_index 0). - adapter_third_party (List[str]): Adapter names that are third-party (externally trained). - Third-party adapters were not trained with control tokens in their vocabulary, - which affects KV hiding policy. - max_lora_rank (int): Maximum rank across all LoRA adapters (for allocation). Default: 8. adapter_ranks (List[int]): Per-adapter ranks. Must have length equal to num_adapters. lora_target_modules (List[str]): List of module GROUP names to apply LoRA to. @@ -55,16 +47,12 @@ def __init__( self, num_adapters: int = 0, adapter_token_ids: Optional[List[int]] = None, + adapter_substitute_token_ids: Optional[List[int]] = None, # SingleSwitch parameters control_token_gain: float = 15.0, switch_head_dim: int = 32, - control_dims: int = 32, - # Hiding groups and policy - adapter_names: Optional[List[str]] = None, - hiding_groups: Optional[Dict[str, List[str]]] = None, - hiding_policy: Optional[Dict[str, List[str]]] = None, - adapter_third_party: Optional[List[str]] = None, # Adapter parameters + adapter_names: Optional[List[str]] = None, max_lora_rank: int = 8, adapter_ranks: List[int] = None, lora_target_modules: Optional[List[str]] = None, @@ -109,40 +97,52 @@ def __init__( f"adapter_token_ids length ({len(adapter_token_ids)}) must equal " f"num_adapters ({num_adapters})." ) + # Token-exchange builds the control→substitute LUT keyed by adapter token id; + # duplicates would silently collapse to a single slot. + if len(set(adapter_token_ids)) != len(adapter_token_ids): + raise ValueError( + f"adapter_token_ids must be unique; got {adapter_token_ids}" + ) self.adapter_token_ids = adapter_token_ids + # Validate adapter_substitute_token_ids — required when num_adapters > 0. + if num_adapters > 0: + if adapter_substitute_token_ids is None: + raise ValueError( + "adapter_substitute_token_ids is required when num_adapters > 0. " + "Every adapter needs a substitute token id whose embedding replaces " + "the control-token embedding before the decoder runs." + ) + if len(adapter_substitute_token_ids) != num_adapters: + raise ValueError( + f"adapter_substitute_token_ids length " + f"({len(adapter_substitute_token_ids)}) must equal num_adapters " + f"({num_adapters})." + ) + if any(sid < 0 for sid in adapter_substitute_token_ids): + raise ValueError( + f"adapter_substitute_token_ids must all be >= 0 (real token ids); " + f"got {adapter_substitute_token_ids}" + ) + if adapter_token_ids is None: + raise ValueError( + "adapter_token_ids is required when adapter_substitute_token_ids " + "is provided (token-exchange maps control ids to substitute ids)." + ) + self.adapter_substitute_token_ids = adapter_substitute_token_ids + # SingleSwitch parameters self.control_token_gain = control_token_gain self.switch_head_dim = switch_head_dim - if control_dims < 0: - raise ValueError( - f"control_dims must be >= 0 (got {control_dims}). " - "Use control_dims=0 for native mode (no KV hiding). " - "Use control_dims >= 1 for third-party mode (KV cache masking)." - ) - self.control_dims = control_dims self.fused_add_norm = fused_add_norm - # Hiding groups and policy + # Adapter names self.adapter_names = adapter_names - self.hiding_groups = hiding_groups - self.hiding_policy = hiding_policy - self.adapter_third_party = adapter_third_party - # Validate control_dims >= num_hiding_groups - if hiding_groups is not None and len(hiding_groups) > control_dims: - raise ValueError( - f"control_dims ({control_dims}) must be >= number of hiding groups " - f"({len(hiding_groups)}). Each hiding group uses one control dimension." - ) - - # KV cache head dimension vs. projection dimension. + # Projection head dimension. # The QKV projection outputs vectors of size projection_head_dim - # (= hidden_size / num_attention_heads). The KV cache stores larger - # vectors (projection_head_dim + control_dims) for exact attention - # masking of control tokens. The expanded size is communicated to - # vLLM via a custom ModelArchConfigConvertor (registered in vllm/__init__.py) - # so that hybrid page-size calculations use the correct value. + # (= hidden_size / num_attention_heads). The KV cache stores native- + # head_dim tensors — no expansion under token exchange. # We do NOT set head_dim here because HF's RoPE also reads it. # Use explicit head_dim from kwargs when available (some models have # head_dim != hidden_size // num_attention_heads). @@ -192,99 +192,3 @@ def __init__( ]) self.lora_target_modules = lora_target_modules - - @property - def expanded_head_dim(self) -> int: - """KV cache head dimension: projection_head_dim + control_dims when adapters are active.""" - if self.num_adapters > 0 and self.control_dims > 0: - return self.projection_head_dim + self.control_dims - return self.projection_head_dim - - @property - def num_hiding_groups(self) -> int: - """Number of hiding groups (each uses one control dimension).""" - if self.hiding_groups is None: - return 0 - return len(self.hiding_groups) - - @property - def hiding_group_names(self) -> List[str]: - """Ordered list of hiding group names (determines control dim indices).""" - if self.hiding_groups is None: - return [] - return list(self.hiding_groups.keys()) - - def get_hiding_group_token_ids(self) -> Dict[int, List[int]]: - """Map group index → list of token IDs in that group. - - Resolves adapter names to their activating token IDs using - adapter_names and adapter_token_ids. - - Returns empty dict if no hiding groups configured. - """ - if self.hiding_groups is None or self.adapter_names is None: - return {} - if self.adapter_token_ids is None: - return {} - - # Build name → token ID mapping (no offset for SingleSwitch) - name_to_token_id = {} - for i, name in enumerate(self.adapter_names): - name_to_token_id[name] = self.adapter_token_ids[i] - - result = {} - for group_idx, group_name in enumerate(self.hiding_group_names): - adapter_names_in_group = self.hiding_groups[group_name] - token_ids = [] - for name in adapter_names_in_group: - if name in name_to_token_id: - token_ids.append(name_to_token_id[name]) - result[group_idx] = token_ids - return result - - def get_third_party_adapter_mask(self) -> List[bool]: - """Return per-adapter-slot boolean: True if the adapter is third-party. - - Index 0 = base (never third-party). Index 1+ = real adapters. - Length = num_adapters + 1 (one slot per adapter index including base). - - Third-party adapters were not trained with control tokens in their - vocabulary, which affects KV hiding policy. - - Returns all-False list if adapter_third_party is not configured. - """ - num_slots = self.num_adapters + 1 # base + adapters - if not self.adapter_third_party or not self.adapter_names: - return [False] * num_slots - - tp_set = set(self.adapter_third_party) - # Index 0 = base (never third-party) - mask = [False] - for name in self.adapter_names: - mask.append(name in tp_set) - return mask - - def get_adapter_hiding_policy_matrix(self) -> List[List[bool]]: - """Build adapter hiding policy matrix: [num_adapter_slots][num_groups]. - - Index 0 = base adapter. Index 1+ = real adapters (matching adapter_names order). - Each entry is True if that adapter hides that group. - - Returns empty list if no hiding policy configured. - """ - if self.hiding_policy is None or self.adapter_names is None: - return [] - - num_groups = self.num_hiding_groups - group_names = self.hiding_group_names - - # Build ordered adapter list: [base, adapter_0, adapter_1, ...] - all_adapter_names = ["base"] + list(self.adapter_names) - num_slots = len(all_adapter_names) - - matrix = [] - for adapter_name in all_adapter_names: - groups_to_hide = self.hiding_policy.get(adapter_name, []) - row = [gn in groups_to_hide for gn in group_names] - matrix.append(row) - return matrix diff --git a/src/granite_switch/hf/core/lora.py b/src/granite_switch/hf/core/lora.py index 97e97ce..648abc3 100644 --- a/src/granite_switch/hf/core/lora.py +++ b/src/granite_switch/hf/core/lora.py @@ -371,13 +371,6 @@ def __init__(self, config: GraniteSwitchConfig, layer_idx: int): self.q_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps) - # Control dimension expansion for KV cache masking. - # Expand only when adapters present AND control_dims > 0. - # control_dims=0 means native mode: no KV hiding, no expansion. - self.expand_control_dims = config.num_adapters > 0 and config.control_dims > 0 - self.control_dims = config.control_dims - self.expanded_head_dim = self.head_dim + self.control_dims - # Fused QKV projection - conditionally add LoRA based on config q_size = self.num_heads * self.head_dim kv_size = self.num_key_value_heads * self.head_dim @@ -425,94 +418,10 @@ def __init__(self, config: GraniteSwitchConfig, layer_idx: int): ) self.has_o_lora = False - def _expand_with_control_dimensions( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - token_group_membership: Optional[torch.Tensor], - query_group_suppression: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Expand Q, K, V with control dimensions for group-based KV cache hiding. - - Always called when num_adapters > 0 (static shape decision). - Each hiding group g uses one control dimension: - - K-side: finfo(dtype).min for tokens that are members of group g - - Q-side: 1.0 for queries whose adapter suppresses group g, - except for tokens that are themselves in group g - - When both tensors are None, all control dims are zero (no masking effect). - - Args: - q: Query tensor [batch, seq_len, num_heads, head_dim] - k: Key tensor [batch, seq_len, num_kv_heads, head_dim] - v: Value tensor [batch, seq_len, num_kv_heads, head_dim] - token_group_membership: [batch, seq_len, num_groups] — True if token is in group g - query_group_suppression: [batch, seq_len, num_groups] — True if token's adapter suppresses group g - - Returns: - Expanded Q, K, V tensors with control_dims added to head_dim - """ - batch_size, seq_len = q.shape[:2] - device = q.device - dtype = q.dtype - - # Allocate control dimensions (initialized to zero) - q_control = torch.zeros( - batch_size, seq_len, self.num_heads, self.control_dims, - device=device, dtype=dtype - ) - k_control = torch.zeros( - batch_size, seq_len, self.num_key_value_heads, self.control_dims, - device=device, dtype=dtype - ) - v_control = torch.zeros( - batch_size, seq_len, self.num_key_value_heads, self.control_dims, - device=device, dtype=dtype - ) - - # K-side: brand each group-member token's key with finfo.min in its group's - # control dim so that suppressing queries score it as −∞. - # token_group_membership: [batch, seq, num_groups] - # → expand to [batch, seq, num_kv_heads, num_groups] - if token_group_membership is not None: - num_groups = token_group_membership.shape[-1] - hiding_constant = torch.finfo(dtype).min - k_control[:, :, :, :num_groups] = ( - token_group_membership.unsqueeze(2) - .expand(-1, -1, self.num_key_value_heads, -1) - .to(dtype) * hiding_constant - ) - - # Q-side: set control dim g to 1.0 for queries whose adapter suppresses group g. - # query_group_suppression: [batch, seq, num_groups] - # → expand to [batch, seq, num_heads, num_groups] - # Tokens that are themselves in group g are excluded: when the control token - # sits at position 0 it has no other causal key to attend to, so suppressing - # its own key yields softmax([−∞]) = NaN. - if query_group_suppression is not None: - num_groups = query_group_suppression.shape[-1] - q_hide = query_group_suppression.to(dtype) - if token_group_membership is not None: - q_hide = q_hide * (1 - token_group_membership.to(dtype)) - q_control[:, :, :, :num_groups] = ( - q_hide.unsqueeze(2) - .expand(-1, -1, self.num_heads, -1) - ) - - # Concatenate original dims + control dims - q = torch.cat([q, q_control], dim=-1) # [batch, seq_len, num_heads, head_dim + control_dims] - k = torch.cat([k, k_control], dim=-1) # [batch, seq_len, num_kv_heads, head_dim + control_dims] - v = torch.cat([v, v_control], dim=-1) # [batch, seq_len, num_kv_heads, head_dim + control_dims] - - return q, k, v - def forward( self, hidden_states: torch.Tensor, adapter_indices: torch.Tensor, - token_group_membership: Optional[torch.Tensor], - query_group_suppression: Optional[torch.Tensor], position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[Cache] = None, @@ -525,8 +434,6 @@ def forward( Args: hidden_states: Input tensor [batch, seq_len, hidden_size] adapter_indices: Per-token adapter selection [batch, seq_len] - token_group_membership: [batch, seq_len, num_groups] — True if token is in group g, or None - query_group_suppression: [batch, seq_len, num_groups] — True if token's adapter suppresses group g, or None position_embeddings: Precomputed (cos, sin) for RoPE attention_mask: Attention mask past_key_values: Cache object for KV caching @@ -573,15 +480,6 @@ def forward( query_states = query_states_t.transpose(1, 2) key_states = key_states_t.transpose(1, 2) - # Control dimension expansion: always when adapters are present. - # Group masks control which tokens/groups get K=finfo.min masking - # (can be None if no hiding groups, but expansion still happens). - if self.expand_control_dims: - query_states, key_states, value_states = self._expand_with_control_dimensions( - query_states, key_states, value_states, - token_group_membership, query_group_suppression, - ) - # Belief that both cache and attention expect [batch, heads, seq, dim] key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) @@ -612,12 +510,6 @@ def forward( sliding_window=getattr(self.config, "sliding_window", None), ) - # Trim control dimensions from output - if self.expand_control_dims: - # attn_output shape: [batch, num_heads, seq_len, expanded_head_dim] - # Trim to original head_dim - attn_output = attn_output[..., :self.head_dim] - # Reshape and project output - conditionally use LoRA attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) if self.has_o_lora: diff --git a/src/granite_switch/hf/modeling_granite_switch.py b/src/granite_switch/hf/modeling_granite_switch.py index 277d947..a6275d2 100644 --- a/src/granite_switch/hf/modeling_granite_switch.py +++ b/src/granite_switch/hf/modeling_granite_switch.py @@ -82,8 +82,6 @@ def forward( cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, adapter_indices: Optional[torch.Tensor] = None, - token_group_membership: Optional[torch.Tensor] = None, - query_group_suppression: Optional[torch.Tensor] = None, **kwargs, ) -> tuple: residual = hidden_states @@ -93,8 +91,6 @@ def forward( hidden_states, self_attn_weights, present_key_values = self.self_attn( hidden_states=hidden_states, adapter_indices=adapter_indices, - token_group_membership=token_group_membership, - query_group_suppression=query_group_suppression, position_embeddings=position_embeddings, attention_mask=attention_mask, past_key_values=past_key_values, @@ -188,44 +184,14 @@ def __init__(self, config: GraniteSwitchConfig): torch.zeros(config.num_adapters, dtype=torch.long), ) - # --- Hiding group buffers --- - # token_to_group_mask: [vocab_size, num_groups] lookup table. - # For each token ID, True at group g if that token belongs to group g. - # Enables O(1) per-token group membership via: mask = table[input_ids] - num_groups = config.num_hiding_groups - if num_groups > 0: - group_token_ids = config.get_hiding_group_token_ids() - # Size must cover all token IDs including added control tokens - # which may have IDs >= config.vocab_size. - all_known_ids = [tid for tids in group_token_ids.values() for tid in tids] - if config.adapter_token_ids: - all_known_ids.extend(config.adapter_token_ids) - max_tid = max(all_known_ids) if all_known_ids else -1 - table_size = max(config.vocab_size, max_tid + 1) - token_to_group_mask = torch.zeros( - table_size, num_groups, dtype=torch.bool - ) - for g, tids in group_token_ids.items(): - for tid in tids: - token_to_group_mask[tid, g] = True - self.register_buffer("token_to_group_mask", token_to_group_mask) - - # adapter_hiding_matrix: [num_adapter_slots, num_groups] boolean. - # Index 0 = base, 1+ = adapters. True if adapter hides group g. - policy_matrix = config.get_adapter_hiding_policy_matrix() - self.register_buffer( - "adapter_hiding_matrix", - torch.tensor(policy_matrix, dtype=torch.bool), - ) - else: - self.token_to_group_mask = None - self.adapter_hiding_matrix = None + # Token-exchange LUT lives on the switch module (see hf/switch/ + # single.py); the switch rewrites input_ids in-place during its + # forward pass, so this model class no longer needs a decoder- + # side substitute table. else: self.switch = None self.adapter_token_ids = None - self.token_to_group_mask = None - self.adapter_hiding_matrix = None # Decoder layers if config.num_adapters > 0: @@ -287,87 +253,75 @@ def forward( ) use_cache = False - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - inputs_embeds = inputs_embeds * self.embedding_multiplier - # Initialize cache if use_cache and past_key_values is None: past_key_values = DynamicCache(config=self.config) + # Determine sequence shape and device. With input_ids we get them + # directly; with pre-supplied inputs_embeds we read from the tensor. + if input_ids is not None: + batch_size, seq_length = input_ids.shape + device = input_ids.device + else: + batch_size, seq_length = inputs_embeds.shape[:2] + device = inputs_embeds.device + if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + past_seen_tokens, past_seen_tokens + seq_length, device=device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) - # Causal mask (4D for attention layers) + # Causal mask (4D for attention layers). create_causal_mask only + # uses the embedding tensor for batch/query/dtype inference; we + # haven't embedded yet (the switch call below may rewrite input_ids + # first), so pass a stub of the right shape/dtype. + embed_dtype = self.embed_tokens.weight.dtype + mask_shape_proxy = inputs_embeds if inputs_embeds is not None else torch.empty( + batch_size, seq_length, 1, device=device, dtype=embed_dtype + ) causal_mask = create_causal_mask( config=self.config, - input_embeds=inputs_embeds, + input_embeds=mask_shape_proxy, attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, position_ids=position_ids, ) - # Compute adapter_indices using switch (BEFORE RoPE for position correction) - hidden_count = None + # The switch returns adapter_indices alongside modified_input_ids: + # input_ids with each control token rewritten to its substitute id, + # so the decoder can embed once without any token-exchange awareness. + modified_input_ids = input_ids if self.switch is not None: - adapter_indices = self.switch( + adapter_indices, modified_input_ids = self.switch( input_ids=input_ids, adapter_token_ids=self.adapter_token_ids, attention_mask=causal_mask, past_key_values=past_key_values, cache_position=cache_position, ) - - # Compute group-based hiding masks from lookup tables. - if self.token_to_group_mask is not None: - # token_group_membership: True at [b, i, g] if token i is a member of group g - token_group_membership = self.token_to_group_mask[input_ids] - # query_group_suppression: True at [b, i, g] if token i's adapter suppresses group g - query_group_suppression = self.adapter_hiding_matrix[adapter_indices] - else: - token_group_membership = None - query_group_suppression = None - - # Compute hidden_count for position correction (SingleSwitch). - # SingleSwitch fires once: hidden_count is 0 before the control - # token and 1 at/after it, which is exactly (adapter_indices > 0). - if hidden_count is None: - hidden_count = (adapter_indices > 0).long() else: - batch_size, seq_length = inputs_embeds.shape[:2] adapter_indices = torch.zeros( - (batch_size, seq_length), - dtype=torch.long, - device=inputs_embeds.device + (batch_size, seq_length), dtype=torch.long, device=device, ) - token_group_membership = None - query_group_suppression = None + + # Embed once, on the (possibly-rewritten) input_ids. The decoder is + # token-exchange-agnostic — it just embeds whatever the switch + # passed through. + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(modified_input_ids) + inputs_embeds = inputs_embeds * self.embedding_multiplier # Expose adapter_indices for tests and debugging. self._last_adapter_indices = adapter_indices - # Position correction: adjust position_ids to close gaps from hidden tokens. - # Clamp to >= 0: pre-init tokens have no hidden tokens in their causal - # past, but the counting mechanism returns capacity-1 when all attention - # keys are masked, which would produce negative positions and OOB RoPE - # cache indices. - if hidden_count is not None: - adjusted_position_ids = torch.clamp(position_ids - hidden_count, min=0) - else: - adjusted_position_ids = position_ids - - # Position embeddings (only if RoPE is configured) position_embeddings = None if self.rotary_emb is not None: - position_embeddings = self.rotary_emb(inputs_embeds, position_ids=adjusted_position_ids) + position_embeddings = self.rotary_emb(inputs_embeds, position_ids=position_ids) # Decoder layers hidden_states = inputs_embeds @@ -388,8 +342,6 @@ def forward( cache_position=cache_position, position_embeddings=position_embeddings, adapter_indices=adapter_indices, - token_group_membership=token_group_membership, - query_group_suppression=query_group_suppression, **kwargs, ) diff --git a/src/granite_switch/hf/switch/single.py b/src/granite_switch/hf/switch/single.py index 7a26a29..6891fbd 100644 --- a/src/granite_switch/hf/switch/single.py +++ b/src/granite_switch/hf/switch/single.py @@ -14,7 +14,7 @@ import torch import torch.nn as nn -from typing import Optional +from typing import Optional, Tuple from transformers.cache_utils import Cache from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.models.granite.modeling_granite import eager_attention_forward @@ -57,11 +57,14 @@ def __init__( self.control_token_gain = control_token_gain self.config = config - # Use expanded_head_dim to align with decoder layers across both backends. - if config is not None and hasattr(config, 'expanded_head_dim') and getattr(config, 'num_adapters', 0) > 0: - self.head_dim = config.expanded_head_dim - elif config is not None: - self.head_dim = config.hidden_size // config.num_attention_heads + # Align with the decoder's native head_dim. (Under token exchange the + # KV cache no longer carries any expansion, so this is just the + # base-model projection_head_dim.) + if config is not None: + self.head_dim = getattr( + config, "projection_head_dim", + config.hidden_size // config.num_attention_heads, + ) else: self.head_dim = switch_head_dim @@ -78,6 +81,28 @@ def __init__( # Switch is layer 0, decoder layers are 1 to num_hidden_layers self.layer_idx = layer_idx + # control_to_substitute_lut: [vocab_size_or_higher], -1 at non-control + # ids and the substitute id at each control slot. The switch performs + # the runtime token-exchange: it rewrites input_ids in-place so that + # control-token positions carry the substitute id by the time the + # decoder embeds them. The decoder is then oblivious — it just calls + # embed_tokens(input_ids) and gets the right result by construction. + if ( + config is not None + and getattr(config, "adapter_token_ids", None) is not None + and getattr(config, "adapter_substitute_token_ids", None) is not None + ): + ctrl_ids = config.adapter_token_ids + sub_ids = config.adapter_substitute_token_ids + max_ctrl_id = max(ctrl_ids) + lut_size = max(getattr(config, "vocab_size", 0), max_ctrl_id + 1) + lut = torch.full((lut_size,), -1, dtype=torch.long) + for ctrl_id, sub_id in zip(ctrl_ids, sub_ids): + lut[ctrl_id] = sub_id + self.register_buffer("control_to_substitute_lut", lut) + else: + self.control_to_substitute_lut = None + @property def num_cache_layers(self) -> int: """Number of cache slots used.""" @@ -90,13 +115,19 @@ def forward( attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Compute adapter indices using single-head attention mechanism. + Compute adapter indices and rewrite control tokens via the LUT. + + The switch performs both halves of token-exchange: + 1. Adapter selection: read input_ids, detect control tokens via + input_ids == adapter_token_ids, emit per-token adapter_indices. + 2. Token rewrite: replace each control token's id in input_ids + with its substitute id (from control_to_substitute_lut). - The switch uses the same head_dim as decoder layers to share the model's Cache object, - ensuring standard HuggingFace behavior where past_key_values is exposed and managed - by the caller. + Returning the rewritten input_ids means the decoder is oblivious to + the swap — it simply embeds whatever it's given. There's no + decoder-side LUT, no per-forward scatter, no clone-guard. Args: input_ids: Input token IDs [batch, seq_len] @@ -110,7 +141,10 @@ def forward( cache_position: Position indices for caching [seq_len] Returns: - adapter_indices: [batch, seq_len] where 0 = base, 1+ = adapters + (adapter_indices, modified_input_ids): + adapter_indices: [batch, seq_len] where 0 = base, 1+ = adapters. + modified_input_ids: [batch, seq_len] with each control-token + id replaced by its substitute id. """ bsz, q_len = input_ids.shape device = input_ids.device @@ -195,4 +229,20 @@ def forward( f"adapter_indices shape {adapter_indices.shape} must match input_ids shape {input_ids.shape}" ) - return adapter_indices + # Token-exchange rewrite: replace each control token's id with its + # substitute id via the LUT. Done here (rather than in the decoder) + # so the decoder sees a clean, unified input_ids and never has to + # know about substitutes. Skipped only when the LUT was not built + # (no substitute ids configured — e.g. a non-token-exchange test + # fixture). Kept symmetric with the vLLM switch, which forbids the + # `tensor.any()` short-circuit under @support_torch_compile. + if self.control_to_substitute_lut is not None: + sub_id_per_pos = self.control_to_substitute_lut[input_ids] + is_control = sub_id_per_pos >= 0 + modified_input_ids = torch.where( + is_control, sub_id_per_pos, input_ids + ) + else: + modified_input_ids = input_ids + + return adapter_indices, modified_input_ids diff --git a/src/granite_switch/vllm/__init__.py b/src/granite_switch/vllm/__init__.py index ba52afb..eb6401b 100644 --- a/src/granite_switch/vllm/__init__.py +++ b/src/granite_switch/vllm/__init__.py @@ -52,10 +52,12 @@ def register(): except Exception: pass - # Register custom ModelArchConfigConvertor so vLLM sees the correct - # KV cache head size. When adapters use control_dims, the decoder - # attention stores expanded vectors (projection_head_dim + control_dims) - # in the KV cache. + # Register custom ModelArchConfigConvertor so vLLM sees: + # 1. The correct decoder layer count (excluding the switch's KV-cache + # placeholder slot). + # 2. The native KV cache head size (projection_head_dim). Token + # exchange does not expand the head dim, so this is just the base + # model's head_dim. try: from vllm.transformers_utils.model_arch_config_convertor import ( MODEL_ARCH_CONFIG_CONVERTORS, @@ -76,15 +78,7 @@ def get_num_hidden_layers(self) -> int: def get_head_size(self) -> int: cfg = self.hf_text_config - if hasattr(cfg, 'expanded_head_dim'): - return cfg.expanded_head_dim - # Fallback for configs without the property - base = super().get_head_size() - num_adapters = getattr(cfg, "num_adapters", 0) - control_dims = getattr(cfg, "control_dims", 32) - if num_adapters > 0 and control_dims > 0: - return base + control_dims - return base + return getattr(cfg, "projection_head_dim", super().get_head_size()) MODEL_ARCH_CONFIG_CONVERTORS["granite_switch"] = ( _GraniteSwitchArchConfigConvertor diff --git a/src/granite_switch/vllm/core/decoder.py b/src/granite_switch/vllm/core/decoder.py index 66090d0..565d1cc 100644 --- a/src/granite_switch/vllm/core/decoder.py +++ b/src/granite_switch/vllm/core/decoder.py @@ -78,12 +78,6 @@ def __init__( self.kv_size = self.num_kv_heads * self.head_dim self.scaling = config.attention_multiplier - # Control dimension expansion: expand only when adapters present AND - # control_dims > 0. control_dims=0 means native mode (no KV hiding). - self.expand_control_dims = num_adapters > 0 and config.control_dims > 0 - self.control_dims = config.control_dims - self.expanded_head_dim = self.head_dim + self.control_dims - # QKV projection - conditionally add LoRA based on config base_qkv_proj = QKVParallelLinear( self.hidden_size, @@ -143,11 +137,10 @@ def __init__( else: self.rotary_emb = None - # Attention layer — use expanded head dim only when expansion is active - self.attn_head_dim = self.expanded_head_dim if self.expand_control_dims else self.head_dim + # Attention layer — head_dim is the native projection_head_dim. self.attn = Attention( self.num_heads, - self.attn_head_dim, + self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, @@ -155,81 +148,12 @@ def __init__( prefix=f"{prefix}.attn", ) - def _expand_with_control_dimensions( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - token_group_membership: Optional[torch.Tensor], - query_group_suppression: Optional[torch.Tensor], - ) -> tuple: - """Expand Q, K, V with control dimensions for group-based KV cache hiding. - - Always called when num_adapters > 0 (static shape decision). - Each hiding group g uses one control dimension: - - K-side: finfo(dtype).min for tokens that are members of group g - - Q-side: 1.0 for queries whose adapter suppresses group g, - except for tokens that are themselves in group g - - When both tensors are None, all control dims are zero (no masking effect). - """ - num_tokens = q.size(0) - device = q.device - dtype = q.dtype - - q = q.view(num_tokens, self.num_heads, self.head_dim) - k = k.view(num_tokens, self.num_kv_heads, self.head_dim) - v = v.view(num_tokens, self.num_kv_heads, self.head_dim) - - q_control = torch.zeros(num_tokens, self.num_heads, self.control_dims, device=device, dtype=dtype) - k_control = torch.zeros(num_tokens, self.num_kv_heads, self.control_dims, device=device, dtype=dtype) - v_control = torch.zeros(num_tokens, self.num_kv_heads, self.control_dims, device=device, dtype=dtype) - - # K-side: brand each group-member token's key with finfo.min in its group's - # control dim so that suppressing queries score it as −∞. - # token_group_membership: [num_tokens, num_groups] — True if token is in group g - if token_group_membership is not None: - num_groups = token_group_membership.shape[-1] - hiding_constant = torch.finfo(dtype).min - k_control[:, :, :num_groups] = ( - token_group_membership.unsqueeze(1) - .expand(-1, self.num_kv_heads, -1) - .to(dtype) * hiding_constant - ) - - # Q-side: set control dim g to 1.0 for queries whose adapter suppresses group g. - # query_group_suppression: [num_tokens, num_groups] — True if this token's - # adapter suppresses group g. - # Tokens that are themselves in group g are excluded: when the control token - # sits at position 0 it has no other causal key to attend to, so suppressing - # its own key yields softmax([−∞]) = NaN. - if query_group_suppression is not None: - num_groups = query_group_suppression.shape[-1] - q_hide = query_group_suppression.to(dtype) - if token_group_membership is not None: - q_hide = q_hide * (1 - token_group_membership.to(dtype)) - q_control[:, :, :num_groups] = ( - q_hide.unsqueeze(1) - .expand(-1, self.num_heads, -1) - ) - - q = torch.cat([q, q_control], dim=-1) - k = torch.cat([k, k_control], dim=-1) - v = torch.cat([v, v_control], dim=-1) - - q = q.view(num_tokens, self.num_heads * self.expanded_head_dim) - k = k.view(num_tokens, self.num_kv_heads * self.expanded_head_dim) - v = v.view(num_tokens, self.num_kv_heads * self.expanded_head_dim) - - return q, k, v - def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: - # SwitchedLoRALinear reads LoRA metadata from shared LoRAContext; - # hiding group masks for control dims also come from the context. + # SwitchedLoRALinear reads LoRA metadata from the shared LoRAContext. qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -242,20 +166,7 @@ def forward( if self.rotary_emb is not None: q, k = self.rotary_emb(positions, q, k) - if self.expand_control_dims: - token_group_membership = self._lora_ctx.token_group_membership if self._lora_ctx is not None else None - query_group_suppression = self._lora_ctx.query_group_suppression if self._lora_ctx is not None else None - q, k, v = self._expand_with_control_dimensions( - q, k, v, token_group_membership, query_group_suppression, - ) - attn_output = self.attn(q, k, v) - - if self.expand_control_dims: - attn_output = attn_output.view(-1, self.num_heads, self.expanded_head_dim)[ - ..., :self.head_dim - ].reshape(-1, self.num_heads * self.head_dim) - output, _ = self.o_proj(attn_output) return output diff --git a/src/granite_switch/vllm/granite_switch_model.py b/src/granite_switch/vllm/granite_switch_model.py index b94fb61..3f90278 100644 --- a/src/granite_switch/vllm/granite_switch_model.py +++ b/src/granite_switch/vllm/granite_switch_model.py @@ -81,17 +81,10 @@ class GraniteSwitchModel(nn.Module): 3. Base transformer layers with LoRA 4. LM head - The switch detects special tokens and selects the appropriate adapter. - Adapter indices are passed as arguments to LoRA layers. - - To mitigate the contribution of control tokens in the base model and adapter computations: - - Each layer's k and v values are augmented with a control dimension set to - k=-inf for control tokens and k=0 otherwise (v=0 throughout), prior to attention - calculation. After softmax attention is computed, the value is reduced to its - original dimension. - - The logits for control tokens are set to -inf in compute_logits() to prevent - the sampler from generating control tokens - - Position correction via hidden_count closes RoPE gaps from KV-hidden control tokens + The switch detects special tokens, selects the appropriate adapter, and + rewrites each control token's id to its substitute id (token exchange). + The decoder embeds the rewritten ids and is otherwise oblivious to the + substitution. Adapter indices are passed as arguments to LoRA layers. """ def __init__( @@ -155,6 +148,11 @@ def __init__( torch.zeros(num_adapters, dtype=torch.long), ) + # Token-exchange LUT lives on the switch module + # (see vllm/switch/single.py); the switch rewrites input_ids + # in-place during its forward pass, so this model class no + # longer needs a decoder-side substitute table. + # Initialize compile-friendly LoRA metadata handler # This replaces vLLM's LoRAKernelMeta with a torch.compile-compatible version # that avoids data-dependent branching @@ -164,38 +162,10 @@ def __init__( dtype=torch.bfloat16, ) - # --- Hiding group buffers --- - num_groups = config.num_hiding_groups - if num_groups > 0: - group_token_ids = config.get_hiding_group_token_ids() - all_known_ids = [tid for tids in group_token_ids.values() for tid in tids] - if config.adapter_token_ids: - all_known_ids.extend(config.adapter_token_ids) - max_tid = max(all_known_ids) if all_known_ids else -1 - table_size = max(config.vocab_size, max_tid + 1) - token_to_group_mask = torch.zeros( - table_size, num_groups, dtype=torch.bool - ) - for g, tids in group_token_ids.items(): - for tid in tids: - token_to_group_mask[tid, g] = True - self.register_buffer("token_to_group_mask", token_to_group_mask) - - policy_matrix = config.get_adapter_hiding_policy_matrix() - self.register_buffer( - "adapter_hiding_matrix", - torch.tensor(policy_matrix, dtype=torch.bool), - ) - else: - self.token_to_group_mask = None - self.adapter_hiding_matrix = None - else: self.switch = None self.adapter_token_ids = None self.lora_meta = None - self.token_to_group_mask = None - self.adapter_hiding_matrix = None # 3. Base transformer layers with custom LoRA # @@ -285,19 +255,6 @@ def make_empty_intermediate_tensors( ), } - num_groups = self.config.num_hiding_groups - if num_groups > 0: - tensors["token_group_membership"] = torch.zeros( - (batch_size, num_groups), - dtype=torch.bool, - device=device, - ) - tensors["query_group_suppression"] = torch.zeros( - (batch_size, num_groups), - dtype=torch.bool, - device=device, - ) - return IntermediateTensors(tensors) def forward( @@ -327,63 +284,33 @@ def forward( # COMPILED: Switch + Metadata preparation # ═══════════════════════════════════════════════════════════════ - # Step 1: Switch - determine adapter for each token via switch - # Switch only runs on first rank - hidden_count = None + # Step 1: Switch — determine adapter for each token and rewrite + # control tokens via token-exchange. Only runs on first rank. if get_pp_group().is_first_rank: if self.switch is not None: - adapter_indices = self.switch( + adapter_indices, modified_input_ids = self.switch( input_ids=input_ids, adapter_token_ids=self.adapter_token_ids, ) else: - # No switch - all tokens use base model (adapter_id = 0) + # No switch — all tokens use base model (adapter_id = 0). num_tokens = input_ids.shape[0] adapter_indices = torch.zeros( - num_tokens, - dtype=torch.long, - device=input_ids.device + num_tokens, dtype=torch.long, device=input_ids.device, ) + modified_input_ids = input_ids - # Step 2: Compute group-based hiding masks. - if self.token_to_group_mask is not None: - # token_group_membership: True at [i, g] if token i is a member of group g - token_group_membership = self.token_to_group_mask[input_ids] # [num_tokens, num_groups] - # query_group_suppression: True at [i, g] if token i's adapter suppresses group g - query_group_suppression = self.adapter_hiding_matrix[adapter_indices] # [num_tokens, num_groups] - else: - token_group_membership = None - query_group_suppression = None - - # Compute hidden_count for position correction (SingleSwitch) - if hidden_count is None: - hidden_count = (adapter_indices > 0).long() - - # Position correction: adjust positions to close gaps from hidden tokens. - # Clamp to >= 0: pre-init tokens have no hidden tokens in their causal - # past, but the counting mechanism returns capacity-1 when all attention - # keys are masked, which would produce negative positions and OOB RoPE - # cache indices. - if hidden_count is not None: - positions = torch.clamp(positions - hidden_count, min=0) - - # Step 3: Prepare LoRA metadata ONCE for all decoder layers. + # Step 2: Prepare LoRA metadata ONCE for all decoder layers. # Stored on the shared LoRAContext — every SwitchedLoRALinear reads from it. if self.lora_meta is not None and self.lora_ctx is not None: # Convert to Punica convention: 0=base -> -1=base punica_indices = adapter_indices - 1 self.lora_meta.prepare_and_store(punica_indices, self.lora_ctx) - self.lora_ctx.token_group_membership = token_group_membership - self.lora_ctx.query_group_suppression = query_group_suppression - # Store metadata in intermediate_tensors for pipeline parallelism + # Store metadata in intermediate_tensors for pipeline parallelism. if intermediate_tensors is None: intermediate_tensors = IntermediateTensors({}) intermediate_tensors["adapter_indices"] = adapter_indices - if token_group_membership is not None: - intermediate_tensors["token_group_membership"] = token_group_membership - if query_group_suppression is not None: - intermediate_tensors["query_group_suppression"] = query_group_suppression else: # Subsequent ranks: recompute fixed-size LoRA metadata from # token-leading adapter_indices received through PP. @@ -392,18 +319,6 @@ def forward( if self.lora_ctx is not None: punica_indices = adapter_indices - 1 self.lora_meta.prepare_and_store(punica_indices, self.lora_ctx) - self.lora_ctx.token_group_membership = ( - _get_intermediate_tensor( - intermediate_tensors, "token_group_membership", - ) - ) - self.lora_ctx.query_group_suppression = ( - _get_intermediate_tensor( - intermediate_tensors, "query_group_suppression", - ) - ) - hidden_count = (adapter_indices > 0).long() - positions = torch.clamp(positions - hidden_count, min=0) else: # Fallback: no metadata available (should not happen in normal operation) num_tokens = input_ids.shape[0] if input_ids is not None else 0 @@ -420,7 +335,11 @@ def forward( if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.get_input_embeddings(input_ids) + # Embed the (possibly-rewritten) input_ids the switch returned. + # The switch already performed the token-exchange rewrite, so + # this single lookup produces the correct embeddings for both + # control positions (substitute id) and content positions. + hidden_states = self.get_input_embeddings(modified_input_ids) hidden_states *= self.config.embedding_multiplier residual = None diff --git a/src/granite_switch/vllm/switch/single.py b/src/granite_switch/vllm/switch/single.py index 1171466..6a95c0d 100644 --- a/src/granite_switch/vllm/switch/single.py +++ b/src/granite_switch/vllm/switch/single.py @@ -2,7 +2,7 @@ """SingleSwitch using replicated one-hot attention for adapter selection. This switch uses the backbone's full head geometry (num_attention_heads, -num_key_value_heads, expanded_head_dim, attention_multiplier) so that all +num_key_value_heads, projection_head_dim, attention_multiplier) so that all attention layers share one FlashAttentionMetadataBuilder configuration. The same one-hot dim-0 pattern is replicated identically across every head: @@ -18,7 +18,7 @@ import torch import torch.nn as nn -from typing import Optional +from typing import Optional, Tuple from vllm.model_executor.layers.attention.attention import Attention from vllm.config import VllmConfig @@ -65,7 +65,7 @@ def __init__( self.num_kv_heads = total_kv // tp_size else: self.num_kv_heads = max(1, total_kv // tp_size) - self.head_dim = config.expanded_head_dim + self.head_dim = config.projection_head_dim self.scaling = config.attention_multiplier self.effective_gain = control_token_gain / self.scaling else: @@ -92,6 +92,29 @@ def __init__( prefix="switch.layers.0", ) + # control_to_substitute_lut: [vocab_size_or_higher], -1 at non-control + # ids and the substitute id at each control slot. The switch performs + # the runtime token-exchange: it rewrites input_ids in-place so that + # control-token positions carry the substitute id by the time the + # decoder embeds them. The decoder is then oblivious — it just calls + # get_input_embeddings(input_ids) and gets the right result by + # construction. + if ( + config is not None + and getattr(config, "adapter_token_ids", None) is not None + and getattr(config, "adapter_substitute_token_ids", None) is not None + ): + ctrl_ids = config.adapter_token_ids + sub_ids = config.adapter_substitute_token_ids + max_ctrl_id = max(ctrl_ids) + lut_size = max(getattr(config, "vocab_size", 0), max_ctrl_id + 1) + lut = torch.full((lut_size,), -1, dtype=torch.long) + for ctrl_id, sub_id in zip(ctrl_ids, sub_ids): + lut[ctrl_id] = sub_id + self.register_buffer("control_to_substitute_lut", lut) + else: + self.control_to_substitute_lut = None + @property def num_cache_layers(self) -> int: """Number of KV cache slots used by this switch (1 Attention layer).""" @@ -101,9 +124,13 @@ def forward( self, input_ids: torch.Tensor, adapter_token_ids: torch.Tensor, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Compute adapter indices using replicated one-hot attention. + Compute adapter indices and rewrite control tokens via the LUT. + + See the HF SingleSwitch docstring for the full rationale. In short: + the switch performs both adapter selection and token-exchange + rewrite, so the decoder is agnostic to substitution. Args: input_ids: Input token IDs [total_tokens] - flattened by vLLM scheduler @@ -114,7 +141,10 @@ def forward( to transition back to base mid-sequence. Returns: - adapter_indices: [total_tokens] where 0 = base, 1+ = adapters + (adapter_indices, modified_input_ids): + adapter_indices: [total_tokens] where 0 = base, 1+ = adapters. + modified_input_ids: [total_tokens] with each control-token id + replaced by its substitute id. """ total_tokens = input_ids.shape[0] device = input_ids.device @@ -162,8 +192,23 @@ def forward( # Round to get integer adapter indices adapter_indices = torch.round(attn_output).long() - + # Clamp to valid range [0, num_adapters] adapter_indices = torch.clamp(adapter_indices, 0, self.num_adapters) - return adapter_indices + # Token-exchange rewrite: see the HF switch for the rationale. + # Skipped only when no LUT was built (no substitute ids configured). + # No data-dependent gate here — the surrounding decoder is wrapped in + # @support_torch_compile, which forbids `tensor.any()` branching. + # `torch.where` runs every step; the cost is one indexed gather and + # one elementwise select on the flat input. + if self.control_to_substitute_lut is not None: + sub_id_per_pos = self.control_to_substitute_lut[input_ids] + is_control = sub_id_per_pos >= 0 + modified_input_ids = torch.where( + is_control, sub_id_per_pos, input_ids + ) + else: + modified_input_ids = input_ids + + return adapter_indices, modified_input_ids diff --git a/tests/composer/test_built_in_adapters.py b/tests/composer/test_built_in_adapters.py deleted file mode 100644 index 783fe5e..0000000 --- a/tests/composer/test_built_in_adapters.py +++ /dev/null @@ -1,263 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""Tests for built-in adapter support (Mode A native / Mode B mixed). - -Tests config-level behavior: -- Mode A: control_dims=0, no hiding -- Mode B: mixed built-in + external → control_dims>0, third_party = external only -- SSM rejection for mixed mode -- Model construction with control_dims=0 -""" - -import pytest -import torch - -from granite_switch.config import GraniteSwitchConfig -from granite_switch.hf import GraniteSwitchForCausalLM - - -# ── Fixtures ────────────────────────────────────────────────────────── - - -@pytest.fixture -def mode_a_config(): - """Mode A (native): built-in adapters only, control_dims=0.""" - return GraniteSwitchConfig( - vocab_size=300, - hidden_size=64, - intermediate_size=128, - num_hidden_layers=3, # 1 switch + 2 decoder - num_attention_heads=4, - num_key_value_heads=4, - num_adapters=2, - adapter_token_ids=[250, 251], - adapter_names=["router", "planner"], - max_lora_rank=4, - adapter_ranks=[4, 4], - switch_head_dim=16, - control_dims=0, - # No hiding groups - hiding_groups=None, - hiding_policy=None, - adapter_third_party=None, - ) - - -@pytest.fixture -def mode_b_config(): - """Mode B (mixed): 1 external + 1 built-in adapter, control_dims=8.""" - return GraniteSwitchConfig( - vocab_size=300, - hidden_size=64, - intermediate_size=128, - num_hidden_layers=3, # 1 switch + 2 decoder - num_attention_heads=4, - num_key_value_heads=4, - num_adapters=2, - adapter_token_ids=[250, 251], - adapter_names=["external_rag", "router"], - hiding_groups={"all_controls": ["external_rag", "router"]}, - hiding_policy={ - "base": ["all_controls"], - "external_rag": ["all_controls"], - "router": ["all_controls"], - }, - adapter_third_party=["external_rag"], # Only external is third-party - max_lora_rank=4, - adapter_ranks=[4, 4], - switch_head_dim=16, - control_dims=8, - ) - - -# ── Mode A Config Tests ────────────────────────────────────────────── - - -class TestModeAConfig: - """Config-level checks for Mode A (native, control_dims=0).""" - - def test_control_dims_zero_allowed(self, mode_a_config): - """control_dims=0 should be accepted by config validation.""" - assert mode_a_config.control_dims == 0 - - def test_no_hiding_groups(self, mode_a_config): - """Mode A has no hiding groups.""" - assert mode_a_config.num_hiding_groups == 0 - assert mode_a_config.hiding_group_names == [] - assert mode_a_config.get_hiding_group_token_ids() == {} - - def test_no_third_party(self, mode_a_config): - """Mode A has no third-party adapters.""" - assert mode_a_config.adapter_third_party is None - mask = mode_a_config.get_third_party_adapter_mask() - assert all(v is False for v in mask) - - def test_adapters_present(self, mode_a_config): - """Mode A still has adapters with LoRA.""" - assert mode_a_config.num_adapters == 2 - assert mode_a_config.adapter_ranks == [4, 4] - - -# ── Mode A Model Tests ─────────────────────────────────────────────── - - -class TestModeAModel: - """Model construction and forward pass with control_dims=0.""" - - def test_model_creates_successfully(self, mode_a_config): - """GraniteSwitchForCausalLM should construct with control_dims=0.""" - model = GraniteSwitchForCausalLM(mode_a_config) - assert model is not None - assert model.config.control_dims == 0 - - def test_attention_no_expansion(self, mode_a_config): - """Decoder attention layers should NOT expand control dims.""" - model = GraniteSwitchForCausalLM(mode_a_config) - for layer in model.model.layers: - attn = layer.self_attn - assert not attn.expand_control_dims, ( - "expand_control_dims should be False when control_dims=0" - ) - assert attn.expanded_head_dim == attn.head_dim, ( - "expanded_head_dim should equal head_dim when control_dims=0" - ) - - def test_forward_pass(self, mode_a_config): - """Forward pass should work with control_dims=0.""" - model = GraniteSwitchForCausalLM(mode_a_config).eval() - model.model.adapter_token_ids.data = torch.tensor( - mode_a_config.adapter_token_ids, dtype=torch.long - ) - - input_ids = torch.tensor([[10, 250, 20, 30, 40]]) - with torch.no_grad(): - output = model(input_ids=input_ids) - assert output.logits.shape == (1, 5, mode_a_config.vocab_size) - - def test_no_hiding_buffers(self, mode_a_config): - """Model should have no hiding-related buffers when control_dims=0.""" - model = GraniteSwitchForCausalLM(mode_a_config) - assert model.model.token_to_group_mask is None - assert model.model.adapter_hiding_matrix is None - - def test_lora_shapes_correct(self, mode_a_config): - """LoRA weight shapes should reflect num_adapters.""" - model = GraniteSwitchForCausalLM(mode_a_config) - layer = model.model.layers[0] # First decoder layer - attn = layer.self_attn - # QKV has LoRA with 2 adapters, rank 4 - if hasattr(attn.qkv_proj, "lora_A_slices"): - for lora_a in attn.qkv_proj.lora_A_slices: - assert lora_a.shape[0] == 2, "num_adapters should be 2" - assert lora_a.shape[2] == 4, "max_lora_rank should be 4" - - def test_adapter_routing_works(self, mode_a_config): - """Adapter routing should still work with control_dims=0.""" - model = GraniteSwitchForCausalLM(mode_a_config).eval() - model.model.adapter_token_ids.data = torch.tensor( - mode_a_config.adapter_token_ids, dtype=torch.long - ) - - # Set non-zero lora_B to make adapter effect visible - with torch.no_grad(): - for layer in model.model.layers: - if hasattr(layer.self_attn.o_proj, "lora_B"): - layer.self_attn.o_proj.lora_B.data = ( - torch.randn_like(layer.self_attn.o_proj.lora_B) * 0.1 - ) - - # All base tokens - base_ids = torch.tensor([[10, 20, 30, 40, 50]]) - # With adapter control token - adapter_ids = torch.tensor([[250, 20, 30, 40, 50]]) - - with torch.no_grad(): - out_base = model(input_ids=base_ids) - out_adapter = model(input_ids=adapter_ids) - - # Logits should differ when adapter is active - # (tokens after control token see different LoRA) - diff = (out_base.logits[0, -1] - out_adapter.logits[0, -1]).abs().max() - assert diff > 1e-6, "Adapter should produce different logits than base" - - def test_control_token_logits_finite(self, mode_a_config): - """Control token logits should be finite.""" - model = GraniteSwitchForCausalLM(mode_a_config).eval() - model.model.adapter_token_ids.data = torch.tensor( - mode_a_config.adapter_token_ids, dtype=torch.long - ) - - input_ids = torch.tensor([[250, 20, 30]]) - with torch.no_grad(): - output = model(input_ids=input_ids) - - control_token_logits = output.logits[:, :, mode_a_config.adapter_token_ids] - assert torch.isfinite(control_token_logits).all(), ( - "Control token logits should be finite" - ) - - -# ── Mode B Config Tests ────────────────────────────────────────────── - - -class TestModeBConfig: - """Config-level checks for Mode B (mixed, control_dims>0).""" - - def test_control_dims_positive(self, mode_b_config): - """Mode B should have control_dims > 0.""" - assert mode_b_config.control_dims == 8 - - def test_only_external_is_third_party(self, mode_b_config): - """Only external adapter should be third-party.""" - assert mode_b_config.adapter_third_party == ["external_rag"] - mask = mode_b_config.get_third_party_adapter_mask() - # [base=False, external_rag=True, router=False] - assert mask == [False, True, False] - - def test_hiding_groups_present(self, mode_b_config): - """Mode B should have hiding groups.""" - assert mode_b_config.num_hiding_groups == 1 - - def test_third_party_mask(self, mode_b_config): - """Third-party mask marks only external adapter.""" - mask = mode_b_config.get_third_party_adapter_mask() - assert mask == [False, True, False] - - -# ── Negative Tests ──────────────────────────────────────────────────── - - -class TestNegative: - """Validation errors that should be raised.""" - - def test_control_dims_negative_rejected(self): - """control_dims < 0 should still be rejected.""" - with pytest.raises(ValueError, match="control_dims must be >= 0"): - GraniteSwitchConfig( - vocab_size=256, - hidden_size=64, - intermediate_size=128, - num_hidden_layers=2, - num_attention_heads=4, - num_key_value_heads=4, - num_adapters=0, - control_dims=-1, - ) - - def test_hiding_groups_require_control_dims(self): - """Hiding groups with control_dims=0 should be rejected.""" - with pytest.raises(ValueError, match="control_dims.*must be >= number of hiding groups"): - GraniteSwitchConfig( - vocab_size=300, - hidden_size=64, - intermediate_size=128, - num_hidden_layers=3, - num_attention_heads=4, - num_key_value_heads=4, - num_adapters=2, - adapter_token_ids=[250, 251], - adapter_names=["a", "b"], - hiding_groups={"all_controls": ["a", "b"]}, - max_lora_rank=4, - adapter_ranks=[4, 4], - control_dims=0, # Too few for 1 hiding group - ) diff --git a/tests/composer/test_chat_template.py b/tests/composer/test_chat_template.py index 52f600e..e363afd 100644 --- a/tests/composer/test_chat_template.py +++ b/tests/composer/test_chat_template.py @@ -49,7 +49,15 @@ def _render(tokenizer, **kwargs): class TestConfigureChatTemplate: def test_lora_prefix_path(self): - """LoRA: activation token emitted at the very start of the sequence.""" + """LoRA: activation token emitted at the very start of the sequence. + + The skip-once flag set by lora_prefix_insertion suppresses the very + next <|start_of_role|>, so the rendered output is + '<|ctx_rel|>user<|end_of_role|>...', not + '<|ctx_rel|><|start_of_role|>user<|end_of_role|>...'. This keeps the + runtime embedding-swap from producing two identical consecutive + embeddings (see tokenizer_setup.py lora_prefix_insertion comment). + """ tokenizer = _make_tokenizer() configure_chat_template(tokenizer, [("/path/a", "ctx_rel", "lora")]) @@ -59,14 +67,23 @@ def test_lora_prefix_path(self): add_generation_prompt=True, adapter_name="ctx_rel", ) - assert result.startswith("<|ctx_rel|>") + assert result.startswith("<|ctx_rel|>user<|end_of_role|>"), ( + f"expected <|ctx_rel|> followed by 'user<|end_of_role|>' " + f"(skip-once suppressed <|start_of_role|>), got {result[:80]!r}" + ) + # Exactly one <|start_of_role|> should survive: the assistant turn. + assert result.count("<|start_of_role|>") == 1 def test_alora_pass1_pass2_path(self): - """ALoRA Pass 1+2: token inserted in last user message before invocation text. + """ALoRA Pass 1+2: token inserted in last user message, first char of + invocation text dropped. Pass 1 finds the user message containing '' and sets - ns.alora_target_idx. Pass 2 splits content.val on '' - and rejoins with the control token before the last occurrence. + ns.alora_target_idx. Pass 2 splits content.val on '' + and rejoins with the control token followed by the invocation text + MINUS its first character ('<' is dropped). The runtime swap + replaces the control token's embedding with '<'s embedding, so the + sequence tokenizes the same as '' with no duplicate. The fallback block does NOT fire (alora_target_idx >= 0). """ with patch(_PATCH_TARGET, return_value=""): @@ -81,9 +98,13 @@ def test_alora_pass1_pass2_path(self): add_generation_prompt=True, adapter_name="req_check", ) - # Token immediately precedes the invocation text inside the user turn + # Token immediately precedes the invocation text (minus first char) + # inside the user turn: "<|req_check|>requirements>" (no '<'). user_turn_header = "<|start_of_role|>user<|end_of_role|>" - assert user_turn_header + "<|req_check|>" in result + assert user_turn_header + "<|req_check|>requirements>" in result + # And the literal "<|req_check|>" should NOT appear — + # the leading '<' must have been dropped. + assert "<|req_check|>" not in result # Fallback did not fire: token is not immediately before generation prompt gen_prompt = "<|start_of_role|>assistant<|end_of_role|>" last_gen_pos = result.rindex(gen_prompt) @@ -109,12 +130,22 @@ def test_alora_fallback_path(self): adapter_name="answerability", ) assert "<|answerability|>" in result - # Token appears immediately before the generation prompt + # Token appears immediately before what would have been the generation + # prompt's <|start_of_role|>. The skip-once flag set by alora_insertion + # suppresses that <|start_of_role|>, so the rendered output has + # "<|answerability|>assistant<|end_of_role|>" — no role marker between + # the control token and the role name. Prevents a duplicate-embedding + # OOD at position 1 after the runtime swap (see tokenizer_setup.py + # alora_insertion comment). token = "<|answerability|>" - gen_prompt = "<|start_of_role|>assistant<|end_of_role|>" token_pos = result.index(token) - gen_pos = result.index(gen_prompt, token_pos) - assert result[token_pos + len(token):gen_pos].strip() == "" + after = result[token_pos + len(token):] + assert after.startswith("assistant<|end_of_role|>"), ( + f"expected 'assistant<|end_of_role|>' immediately after " + f"{token!r}, got {after[:60]!r}" + ) + # Only one <|start_of_role|> should survive: the one before the user turn. + assert result.count("<|start_of_role|>") == 1 def test_alora_pass1_pass2_iterable_content(self): """ALoRA Pass 1+2: token inserted correctly when message content is a list of parts. @@ -145,14 +176,43 @@ def test_alora_pass1_pass2_iterable_content(self): add_generation_prompt=True, adapter_name="req_check", ) - # Token must appear immediately before invocation text inside the user turn - assert "<|req_check|>" in result + # Token appears before the invocation text, and the invocation + # text's first character ('<') has been dropped. + assert "<|req_check|>requirements>" in result + assert "<|req_check|>" not in result assert result.index("<|req_check|>") > result.index("<|start_of_role|>user<|end_of_role|>") # Fallback must NOT also fire gen_prompt = "<|start_of_role|>assistant<|end_of_role|>" last_gen_pos = result.rindex(gen_prompt) assert result[last_gen_pos - len("<|req_check|>"):last_gen_pos] != "<|req_check|>" + def test_skip_once_is_single_shot(self): + """Skip-once flag consumes itself: only the first <|start_of_role|> + after a LoRA control token is suppressed; later role markers emit.""" + tokenizer = _make_tokenizer() + configure_chat_template(tokenizer, [("/path/a", "my_lora", "lora")]) + + # Two user turns so the template emits <|start_of_role|> three times: + # once per user turn + once for the generation prompt. Only the very + # first one should be suppressed. + result = _render( + tokenizer, + messages=[ + {"role": "user", "content": "first"}, + {"role": "assistant", "content": "reply"}, + {"role": "user", "content": "second"}, + ], + add_generation_prompt=True, + adapter_name="my_lora", + ) + assert result.startswith("<|my_lora|>user<|end_of_role|>"), ( + f"first <|start_of_role|> should be suppressed; got {result[:80]!r}" + ) + # Four role markers would be emitted normally (first user, assistant, + # second user, assistant-generation-prompt). Skip-once removes the + # first → exactly three survive. + assert result.count("<|start_of_role|>") == 3 + def test_no_adapter_no_tokens(self): """Without adapter_name the rendered output is identical to the original template.""" messages = [{"role": "user", "content": "Hello"}] @@ -169,6 +229,55 @@ def test_no_adapter_no_tokens(self): assert modified == original +class TestInvocationFirstCharDropProperty: + """Standalone property test on a real Granite tokenizer: dropping the first + character of an ALoRA invocation text yields the same tail-token sequence + as tokenizing the full invocation text and dropping its first token. This + is the BPE-level invariant the Pass-2 edit relies on — if a future + tokenizer change breaks it, the template-level drop would silently corrupt + the tail of the invocation. + """ + + _INVOCATIONS = [ + "", + "", + "", + "", + ] + + def _get_tokenizer(self): + from transformers import AutoTokenizer + try: + return AutoTokenizer.from_pretrained("ibm-granite/granite-4.1-3b") + except Exception as e: + import pytest + pytest.skip(f"could not fetch Granite tokenizer: {e}") + + def test_first_char_drop_equals_first_token_drop(self): + tok = self._get_tokenizer() + for invocation in self._INVOCATIONS: + full_ids = tok(invocation, add_special_tokens=False).input_ids + dropped_ids = tok(invocation[1:], add_special_tokens=False).input_ids + assert full_ids[1:] == dropped_ids, ( + f"invocation {invocation!r}: dropping first char of the " + f"string produced tokens {dropped_ids} but the tail of the " + f"full tokenization is {full_ids[1:]}" + ) + + def test_first_token_is_single_character(self): + """Sanity: the first token of each invocation must be exactly one + character (the leading '<'). Otherwise dropping invocation_text[1:] + in Jinja would drop the wrong number of characters.""" + tok = self._get_tokenizer() + for invocation in self._INVOCATIONS: + first_id = tok(invocation, add_special_tokens=False).input_ids[0] + first_str = tok.decode([first_id]) + assert first_str == invocation[0], ( + f"invocation {invocation!r}: first token decodes to " + f"{first_str!r}, expected {invocation[0]!r}" + ) + + class _FixtureTokenizer: """Tokenizer with a decode map for fixture adapter token IDs.""" @@ -210,16 +319,27 @@ def test_alora_fallback_from_adapter_config(self): add_generation_prompt=True, adapter_name="answerability", ) - # Fallback: token immediately before generation prompt + # Fallback: token immediately before generation prompt, with the + # generation-prompt <|start_of_role|> suppressed by the skip-once flag + # armed in alora_insertion. Output is "<|answerability|>assistant<|end_of_role|>". token = "<|answerability|>" - gen_prompt = "<|start_of_role|>assistant<|end_of_role|>" assert token in result token_pos = result.index(token) - gen_pos = result.index(gen_prompt, token_pos) - assert result[token_pos + len(token):gen_pos].strip() == "" + after = result[token_pos + len(token):] + assert after.startswith("assistant<|end_of_role|>"), ( + f"expected 'assistant<|end_of_role|>' immediately after " + f"{token!r}, got {after[:60]!r}" + ) + # Only the user-turn <|start_of_role|> should survive. + assert result.count("<|start_of_role|>") == 1 def test_alora_invocation_at_start_of_user_message(self): - """ALoRA: invocation text is the first thing in the user message.""" + """ALoRA: invocation text is the first thing in the user message. + + Pass 2 drops the first character of the invocation text after + inserting the control token, so "" becomes + "<|context_relevance|>context>" in the rendered output. + """ tokenizer = self._make_tokenizer({(27,): ""}) configure_chat_template(tokenizer, [ (self._CONTEXT_REL, "context_relevance", "alora"), @@ -231,16 +351,21 @@ def test_alora_invocation_at_start_of_user_message(self): add_generation_prompt=True, adapter_name="context_relevance", ) - # Token injected right after the user role header, before + # Token injected right after the user role header; the '<' of + # the invocation text is dropped. user_header = "<|start_of_role|>user<|end_of_role|>" - assert user_header + "<|context_relevance|>" in result + assert user_header + "<|context_relevance|>context>" in result + assert "<|context_relevance|>" not in result # Fallback must NOT fire gen_prompt = "<|start_of_role|>assistant<|end_of_role|>" last_gen_pos = result.rindex(gen_prompt) assert result[last_gen_pos - len("<|context_relevance|>"):last_gen_pos] != "<|context_relevance|>" def test_alora_invocation_mid_user_message(self): - """ALoRA: invocation text appears in the middle of the user message.""" + """ALoRA: invocation text appears in the middle of the user message. + + Same first-character drop as the start-of-message case. + """ tokenizer = self._make_tokenizer({(27,): ""}) configure_chat_template(tokenizer, [ (self._CONTEXT_REL, "context_relevance", "alora"), @@ -252,8 +377,9 @@ def test_alora_invocation_mid_user_message(self): add_generation_prompt=True, adapter_name="context_relevance", ) - # Token injected mid-message, before - assert "Please review: <|context_relevance|>" in result + # Token injected mid-message, invocation text's '<' dropped. + assert "Please review: <|context_relevance|>context>" in result + assert "<|context_relevance|>" not in result user_header = "<|start_of_role|>user<|end_of_role|>" assert result.index("<|context_relevance|>") > result.index(user_header) # Fallback must NOT fire @@ -265,7 +391,8 @@ def test_alora_multiple_occurrences_targets_last(self): """ALoRA: invocation text appears twice — token injected before the last occurrence. rsplit(..., 1) splits on the last occurrence, so the control token must - land before the second , not the first. + land before the second , not the first. First occurrence + remains intact with its '<'; only the second has its '<' dropped. """ tokenizer = self._make_tokenizer({(27,): ""}) configure_chat_template(tokenizer, [ @@ -281,8 +408,9 @@ def test_alora_multiple_occurrences_targets_last(self): add_generation_prompt=True, adapter_name="context_relevance", ) - # The first must NOT have the control token before it - assert "first batch Also check <|context_relevance|>second batch" in result + # First untouched; second one has the control token + # inserted with its '<' dropped. + assert "first batch Also check <|context_relevance|>context>second batch" in result # Only one control token in the entire output assert result.count("<|context_relevance|>") == 1 @@ -300,6 +428,11 @@ def test_lora_prefix_from_adapter_config(self): adapter_name="summarization", ) assert result.startswith("<|summarization|>") + # Skip-once suppresses the user-turn <|start_of_role|>: output is + # "<|summarization|>user<|end_of_role|>...", not + # "<|summarization|><|start_of_role|>user...". Keeps the adapter + # substitute token from duplicating at runtime. + assert result.startswith("<|summarization|>user<|end_of_role|>") def test_mixed_adapters_from_adapter_config(self): """All three adapter types composed together, each activated independently.""" @@ -315,23 +448,24 @@ def test_mixed_adapters_from_adapter_config(self): messages = [{"role": "user", "content": "docs"}] - # Activate context_relevance → Pass 1+2 + # Activate context_relevance → Pass 1+2 (drops first char of invocation). result = _render( tokenizer, messages=messages, add_generation_prompt=True, adapter_name="context_relevance", ) - assert "<|context_relevance|>" in result + assert "<|context_relevance|>context>" in result + assert "<|context_relevance|>" not in result - # Activate answerability → fallback + # Activate answerability → fallback (skip-once suppresses the + # generation-prompt <|start_of_role|>). result = _render( tokenizer, messages=messages, add_generation_prompt=True, adapter_name="answerability", ) token = "<|answerability|>" - gen_prompt = "<|start_of_role|>assistant<|end_of_role|>" token_pos = result.index(token) - gen_pos = result.index(gen_prompt, token_pos) - assert result[token_pos + len(token):gen_pos].strip() == "" + after = result[token_pos + len(token):] + assert after.startswith("assistant<|end_of_role|>") # Activate summarization → prefix result = _render( diff --git a/tests/composer/test_lora_substitute_probe.py b/tests/composer/test_lora_substitute_probe.py new file mode 100644 index 0000000..f5c90f1 --- /dev/null +++ b/tests/composer/test_lora_substitute_probe.py @@ -0,0 +1,134 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for _probe_lora_substitute_token_id. + +The probe derives the LoRA substitute token from the tokenizer's chat +template rather than hardcoding a Granite-4.x-specific token string. These +tests verify: + +1. On real Granite tokenizers, the probe returns <|start_of_role|> (id + 100264) — the token the LoRA prefix insertion places immediately after + the control token in the rendered prompt. +2. On a synthetic tokenizer with a different template, the probe returns + whatever that template emits first for a user turn. +3. The probe raises a clear error when the template is missing, fails to + render, or emits an unknown token. +""" + +from types import SimpleNamespace + +import pytest + +from granite_switch.composer.compose_granite_switch import ( + _probe_lora_substitute_token_id, +) + + +class TestOnRealGraniteTokenizer: + """Exercise the probe on actual Granite tokenizers. Network-dependent; + skips cleanly if the model can't be fetched.""" + + def _tok(self, name): + from transformers import AutoTokenizer + try: + return AutoTokenizer.from_pretrained(name) + except Exception as e: + pytest.skip(f"could not fetch tokenizer {name!r}: {e}") + + def test_granite_4_1_3b(self): + tok = self._tok("ibm-granite/granite-4.1-3b") + sub_id = _probe_lora_substitute_token_id(tok) + assert sub_id == 100264 + assert tok.convert_ids_to_tokens([sub_id])[0] == "<|start_of_role|>" + + def test_granite_4_0_micro(self): + tok = self._tok("ibm-granite/granite-4.0-micro") + sub_id = _probe_lora_substitute_token_id(tok) + assert sub_id == 100264 + assert tok.convert_ids_to_tokens([sub_id])[0] == "<|start_of_role|>" + + +class TestOnSyntheticTokenizer: + """Verify the probe is generic — it returns whatever the template emits, + not a Granite-specific hardcoded token.""" + + def test_custom_template_gives_custom_token(self): + """A template whose first emission is a different marker produces + the id of that different marker.""" + + class _FakeTokenizer: + chat_template = "" + unk_token_id = 0 + + def apply_chat_template( + self, messages, tokenize, add_generation_prompt + ): + assert tokenize is False + assert add_generation_prompt is False + return "hello" + + def __call__(self, text, **kwargs): + # Pretend tokenizes as [42], "hello" as [7, 8, 9, 10, 11]. + assert kwargs.get("add_special_tokens") is False + assert text == "hello" + return SimpleNamespace(input_ids=[42, 7, 8, 9, 10, 11]) + + assert _probe_lora_substitute_token_id(_FakeTokenizer()) == 42 + + +class TestErrorPaths: + + def _minimal_tokenizer_without_template(self): + class _T: + chat_template = None + unk_token_id = 0 + def apply_chat_template(self, *a, **kw): + raise AssertionError("should not be called") + def __call__(self, text, **kw): + raise AssertionError("should not be called") + return _T() + + def _tokenizer_whose_template_fails(self): + class _T: + chat_template = "" + unk_token_id = 0 + def apply_chat_template(self, *a, **kw): + raise RuntimeError("template exploded") + def __call__(self, text, **kw): + raise AssertionError("unreachable") + return _T() + + def _tokenizer_emitting_unk(self): + class _T: + chat_template = "" + unk_token_id = 777 + def apply_chat_template(self, messages, tokenize, add_generation_prompt): + return "mystery" + def __call__(self, text, **kw): + return SimpleNamespace(input_ids=[777]) + return _T() + + def _tokenizer_emitting_empty(self): + class _T: + chat_template = "" + unk_token_id = 0 + def apply_chat_template(self, messages, tokenize, add_generation_prompt): + return "" + def __call__(self, text, **kw): + return SimpleNamespace(input_ids=[]) + return _T() + + def test_missing_chat_template_raises(self): + with pytest.raises(ValueError, match="no chat_template"): + _probe_lora_substitute_token_id(self._minimal_tokenizer_without_template()) + + def test_template_render_failure_raises(self): + with pytest.raises(ValueError, match="Failed to render a probe chat"): + _probe_lora_substitute_token_id(self._tokenizer_whose_template_fails()) + + def test_unk_first_token_raises(self): + with pytest.raises(ValueError, match=""): + _probe_lora_substitute_token_id(self._tokenizer_emitting_unk()) + + def test_empty_tokenization_raises(self): + with pytest.raises(ValueError, match="empty id list"): + _probe_lora_substitute_token_id(self._tokenizer_emitting_empty()) diff --git a/tests/composer/test_save_load_compose.py b/tests/composer/test_save_load_compose.py index e1008c3..f7e579f 100644 --- a/tests/composer/test_save_load_compose.py +++ b/tests/composer/test_save_load_compose.py @@ -492,14 +492,14 @@ def test_pipeline_metadata_files_exist(self, phase1): ) def test_config_adapter_identity(self, phase1): - """num_adapters, token IDs, names, third_party survive save→load.""" + """num_adapters, token IDs, names, substitute IDs survive save→load.""" built = phase1["built_config"] loaded = phase1["loaded_config"] assert loaded.num_adapters == built.num_adapters assert loaded.adapter_token_ids == built.adapter_token_ids assert loaded.adapter_names == built.adapter_names - assert loaded.adapter_third_party == built.adapter_third_party + assert loaded.adapter_substitute_token_ids == built.adapter_substitute_token_ids def test_config_lora(self, phase1): """adapter_ranks, max_lora_rank, lora_target_modules survive save→load.""" @@ -511,22 +511,13 @@ def test_config_lora(self, phase1): assert loaded.lora_target_modules == built.lora_target_modules def test_config_switch(self, phase1): - """switch head_dim, control_dims, gain survive save→load.""" + """switch_head_dim and control_token_gain survive save→load.""" built = phase1["built_config"] loaded = phase1["loaded_config"] assert loaded.switch_head_dim == built.switch_head_dim - assert loaded.control_dims == built.control_dims assert loaded.control_token_gain == built.control_token_gain - def test_config_hiding(self, phase1): - """hiding_groups and hiding_policy survive save→load.""" - built = phase1["built_config"] - loaded = phase1["loaded_config"] - - assert loaded.hiding_groups == built.hiding_groups - assert loaded.hiding_policy == built.hiding_policy - def test_config_granite_scaling(self, phase1): """Granite-specific scaling parameters survive save→load.""" built = phase1["built_config"] @@ -848,8 +839,7 @@ def test_config_matches(self, phase2): assert c2.adapter_ranks == c1.adapter_ranks assert c2.max_lora_rank == c1.max_lora_rank assert c2.lora_target_modules == c1.lora_target_modules - assert c2.hiding_groups == c1.hiding_groups - assert c2.hiding_policy == c1.hiding_policy + assert c2.adapter_substitute_token_ids == c1.adapter_substitute_token_ids assert c2.logits_scaling == c1.logits_scaling assert c2.attention_multiplier == c1.attention_multiplier assert c2.vocab_size == c1.vocab_size diff --git a/tests/conftest.py b/tests/conftest.py index 7fca1da..e261922 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -54,18 +54,11 @@ def tiny_config(): num_key_value_heads=4, num_adapters=2, adapter_token_ids=[250, 251], + adapter_substitute_token_ids=[1, 1], adapter_names=["adapter_a", "adapter_b"], - hiding_groups={"all_controls": ["adapter_a", "adapter_b"]}, - hiding_policy={ - "base": ["all_controls"], - "adapter_a": ["all_controls"], - "adapter_b": ["all_controls"], - }, - adapter_third_party=["adapter_a", "adapter_b"], max_lora_rank=4, adapter_ranks=[4, 4], switch_head_dim=16, - control_dims=8, ) diff --git a/tests/hf/test_granite4_mini.py b/tests/hf/test_granite4_mini.py index c70412d..3de3884 100644 --- a/tests/hf/test_granite4_mini.py +++ b/tests/hf/test_granite4_mini.py @@ -178,7 +178,7 @@ def _make_zero_adapter_pair(cfg_dict): for name in unloaded: assert any(k in name for k in ( "lora_A", "lora_B", "switch", "adapter_token_ids", - "token_to_group_mask", "adapter_hiding_matrix", + "control_to_substitute_lut", )), f"Unexpected unloaded parameter: {name}" # Zero all LoRA weights defensively diff --git a/tests/hf/test_kv_hiding_gap_equivalence.py b/tests/hf/test_kv_hiding_gap_equivalence.py deleted file mode 100644 index 7fc857a..0000000 --- a/tests/hf/test_kv_hiding_gap_equivalence.py +++ /dev/null @@ -1,215 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""Verify that a hidden control token creates a transparent gap in attention. - -The upstream model processes a contiguous N-token sequence. The switch model -processes the same N content tokens with a hidden control token inserted, -giving N+1 total tokens. With zero LoRA weights and SingleSwitch's hidden_count -closing the RoPE gap, the logits at corresponding visible positions should -match within FP tolerance. - -The hiding mechanism itself is exact on CPU: -- exp(finfo.min) = 0.0 exactly → hidden token gets zero softmax weight -- Control dims: Q_ctrl * K_ctrl = 1.0 * 0.0 = 0.0 → no score change -- V_control = 0.0 → zero contribution to attention output - -The ~1e-7 tolerance comes from different softmax window sizes at -corresponding positions. Switch position k+1 computes softmax over k+2 -entries (including the ~0 hidden token), while upstream position k computes -softmax over k+1 entries. Although the hidden entry contributes exactly 0.0 -to the denominator, SDPA's fused softmax kernel processes different-length -reductions with different FP accumulation order. Positions before the -control token are bit-exact (same causal window in both models). - -Attention-only models only — Mamba layers do not support KV hiding (the hidden -control token would flow through conv1d and SSM state, corrupting subsequent -positions). Only dense (attention-only) configs from GRANITE4_MINI are tested. - -SingleSwitch: hidden_count = (adapter_indices > 0).long() — fires once, -so 0 before control token and 1 at/after (see issue #16). -""" - -import pytest -import torch -from transformers.models.granitemoehybrid.configuration_granitemoehybrid import ( - GraniteMoeHybridConfig, -) -from transformers.models.granitemoehybrid.modeling_granitemoehybrid import ( - GraniteMoeHybridForCausalLM, -) - -from granite_switch.config import GraniteSwitchConfig -from granite_switch.hf import GraniteSwitchForCausalLM - -from tests.shared.granite4_equivalence import ( - augment_cfg_with_adapters, - transfer_weights, - zero_lora_weights, - GRANITE4_MINI, -) -from tests.shared.gap_equivalence import ( - ATTN_ONLY_NAMES, - make_gapped_inputs, - extract_visible_batched, -) - -# Softmax window-size tolerance (see module docstring). -# Observed max: ~1e-7 across all configs/positions/seeds. -# Use 5e-7 (≈5x margin) to accommodate variation. -_ATOL = 5e-7 - - -# ── Helpers ──────────────────────────────────────────────────────── - - -def _make_gap_pair(cfg_dict): - """Create upstream + 1-adapter switch model pair with zero LoRA weights. - - SingleSwitch: adapter_token_ids=[101], 101 is adapter_0 (KV-hidden). - """ - torch.manual_seed(0) - upstream = GraniteMoeHybridForCausalLM( - GraniteMoeHybridConfig(**cfg_dict) - ).eval() - - switch_cfg_dict = augment_cfg_with_adapters(cfg_dict, num_adapters=1) - switch = GraniteSwitchForCausalLM( - GraniteSwitchConfig(**switch_cfg_dict) - ).eval() - - # Transfer base weights (non-strict: LoRA/switch params left unloaded) - unloaded = transfer_weights(upstream.state_dict(), switch.state_dict()) - - # Verify unloaded params are only LoRA and switch related - for name in unloaded: - assert any(k in name for k in ( - "lora_A", "lora_B", "switch", "adapter_token_ids", - "token_to_group_mask", "adapter_hiding_matrix", - )), f"Unexpected unloaded parameter: {name}" - - # Zero all LoRA weights defensively - zero_lora_weights(switch) - - return upstream, switch - - -def _assert_gap_equivalence(name, upstream, switch, seq_len, ctrl_pos, seed=42): - """Run forward pass and assert visible logits match within tolerance.""" - upstream_ids, switch_ids = make_gapped_inputs(seq_len, ctrl_pos, seed) - - with torch.no_grad(): - upstream_out = upstream(input_ids=upstream_ids, use_cache=False) - switch_out = switch(input_ids=switch_ids, use_cache=False) - - visible = extract_visible_batched(switch_out.logits, ctrl_pos) - - torch.testing.assert_close( - visible, upstream_out.logits, - atol=_ATOL, rtol=0.0, - msg=f"{name}: visible logits diverge (seq={seq_len}, ctrl={ctrl_pos})", - ) - - -# ── Test class: KV Hiding Gap Equivalence ───────────────────────── - - -class TestKVHidingGapEquivalence: - """Verify hidden control token creates a transparent gap. - - The upstream model processes N contiguous tokens. The switch model - processes the same N tokens with a hidden control token inserted (N+1 - total). Visible-position logits match within BLAS gemm tolerance. - """ - - @pytest.fixture(params=ATTN_ONLY_NAMES) - def model_pair(self, request): - model_name = request.param - upstream, switch = _make_gap_pair(GRANITE4_MINI[model_name]) - return model_name, upstream, switch - - def test_gap_short(self, model_pair): - """Short sequence (16 tokens), control token at position 2.""" - name, upstream, switch = model_pair - _assert_gap_equivalence(name, upstream, switch, seq_len=16, ctrl_pos=2) - - def test_gap_long(self, model_pair): - """Longer sequence (64 tokens), control token at position 8.""" - name, upstream, switch = model_pair - _assert_gap_equivalence(name, upstream, switch, seq_len=64, ctrl_pos=8) - - def test_ctrl_at_position_1(self, model_pair): - """Control token at position 1. - - With SingleSwitch, position 0 has no special role (no counting - anchor needed). ctrl_pos=0 is tested separately in - test_multiple_ctrl_positions as a NaN regression guard. - """ - name, upstream, switch = model_pair - _assert_gap_equivalence(name, upstream, switch, seq_len=16, ctrl_pos=1) - - def test_ctrl_near_end(self, model_pair): - """Control token near the end of the sequence (pos=seq_len-2).""" - name, upstream, switch = model_pair - _assert_gap_equivalence(name, upstream, switch, seq_len=16, ctrl_pos=14) - - @pytest.mark.parametrize("ctrl_pos", [0, 1, 2, 4, 8, 14]) - def test_multiple_ctrl_positions(self, model_pair, ctrl_pos): - """Sweep control token across multiple positions. - - ctrl_pos=0 is a regression guard for the NaN bug fixed in PR #87: - when the control token sits at position 0 with no other causal key, - softmax([-inf]) = NaN unless q_control is zeroed for group members. - """ - name, upstream, switch = model_pair - _assert_gap_equivalence(name, upstream, switch, seq_len=16, ctrl_pos=ctrl_pos) - - -# ── Test class: Adapter Indices Sanity ──────────────────────────── - - -class TestAdapterIndicesSanity: - """Verify adapter_indices correctness with a single hidden control token. - - Uses a single config (4.0-350m) to check that: - - Positions before the control token have adapter_indices=0 (base) - - Positions at and after the control token have adapter_indices=1 - """ - - @pytest.fixture - def model(self): - cfg_dict = GRANITE4_MINI["4.0-350m"] - _, switch = _make_gap_pair(cfg_dict) - return switch - - def _run(self, model, ctrl_pos, seed=42): - """Run forward pass and return adapter_indices.""" - _, switch_ids = make_gapped_inputs(seq_len=16, ctrl_pos=ctrl_pos, seed=seed) - with torch.no_grad(): - model(input_ids=switch_ids, use_cache=False) - return model.model._last_adapter_indices - - def test_adapter_indices_before_ctrl(self, model): - """Positions before control token should be base (0).""" - ctrl_pos = 4 - ai = self._run(model, ctrl_pos) - assert (ai[:, :ctrl_pos] == 0).all(), ( - f"Pre-control positions should be base, got {ai[:, :ctrl_pos]}" - ) - - def test_adapter_indices_at_and_after_ctrl(self, model): - """Positions at and after control token should be adapter_0 (1).""" - ctrl_pos = 4 - ai = self._run(model, ctrl_pos) - assert (ai[:, ctrl_pos:] == 1).all(), ( - f"Post-control positions should be adapter_0 (1), got {ai[:, ctrl_pos:]}" - ) - - def test_adapter_indices_sweep(self, model): - """Sweep ctrl_pos and verify adapter_indices boundary.""" - for ctrl_pos in [1, 2, 4, 8, 14]: - ai = self._run(model, ctrl_pos, seed=ctrl_pos) - assert (ai[:, :ctrl_pos] == 0).all(), ( - f"ctrl_pos={ctrl_pos}: pre-ctrl should be 0, got {ai[:, :ctrl_pos]}" - ) - assert (ai[:, ctrl_pos:] == 1).all(), ( - f"ctrl_pos={ctrl_pos}: post-ctrl should be 1, got {ai[:, ctrl_pos:]}" - ) diff --git a/tests/hf/test_model_forward.py b/tests/hf/test_model_forward.py index 11cdedc..ce065ee 100644 --- a/tests/hf/test_model_forward.py +++ b/tests/hf/test_model_forward.py @@ -47,7 +47,7 @@ def _set_nonzero_lora_B(model, scale=0.1): @pytest.fixture def tiny_single_config(): - """Minimal SingleSwitch config for CPU tests.""" + """Minimal SingleSwitch config for CPU tests (token exchange).""" return GraniteSwitchConfig( vocab_size=300, hidden_size=64, @@ -57,37 +57,11 @@ def tiny_single_config(): num_key_value_heads=4, num_adapters=2, adapter_token_ids=[250, 251], + adapter_substitute_token_ids=[1, 1], adapter_names=["adapter_1", "adapter_2"], - hiding_groups={"all_controls": ["adapter_1", "adapter_2"]}, - hiding_policy={"base": ["all_controls"], "adapter_1": ["all_controls"], "adapter_2": ["all_controls"]}, - adapter_third_party=["adapter_1", "adapter_2"], max_lora_rank=4, adapter_ranks=[4, 4], switch_head_dim=16, - control_dims=8, - ) - - -@pytest.fixture -def tiny_basic_mixed_tp_config(): - """SingleSwitch config where only adapter_1 is third-party (adapter_2 is not).""" - return GraniteSwitchConfig( - vocab_size=300, - hidden_size=64, - intermediate_size=128, - num_hidden_layers=3, - num_attention_heads=4, - num_key_value_heads=4, - num_adapters=2, - adapter_token_ids=[250, 251], - adapter_names=["adapter_1", "adapter_2"], - hiding_groups={"all_controls": ["adapter_1", "adapter_2"]}, - hiding_policy={"base": ["all_controls"], "adapter_1": ["all_controls"], "adapter_2": ["all_controls"]}, - adapter_third_party=["adapter_1"], # only adapter_1 is third-party - max_lora_rank=4, - adapter_ranks=[4, 4], - switch_head_dim=16, - control_dims=8, ) @@ -238,101 +212,7 @@ def test_different_adapters_produce_different_post_control_logits(self, tiny_con # ════════════════════════════════════════════════════════════════════ -# 6. Control token KV invisibility -# ════════════════════════════════════════════════════════════════════ - -class TestControlTokenKVInvisibility: - """Verify control_dims makes control tokens invisible in KV cache.""" - - def test_control_token_kv_invisible_to_future_positions(self, tiny_config): - """Perturbing a control token's embedding doesn't affect future positions.""" - torch.manual_seed(42) - model = GraniteSwitchForCausalLM(tiny_config).eval() - _set_adapter_token_ids(model, tiny_config.adapter_token_ids) - - # Control token 250 at position 2 - input_ids = torch.tensor([[10, 20, 250, 30, 40, 50, 60, 70]]) - - # Pass A: original embeddings - with torch.no_grad(): - out_a = model(input_ids=input_ids, output_hidden_states=True) - hidden_a = out_a.hidden_states # tuple of [1, 8, hidden_size] - - # Perturb the control token's embedding - with torch.no_grad(): - perturbation = torch.randn(tiny_config.hidden_size) * 10.0 - model.model.embed_tokens.weight.data[250] += perturbation - - # Pass B: perturbed embedding - with torch.no_grad(): - out_b = model(input_ids=input_ids, output_hidden_states=True) - hidden_b = out_b.hidden_states - - # Check each layer's hidden states - for layer_idx in range(len(hidden_a)): - ha = hidden_a[layer_idx][0] # [8, hidden_size] - hb = hidden_b[layer_idx][0] - - # Pre-control (positions 0, 1): identical (causal, can't see pos 2) - torch.testing.assert_close( - ha[:2], hb[:2], - msg=f"Layer {layer_idx}: pre-control hidden states should be identical" - ) - - # At control position (2): must differ (embedding changed) - assert not torch.allclose(ha[2], hb[2]), \ - f"Layer {layer_idx}: control token hidden state should differ after perturbation" - - # Post-control (positions 3+): identical (control token KV is invisible) - torch.testing.assert_close( - ha[3:], hb[3:], - msg=f"Layer {layer_idx}: post-control hidden states should be identical " - f"(control token KV masked by control_dims)" - ) - - -# ════════════════════════════════════════════════════════════════════ -# 7. Control token KV visibility -# ════════════════════════════════════════════════════════════════════ - -class TestControlTokenKVVisibility: - """Verify control tokens are KV-invisible (hidden from attention via control dimensions).""" - - def _make_model(self, config): - torch.manual_seed(42) - model = GraniteSwitchForCausalLM(config).eval() - _set_adapter_token_ids(model, config.adapter_token_ids) - return model - - def test_adapter_token_kv_invisible(self, tiny_single_config): - """Adapter token (250) is KV-invisible: perturbing doesn't affect future.""" - config = tiny_single_config - model = self._make_model(config) - - input_ids = torch.tensor([[10, 20, 250, 30, 40, 50, 60, 70]]) - - with torch.no_grad(): - out_a = model(input_ids=input_ids, output_hidden_states=True) - hidden_a = out_a.hidden_states - - with torch.no_grad(): - perturbation = torch.randn(config.hidden_size) * 10.0 - model.model.embed_tokens.weight.data[250] += perturbation - - with torch.no_grad(): - out_b = model(input_ids=input_ids, output_hidden_states=True) - hidden_b = out_b.hidden_states - - for layer_idx in range(len(hidden_a)): - ha = hidden_a[layer_idx][0] - hb = hidden_b[layer_idx][0] - torch.testing.assert_close( - ha[3:], hb[3:], - msg=f"Layer {layer_idx}: post-adapter-token hidden states should be identical" - ) - -# ════════════════════════════════════════════════════════════════════ -# 8. Activating tokens: switch behavior (explicit adapter_indices) +# 6. Activating tokens: switch behavior (explicit adapter_indices) # ════════════════════════════════════════════════════════════════════ class TestActivatingTokenSwitch: @@ -355,13 +235,13 @@ def test_activating_adapter_indices_nonzero(self, tiny_single_config): # ════════════════════════════════════════════════════════════════════ -# 9. Native mode: control_dims=0 (no KV hiding) +# 7. Token-exchange forward pass tests # ════════════════════════════════════════════════════════════════════ @pytest.fixture def tiny_native_config(): - """Minimal config for native mode (control_dims=0, no hiding).""" + """Minimal config for token-exchange mode.""" return GraniteSwitchConfig( vocab_size=300, hidden_size=64, @@ -371,20 +251,16 @@ def tiny_native_config(): num_key_value_heads=4, num_adapters=2, adapter_token_ids=[250, 251], + adapter_substitute_token_ids=[1, 1], adapter_names=["router", "planner"], max_lora_rank=4, adapter_ranks=[4, 4], switch_head_dim=16, - control_dims=0, - # No hiding - hiding_groups=None, - hiding_policy=None, - adapter_third_party=None, ) class TestNativeModeForward: - """Forward pass tests with control_dims=0 (native mode).""" + """Forward pass tests with token-exchange enabled.""" def test_forward_produces_logits(self, tiny_native_config): """Basic forward pass succeeds and produces correct-shaped logits.""" @@ -399,24 +275,6 @@ def test_forward_produces_logits(self, tiny_native_config): assert output.logits.shape == (1, 5, config.vocab_size) assert torch.isfinite(output.logits).all() - def test_no_expansion_in_attention(self, tiny_native_config): - """Attention layers should not expand control dimensions.""" - config = tiny_native_config - model = GraniteSwitchForCausalLM(config) - - for layer in model.model.layers: - attn = layer.self_attn - assert not attn.expand_control_dims - assert attn.expanded_head_dim == attn.head_dim - - def test_no_hiding_buffers(self, tiny_native_config): - """Model should have no hiding group buffers.""" - config = tiny_native_config - model = GraniteSwitchForCausalLM(config) - - assert model.model.token_to_group_mask is None - assert model.model.adapter_hiding_matrix is None - def test_control_token_logits_finite(self, tiny_native_config): """Control token logits should be finite.""" config = tiny_native_config @@ -430,7 +288,7 @@ def test_control_token_logits_finite(self, tiny_native_config): # All control token logits should be finite for tid in config.adapter_token_ids: assert torch.isfinite(output.logits[:, :, tid]).all(), ( - f"Token {tid} logits should be finite in native mode" + f"Token {tid} logits should be finite" ) def test_adapter_effect_visible(self, tiny_native_config): @@ -451,7 +309,7 @@ def test_adapter_effect_visible(self, tiny_native_config): assert diff > 1e-6, "Adapter should produce different logits" def test_batch_forward(self, tiny_native_config): - """Batched forward pass works with control_dims=0.""" + """Batched forward pass works.""" config = tiny_native_config model = GraniteSwitchForCausalLM(config).eval() _set_adapter_token_ids(model, config.adapter_token_ids) diff --git a/tests/hf/test_position_zero_nan.py b/tests/hf/test_position_zero_nan.py deleted file mode 100644 index 5751d9a..0000000 --- a/tests/hf/test_position_zero_nan.py +++ /dev/null @@ -1,188 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""NaN regression tests — control token at sequence position 0 (HF backend). - -HF-specific unit tests for GraniteLoRAEmbeddedAttention._expand_with_control_dimensions -(batch/seq tensor layout: [batch, seq, heads, head_dim]) plus shared SDPANaNCases. - -Note: model-level finiteness tests are not included here — the NaN bug only manifests -in vLLM's FlashAttention path, not in HF's stable softmax. See tests/vllm/ for those. -""" - -import types - -import torch - -from granite_switch.hf.core.lora import GraniteLoRAEmbeddedAttention - -from tests.shared.position_zero_nan_cases import SDPANaNCases - - -# ── Helpers ──────────────────────────────────────────────────────── - - -def _stub(num_heads=4, num_kv_heads=1, control_dims=1): - """Minimal namespace satisfying _expand_with_control_dimensions's self usage.""" - return types.SimpleNamespace( - num_heads=num_heads, - num_key_value_heads=num_kv_heads, - control_dims=control_dims, - ) - - -def _expand(stub, q, k, v, membership, suppression): - return GraniteLoRAEmbeddedAttention._expand_with_control_dimensions( - stub, q, k, v, membership, suppression - ) - - - -# ════════════════════════════════════════════════════════════════════ -# 1. HF-specific unit tests: _expand_with_control_dimensions -# Tensor layout: [batch, seq_len, num_heads, head_dim] -# ════════════════════════════════════════════════════════════════════ - - -class TestExpandControlDimensions: - """Direct tests of _expand_with_control_dimensions (HF tensor layout). - - token_group_membership=True marks the control token itself. - query_group_suppression=True marks adapter-generated tokens that suppress - the group — these are NOT group members and must keep q_control=1. - """ - - _HEAD_DIM = 32 - - def _qkv(self, stub, seq_len): - q = torch.randn(1, seq_len, stub.num_heads, self._HEAD_DIM) - k = torch.randn(1, seq_len, stub.num_key_value_heads, self._HEAD_DIM) - v = torch.randn(1, seq_len, stub.num_key_value_heads, self._HEAD_DIM) - return q, k, v - - # ── fix: control token must have q_control = 0 ────────────────── - - def test_control_token_q_hide_zero_at_position_zero(self): - """Core fix: control token at pos 0 must not activate Q-side hiding. - - Before the fix q_control was 1.0 unconditionally, so - softmax([−∞]) = NaN when it had no other causal keys. - """ - stub = _stub() - membership = torch.ones(1, 1, 1, dtype=torch.bool) - suppression = torch.ones(1, 1, 1, dtype=torch.bool) - q, k, v = self._qkv(stub, seq_len=1) - - q_exp, _, _ = _expand(stub, q, k, v, membership, suppression) - - q_ctrl = q_exp[0, 0, :, self._HEAD_DIM:] - assert q_ctrl.eq(0).all(), f"Control token at pos 0: q_control must be 0, got {q_ctrl}" - - def test_control_token_q_hide_zero_at_later_position(self): - """Control token q_control is 0 regardless of its sequence position.""" - stub = _stub() - membership = torch.zeros(1, 5, 1, dtype=torch.bool) - membership[0, 3, 0] = True - suppression = torch.ones(1, 5, 1, dtype=torch.bool) - q, k, v = self._qkv(stub, seq_len=5) - - q_exp, _, _ = _expand(stub, q, k, v, membership, suppression) - - assert q_exp[0, 3, :, self._HEAD_DIM:].eq(0).all(), "Control token at pos 3: q_control must be 0" - - # ── adapter-generated tokens must still suppress the control token ── - - def test_adapter_generated_tokens_q_hide_one(self): - """Adapter-generated tokens (non-members) keep q_control=1 to hide the control token.""" - stub = _stub() - membership = torch.zeros(1, 5, 1, dtype=torch.bool) - membership[0, 0, 0] = True # control token at pos 0 - suppression = torch.ones(1, 5, 1, dtype=torch.bool) - q, k, v = self._qkv(stub, seq_len=5) - - q_exp, _, _ = _expand(stub, q, k, v, membership, suppression) - - assert q_exp[0, 0, :, self._HEAD_DIM:].eq(0).all(), "Control token: q_control must be 0" - for pos in range(1, 5): - assert q_exp[0, pos, :, self._HEAD_DIM:].eq(1).all(), ( - f"Adapter-generated token at pos {pos}: q_control must be 1" - ) - - # ── k-side unchanged by fix ────────────────────────────────────── - - def test_k_side_finfo_min_for_control_token(self): - """K-side branding is unaffected by the fix — control token gets finfo.min.""" - stub = _stub() - membership = torch.ones(1, 1, 1, dtype=torch.bool) - q, k, v = self._qkv(stub, seq_len=1) - - _, k_exp, _ = _expand(stub, q, k, v, membership, None) - - expected_min = torch.finfo(k.dtype).min - k_ctrl = k_exp[0, 0, :, self._HEAD_DIM:] - torch.testing.assert_close(k_ctrl, torch.full_like(k_ctrl, expected_min)) - - def test_k_side_zero_for_adapter_generated_tokens(self): - """Adapter-generated tokens have k_control=0.""" - stub = _stub() - q, k, v = self._qkv(stub, seq_len=3) - - _, k_exp, _ = _expand(stub, q, k, v, torch.zeros(1, 3, 1, dtype=torch.bool), None) - - assert k_exp[:, :, :, self._HEAD_DIM:].eq(0).all() - - # ── v-side and no-mask baseline ────────────────────────────────── - - def test_v_control_always_zero(self): - """V control dimensions are always zero.""" - stub = _stub() - q, k, v = self._qkv(stub, seq_len=3) - _, _, v_exp = _expand( - stub, q, k, v, - torch.ones(1, 3, 1, dtype=torch.bool), - torch.ones(1, 3, 1, dtype=torch.bool), - ) - assert v_exp[:, :, :, self._HEAD_DIM:].eq(0).all() - - def test_both_none_leaves_all_control_dims_zero(self): - """With both tensors None, all control dims remain zero.""" - stub = _stub(control_dims=2) - q, k, v = self._qkv(stub, seq_len=4) - q_exp, k_exp, v_exp = _expand(stub, q, k, v, None, None) - assert q_exp[..., self._HEAD_DIM:].eq(0).all() - assert k_exp[..., self._HEAD_DIM:].eq(0).all() - assert v_exp[..., self._HEAD_DIM:].eq(0).all() - - # ── multiple groups ────────────────────────────────────────────── - - def test_multiple_groups_independent(self): - """Control token of group 0 only zeroes q_control for group 0.""" - stub = _stub(control_dims=2) - membership = torch.zeros(1, 1, 2, dtype=torch.bool) - membership[0, 0, 0] = True - suppression = torch.ones(1, 1, 2, dtype=torch.bool) - q, k, v = self._qkv(stub, seq_len=1) - - q_exp, _, _ = _expand(stub, q, k, v, membership, suppression) - q_ctrl = q_exp[0, 0, :, self._HEAD_DIM:] - - assert q_ctrl[:, 0].eq(0).all(), "Group 0 dim must be 0 (control token is a member)" - assert q_ctrl[:, 1].eq(1).all(), "Group 1 dim must be 1 (control token is not a member)" - - def test_original_qkv_dimensions_preserved(self): - """Original Q/K/V dimensions are unchanged; only control dims appended.""" - stub = _stub(control_dims=3) - q, k, v = self._qkv(stub, seq_len=5) - q_exp, k_exp, v_exp = _expand(stub, q, k, v, None, None) - torch.testing.assert_close(q_exp[..., : self._HEAD_DIM], q) - torch.testing.assert_close(k_exp[..., : self._HEAD_DIM], k) - torch.testing.assert_close(v_exp[..., : self._HEAD_DIM], v) - - -# ════════════════════════════════════════════════════════════════════ -# 2. Shared SDPA cases -# ════════════════════════════════════════════════════════════════════ - - -class TestSDPANaN(SDPANaNCases): - pass - - diff --git a/tests/hf/test_qk_norm.py b/tests/hf/test_qk_norm.py index a603919..557d599 100644 --- a/tests/hf/test_qk_norm.py +++ b/tests/hf/test_qk_norm.py @@ -33,7 +33,6 @@ def _make_config(qk_norm: bool, num_adapters: int = 0) -> GraniteSwitchConfig: adapter_token_ids=[], adapter_names=[], adapter_ranks=[], - control_dims=0, qk_norm=qk_norm, ) config._attn_implementation = "sdpa" @@ -110,12 +109,10 @@ def test_output_differs_with_qk_norm(self): with torch.no_grad(): out_off, _, _ = attn_off( hidden, adapter_indices, - token_group_membership=None, query_group_suppression=None, position_embeddings=pos_emb, ) out_on, _, _ = attn_on( hidden, adapter_indices, - token_group_membership=None, query_group_suppression=None, position_embeddings=pos_emb, ) @@ -137,7 +134,6 @@ def test_output_shape_preserved(self): with torch.no_grad(): out, _, _ = attn( hidden, adapter_indices, - token_group_membership=None, query_group_suppression=None, position_embeddings=pos_emb, ) diff --git a/tests/hf/test_single_switch.py b/tests/hf/test_single_switch.py index 234d9c8..5d186b4 100644 --- a/tests/hf/test_single_switch.py +++ b/tests/hf/test_single_switch.py @@ -135,8 +135,12 @@ def _run(self, seq, num_adapters=NUM_ADAPTERS, control_token_gain=15.0): switch = _make_switch(self._backend, num_adapters, control_token_gain) token_ids = torch.tensor(ADAPTER_TOKEN_IDS_LIST[:num_adapters]) input_ids = torch.tensor([seq]) - result = switch.forward(input_ids=input_ids, adapter_token_ids=token_ids) - return result[0].tolist() + # Switch returns (adapter_indices, modified_input_ids); these tests + # only check adapter selection so we drop the rewritten ids here. + adapter_indices, _modified = switch.forward( + input_ids=input_ids, adapter_token_ids=token_ids, + ) + return adapter_indices[0].tolist() # ── Shared test classes (from mixin) ──────────────────────────────── @@ -177,6 +181,8 @@ def test_batch_independence(self, backend): [TEXT_TOKEN, ADAPTER_TOKEN_IDS_LIST[0], TEXT_TOKEN, TEXT_TOKEN, TEXT_TOKEN], [TEXT_TOKEN, ADAPTER_TOKEN_IDS_LIST[3], TEXT_TOKEN, TEXT_TOKEN, TEXT_TOKEN], ]) - result = switch.forward(input_ids=input_ids, adapter_token_ids=token_ids) - assert (result[0, 2:] == 1).all() - assert (result[1, 2:] == 4).all() + adapter_indices, _modified = switch.forward( + input_ids=input_ids, adapter_token_ids=token_ids, + ) + assert (adapter_indices[0, 2:] == 1).all() + assert (adapter_indices[1, 2:] == 4).all() diff --git a/tests/hf/test_single_switch_e2e.py b/tests/hf/test_single_switch_e2e.py index e02856b..89401d4 100644 --- a/tests/hf/test_single_switch_e2e.py +++ b/tests/hf/test_single_switch_e2e.py @@ -11,8 +11,8 @@ GraniteSwitchConfig → create_switch() → SingleSwitch.__init__ → model forward → _last_adapter_indices -Parametrized over both PRODUCTION_ATTENTION_MULTIPLIERS and both control_dims -modes (native and hiding) to catch config-flow regressions in either code path. +Parametrized over PRODUCTION_ATTENTION_MULTIPLIERS to catch config-flow +regressions across the realistic multiplier values. CPU-only. Does not exercise vLLM gain compensation — HF SingleSwitch hardcodes scaling=1.0 regardless of config. Compensation is tested in the Tier 2 composer @@ -35,11 +35,6 @@ ) from tests.shared.single_switch_cases import ADAPTER_TOKEN_IDS_LIST, NUM_ADAPTERS -# control_dims=0 → native mode (no KV hiding). control_dims=32 → hiding mode. -# Both take different code paths through SingleSwitch.__init__ (expanded_head_dim) -# and through GraniteSwitchModel.forward (hiding-group mask construction). -CONTROL_DIMS_MODES = [0, 32] - # TEXT_TOKEN matches tests/shared/single_switch_cases.py convention. Any # non-adapter token ID works — 50 is outside ADAPTER_TOKEN_IDS_LIST (1000+). TEXT_TOKEN = 50 @@ -62,35 +57,20 @@ ) -def _build_e2e_overrides(base_cfg, *, num_adapters=NUM_ADAPTERS, control_dims=32): - """Build config overrides for a production-ish E2E test model. - - Three overrides beyond the `single_overrides()` defaults: - - vocab_size: large enough to hold every adapter token ID (derived). - - max_position_embeddings: supports the long-context test matrix (derived). - - control_dims parametrized: native (0) vs hiding (32+). - """ +def _build_e2e_overrides(base_cfg, *, num_adapters=NUM_ADAPTERS): + """Build config overrides for a production-ish E2E test model.""" adapter_names = [f"adapter_{i}" for i in range(num_adapters)] - overrides = { + return { "vocab_size": _E2E_VOCAB_SIZE, "max_position_embeddings": _E2E_MAX_POSITION_EMBEDDINGS, "num_adapters": num_adapters, "adapter_ranks": [8] * num_adapters, "adapter_token_ids": ADAPTER_TOKEN_IDS_LIST[:num_adapters], "adapter_names": adapter_names, - "control_dims": control_dims, + "adapter_substitute_token_ids": [1] * num_adapters, "num_hidden_layers": len(base_cfg["layer_types"]) + 1, "layer_types": ["attention"] + base_cfg["layer_types"], } - if control_dims > 0: - # Hiding mode needs hiding_groups + hiding_policy + adapter_third_party. - overrides["hiding_groups"] = {"all_controls": adapter_names} - overrides["hiding_policy"] = { - n: ["all_controls"] for n in ["base"] + adapter_names - } - overrides["adapter_third_party"] = adapter_names - # control_dims == 0 → native mode → no hiding_groups/policy. - return overrides def _make_e2e_model(base_cfg, overrides): @@ -107,29 +87,28 @@ def _make_e2e_model(base_cfg, overrides): # Module scope would save ~19s across the long-context matrix but would require # auditing that no test mutates model state — not worth it. @pytest.fixture( - params=[(m, cd) for m in PRODUCTION_ATTENTION_MULTIPLIERS for cd in CONTROL_DIMS_MODES], - ids=lambda p: f"mult={p[0]}-cd={p[1]}", + params=PRODUCTION_ATTENTION_MULTIPLIERS, + ids=lambda m: f"mult={m}", ) def e2e_model(request): - """GraniteSwitchForCausalLM parametrized over (attention_multiplier, control_dims).""" - multiplier, control_dims = request.param + """GraniteSwitchForCausalLM parametrized over attention_multiplier.""" + multiplier = request.param base_cfg = {**DENSE_CFG, "attention_multiplier": multiplier} - overrides = _build_e2e_overrides(base_cfg, control_dims=control_dims) + overrides = _build_e2e_overrides(base_cfg) model, config = _make_e2e_model(base_cfg, overrides) - return model, config, multiplier, control_dims + return model, config, multiplier @pytest.fixture def e2e_model_32adapter(): """Single-variant fixture for the 32-adapter stress test. - The adapter-ID rounding math is independent of (multiplier, control_dims), - so we don't parametrize this fixture — TestE2EBasicAdapterActivation already - covers the cross-product. Chosen variant: hiding mode (control_dims=32) with - the most common production multiplier (0.0078125, granite-4.0-h-1b/tiny/small/4.1-8b/30b). + The adapter-ID rounding math is independent of multiplier, so we don't + parametrize. Chosen variant: most common production multiplier + (0.0078125, granite-4.0-h-1b/tiny/small/4.1-8b/30b). """ base_cfg = {**DENSE_CFG, "attention_multiplier": 0.0078125} - overrides = _build_e2e_overrides(base_cfg, control_dims=32) + overrides = _build_e2e_overrides(base_cfg) model, config = _make_e2e_model(base_cfg, overrides) return model, config @@ -160,10 +139,9 @@ def test_pre_control_is_zero_post_control_matches_adapter(self, e2e_model): from position 2 onward; positions before it remain at 0 (base). Proves the full chain config → create_switch → forward → - _last_adapter_indices works on a production-ish multiplier/control_dims - combination. + _last_adapter_indices works on a production-ish multiplier. """ - model, config, mult, cd = e2e_model + model, config, mult = e2e_model ctrl_token = config.adapter_token_ids[0] # adapter_0 → expected index 1 input_ids = torch.tensor([[10, 20, ctrl_token, 30, 40, 50, 60, 70]]) with torch.no_grad(): @@ -227,7 +205,7 @@ def test_long_context_e2e(self, e2e_model, seq_len, adapter_idx, control_positio Default CI runs seq_len ∈ {10K, 32K}; `-m slow` adds 65K and 131K. """ - model, config, mult, cd = e2e_model + model, config, mult = e2e_model ctrl_token = config.adapter_token_ids[adapter_idx] expected_id = adapter_idx + 1 ctrl_pos = _control_position(seq_len, control_position) @@ -244,10 +222,10 @@ def test_long_context_e2e(self, e2e_model, seq_len, adapter_idx, control_positio assert (ai[:ctrl_pos] == 0).all(), ( f"pre-control slice should be all 0; failed at seq_len={seq_len}, " f"adapter_idx={adapter_idx}, position={control_position}, " - f"mult={mult}, cd={cd}" + f"mult={mult}" ) assert (ai[ctrl_pos:] == expected_id).all(), ( f"post-control slice should be all {expected_id}; failed at seq_len={seq_len}, " f"adapter_idx={adapter_idx}, position={control_position}, " - f"mult={mult}, cd={cd}" + f"mult={mult}" ) diff --git a/tests/hf/test_token_exchange.py b/tests/hf/test_token_exchange.py new file mode 100644 index 0000000..ae13392 --- /dev/null +++ b/tests/hf/test_token_exchange.py @@ -0,0 +1,109 @@ +# SPDX-License-Identifier: Apache-2.0 +"""HF backend tests for token-exchange mode. + +Two properties under test: +1. The embedding at each control-token position equals the embedding of its + substitute token (scaled by embedding_multiplier), not the original + control-token embedding. +2. The KV cache head_dim is the native projection_head_dim — token-exchange + does not expand the KV cache. +""" + +import pytest +import torch + +from granite_switch.config import GraniteSwitchConfig +from granite_switch.hf import GraniteSwitchForCausalLM + + +def _build(num_adapters=2, substitute_ids=(1, 7)): + return GraniteSwitchConfig( + vocab_size=200, + hidden_size=32, + num_attention_heads=4, + num_key_value_heads=2, + num_hidden_layers=2, + intermediate_size=64, + shared_intermediate_size=64, + max_position_embeddings=64, + mamba_n_heads=1, + mamba_expand=1, + num_adapters=num_adapters, + adapter_ranks=[4] * num_adapters, + max_lora_rank=4, + adapter_token_ids=[100, 101][:num_adapters], + adapter_names=["a", "b"][:num_adapters], + adapter_substitute_token_ids=list(substitute_ids[:num_adapters]), + torch_dtype=torch.float32, + ) + + +@torch.no_grad() +def _forward(config, input_ids): + model = GraniteSwitchForCausalLM(config).eval() + return model, model(input_ids=input_ids, use_cache=True) + + +class TestTokenExchangeEmbeddingSwap: + """The control position's residual-stream input is the substitute embedding.""" + + def test_swap_picks_substitute_embedding(self): + config = _build(substitute_ids=(5, 7)) + model, _ = _forward( + config, + torch.tensor([[10, 20, 100, 40]], dtype=torch.long), # adapter 0 control at pos 2 + ) + # The LUT lives on the switch (it performs the rewrite during its + # forward); maps control id 100 → substitute 5. + lut = model.model.switch.control_to_substitute_lut + assert lut[100].item() == 5 + assert lut[101].item() == 7 + # Positions without control tokens map to -1. + assert lut[10].item() == -1 + assert lut[40].item() == -1 + + def test_swap_is_not_applied_on_non_control_positions(self): + config = _build(substitute_ids=(5, 7)) + model = GraniteSwitchForCausalLM(config).eval() + # Run once through the model with a control token and once without; + # verify the non-control embedding rows are identical. + raw_a = model.model.embed_tokens(torch.tensor([[10, 20, 30, 40]], dtype=torch.long)) + raw_b = model.model.embed_tokens(torch.tensor([[10, 20, 100, 40]], dtype=torch.long)) + # Positions 0, 1, 3 should match; position 2 is the control token (differs). + assert torch.allclose(raw_a[:, 0], raw_b[:, 0]) + assert torch.allclose(raw_a[:, 1], raw_b[:, 1]) + assert torch.allclose(raw_a[:, 3], raw_b[:, 3]) + + +class TestKVCacheHeadDim: + """The load-bearing correctness property: KV cache head_dim equals + the native projection_head_dim — no expansion.""" + + def test_token_exchange_native_head_dim(self): + config = _build(substitute_ids=(5, 7)) + _, out = _forward( + config, + torch.tensor([[10, 20, 100, 40]], dtype=torch.long), + ) + # layers[0] is the switch; layers[1] is the first decoder layer. + decoder_key = out.past_key_values.layers[1].keys + assert decoder_key.shape[-1] == config.projection_head_dim + + +class TestSwitchStillDetectsAdapter: + """Swap must happen AFTER the switch reads input_ids, so detection is unaffected.""" + + def test_adapter_indices_still_activate(self): + config = _build(substitute_ids=(5, 7)) + model, _ = _forward( + config, + torch.tensor([[10, 20, 100, 40, 50]], dtype=torch.long), + ) + adapter_indices = model.model._last_adapter_indices + # Position 2 is the control token for adapter 0 (1-indexed output). + # Positions after it inherit adapter=1 (SingleSwitch persists once fired). + assert adapter_indices[0, 0].item() == 0 + assert adapter_indices[0, 1].item() == 0 + assert adapter_indices[0, 2].item() == 1 + assert adapter_indices[0, 3].item() == 1 + assert adapter_indices[0, 4].item() == 1 diff --git a/tests/integration/test_hf_to_vllm_weights.py b/tests/integration/test_hf_to_vllm_weights.py index 1daeecd..711652b 100644 --- a/tests/integration/test_hf_to_vllm_weights.py +++ b/tests/integration/test_hf_to_vllm_weights.py @@ -373,17 +373,10 @@ def _config(self): num_adapters=2, adapter_token_ids=[250, 251], adapter_names=["adapter_1", "adapter_2"], - hiding_groups={"all_controls": ["adapter_1", "adapter_2"]}, - hiding_policy={ - "base": ["all_controls"], - "adapter_1": ["all_controls"], - "adapter_2": ["all_controls"], - }, - adapter_third_party=["adapter_1", "adapter_2"], + adapter_substitute_token_ids=[1, 1], max_lora_rank=4, adapter_ranks=[4, 4], switch_head_dim=32, - control_dims=32, max_position_embeddings=512, attention_multiplier=1.0, embedding_multiplier=1.0, diff --git a/tests/shared/generation_models.py b/tests/shared/generation_models.py index 2635c22..6c5e477 100644 --- a/tests/shared/generation_models.py +++ b/tests/shared/generation_models.py @@ -49,20 +49,14 @@ def single_overrides(base_cfg): - """SingleSwitch overrides for the given base config.""" + """SingleSwitch overrides for the given base config (token exchange).""" base_layers = base_cfg["layer_types"] return { "num_adapters": NUM_ADAPTERS, "adapter_ranks": [ADAPTER_RANK] * NUM_ADAPTERS, "adapter_token_ids": [250, 251], + "adapter_substitute_token_ids": [1, 1], "adapter_names": ["adapter_0", "adapter_1"], - "hiding_groups": {"all_controls": ["adapter_0", "adapter_1"]}, - "hiding_policy": { - "base": ["all_controls"], - "adapter_0": ["all_controls"], - "adapter_1": ["all_controls"], - }, - "adapter_third_party": ["adapter_0", "adapter_1"], "num_hidden_layers": len(base_layers) + 1, "layer_types": ["attention"] + base_layers, } diff --git a/tests/shared/granite4_equivalence.py b/tests/shared/granite4_equivalence.py index 865a68c..3237c52 100644 --- a/tests/shared/granite4_equivalence.py +++ b/tests/shared/granite4_equivalence.py @@ -122,31 +122,29 @@ def transfer_weights_strict(upstream_sd, switch_sd): # zero LoRA weights the delta is zero for the LoRA path. # # All control tokens (adapter_token_ids) are KV-hidden by default. -# Test sequences use adapter tokens; hidden positions are excluded +# Test sequences use adapter tokens; the control position is excluded # from comparison via get_visible_mask(). # -# KV hiding: control tokens get K=finfo.min masking on control dims, -# zeroing their attention contribution at hidden positions. +# Token-exchange: at the switch, each control token's id is rewritten to +# its configured substitute id before the decoder embeds. Upstream sees +# the original control id; switch sees the substitute. The two embeddings +# differ at the control position only. # -# Error sources (with control_dims=32, exact K=-inf masking): +# Error sources: # -# 1. Hidden token attention contribution removal (primary): -# The upstream model attends to control tokens normally. The switch -# model masks them exactly (K=-inf -> zero attention weight). The diff -# is the value of the removed attention contribution -- fundamental -# and unavoidable. Error scales with hidden_tokens / seq_len. +# 1. Embedding divergence at control positions (primary): +# Upstream embeds the original control id; switch embeds the substitute. +# The control position itself is excluded from comparison. Visible +# positions attend to the control position too — the attention +# contribution from the (substitute vs original) embedding propagates. +# Error scales with control_tokens / seq_len. # -# 2. Expanded tensor FP rounding (secondary): -# control_dims adds extra dimensions to Q/K/V, changing the dot -# product accumulation in the attention kernel (D+32 vs D elements). -# This introduces small FP differences even for real token positions. -# -# 3. Mamba conv1d zero-gap (additional, hybrid only): -# Input zeroing writes zeros into conv1d's sliding window, perturbing -# K-1 subsequent real tokens per hidden token (issue #5). +# 2. Mamba conv1d effects (hybrid only): +# Conv1d's sliding window over substituted vs original tokens at the +# control position perturbs K-1 subsequent real tokens. # # Token allocation (within vocab_size=256): -# 101+ = adapter_token_ids (KV-hidden, activate switch) +# 101+ = adapter_token_ids (rewritten to substitute by switch) # Random fill: [0, 100) -- guaranteed no collisions with control tokens # Control token IDs -- low-vocab, valid embeddings in the base model @@ -165,9 +163,9 @@ def augment_cfg_with_adapters(cfg_dict, num_adapters=2, rank=8): - num_hidden_layers += 1 (1 cache slot for SingleSwitch) - layer_types prepended with "attention" (switch layer type) - LoRA adapter config fields - - adapter_token_ids (KV-hidden) + - adapter_token_ids (rewritten to substitute ids by the switch) + - adapter_substitute_token_ids (token-exchange substitutes) - adapter_names for name-to-index mapping - - control_dims=32 (default: exact K=-inf masking, no softmax dilution) """ cfg = dict(cfg_dict) @@ -186,13 +184,10 @@ def augment_cfg_with_adapters(cfg_dict, num_adapters=2, rank=8): cfg["adapter_token_ids"] = [ _ADAPTER_TOKEN_BASE + i for i in range(num_adapters) ] - - # Default hiding config: all adapters in a single group, all hide it. - cfg["hiding_groups"] = {"all_controls": list(adapter_names)} - cfg["hiding_policy"] = { - name: ["all_controls"] for name in ["base"] + list(adapter_names) - } - cfg["adapter_third_party"] = list(adapter_names) + # Token-exchange substitute ids — use a benign shared id (the BOS-or- + # equivalent doesn't matter for these synthetic equivalence tests since + # all LoRA weights are zero, so the embedding is what feeds the decoder). + cfg["adapter_substitute_token_ids"] = [1] * num_adapters return cfg @@ -233,13 +228,14 @@ def zero_lora_weights(model): def get_visible_mask(input_ids): - """Return boolean mask of non-hidden (visible) positions. + """Return boolean mask of non-control (visible) positions. - Positions in a hiding group get K=finfo.min masking on control dims, - making their logits intentionally different from upstream. This mask - identifies positions that should be compared in equivalence tests. + Control positions hold the substitute embedding in the switch model + versus the original control-token embedding upstream — their logits + are intentionally different. This mask identifies positions that + should be compared in equivalence tests. - All adapter tokens (>= _ADAPTER_TOKEN_BASE) are KV-hidden. + All adapter tokens (>= _ADAPTER_TOKEN_BASE) are control positions. Fill tokens from [0, 100) are visible. """ is_adapter = input_ids >= _ADAPTER_TOKEN_BASE @@ -252,38 +248,33 @@ def get_visible_mask(input_ids): def get_tolerances(layer_types, long_sequence=False, has_kv_hidden=False): """Return (atol, rtol) for a given architecture. - Error sources (systematic analysis): - - 1. **No hiding, no adapters**: GraniteSwitch with num_adapters=0 is - bit-exact vs upstream Granite (all configs). Fused QKV matmul is - bit-identical to separate Q/K/V matmuls in float32. + Error sources: - 2. **Hidden token attention contribution removal**: When control tokens - are hidden, the switch model masks them exactly (K=-inf, zero attention - weight via control_dims). Visible tokens lose the attention contribution - that the upstream model gets from those positions. Fundamental and - unavoidable — error scales with hidden_tokens / seq_len. + 1. **No adapters**: GraniteSwitch with num_adapters=0 is bit-exact vs + upstream Granite. Fused QKV matmul is bit-identical to separate + Q/K/V matmuls in float32. - 3. **Expanded tensor FP rounding**: control_dims adds extra dimensions - to Q/K/V, changing the attention kernel's dot product accumulation - (D+32 vs D elements). Small FP rounding differences at real positions. + 2. **Token-exchange embedding divergence**: With adapters and a control + token in the input, the switch embeds the substitute id at that + position while upstream embeds the original control id. Visible + positions attending to the control position pick up that delta. Args: layer_types: list of "attention" strings long_sequence: unused (kept for API compatibility) - has_kv_hidden: True when control token hiding is active + has_kv_hidden: True when adapters are active and control tokens + are present (kept name for API compatibility — the parameter + now means "control tokens get substituted"). Returns: (atol, rtol) tuple, or None if bit-exact match expected. """ if not has_kv_hidden: - # No hiding: bit-exact (fused QKV is numerically identical, - # control_dims expansion adds exactly 0 to dot products). + # Pure base-model path: bit-exact (fused QKV numerically identical). return None else: - # Attention-only with hiding (control_dims=32): hidden token - # attention contribution removed. - # Worst observed: ~5.0e-2 (multi 1b, seed-dependent). + # Substitute-embedding propagates through attention to visible + # positions. Worst observed: ~5.0e-2 (multi 1b, seed-dependent). return (6e-2, 6e-2) diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 9280a0c..7225958 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 """Config validation tests for GraniteSwitchConfig. -Tests every ValueError path in __init__, default values, and derived properties. +Covers the validators in __init__, default values, and the config +fields that survived the legacy-hiding removal. """ import pytest @@ -11,8 +12,9 @@ # ── Helper ──────────────────────────────────────────────────────────── + def _valid_kwargs(num_adapters=2, **overrides): - """Return kwargs for a valid SingleSwitch config, with optional overrides.""" + """Return kwargs for a valid token-exchange config.""" adapter_names = [f"adapter_{i}" for i in range(num_adapters)] base = dict( vocab_size=300, @@ -23,6 +25,7 @@ def _valid_kwargs(num_adapters=2, **overrides): num_key_value_heads=4, num_adapters=num_adapters, adapter_token_ids=list(range(500, 500 + num_adapters)), + adapter_substitute_token_ids=[1] * num_adapters, adapter_names=adapter_names, max_lora_rank=8, adapter_ranks=[8] * num_adapters, @@ -35,6 +38,7 @@ def _valid_kwargs(num_adapters=2, **overrides): # 1. Config validation — every ValueError path # ════════════════════════════════════════════════════════════════════ + class TestConfigValidation: def test_negative_num_adapters_raises(self): @@ -43,137 +47,56 @@ def test_negative_num_adapters_raises(self): def test_adapter_token_ids_wrong_length_raises(self): with pytest.raises(ValueError, match="adapter_token_ids length"): - GraniteSwitchConfig(**_valid_kwargs( - adapter_token_ids=[500, 501, 502], # length 3, expected 2 - )) - - def test_missing_adapter_ranks_raises(self): + GraniteSwitchConfig(**_valid_kwargs(adapter_token_ids=[500])) + + def test_substitute_ids_required_when_adapters_present(self): + with pytest.raises(ValueError, match="adapter_substitute_token_ids is required"): + GraniteSwitchConfig( + **_valid_kwargs(adapter_substitute_token_ids=None) + ) + + def test_substitute_ids_wrong_length_raises(self): + with pytest.raises(ValueError, match="adapter_substitute_token_ids length"): + GraniteSwitchConfig( + **_valid_kwargs(adapter_substitute_token_ids=[1]) + ) + + def test_substitute_ids_negative_raises(self): + with pytest.raises(ValueError, match=">= 0"): + GraniteSwitchConfig( + **_valid_kwargs(adapter_substitute_token_ids=[-1, 1]) + ) + + def test_duplicate_adapter_token_ids_raises(self): + with pytest.raises(ValueError, match="adapter_token_ids must be unique"): + GraniteSwitchConfig(**_valid_kwargs(adapter_token_ids=[500, 500])) + + def test_adapter_ranks_required(self): with pytest.raises(ValueError, match="adapter_ranks must be provided"): GraniteSwitchConfig(**_valid_kwargs(adapter_ranks=None)) - def test_adapter_ranks_wrong_length_raises(self): + def test_adapter_ranks_wrong_length(self): with pytest.raises(ValueError, match="adapter_ranks length"): - GraniteSwitchConfig(**_valid_kwargs( - adapter_ranks=[8], # length 1, expected 2 - )) + GraniteSwitchConfig(**_valid_kwargs(adapter_ranks=[8])) - def test_max_adapter_ranks_mismatch_raises(self): - with pytest.raises(ValueError, match="max.*adapter_ranks.*must equal max_lora_rank"): - GraniteSwitchConfig(**_valid_kwargs( - adapter_ranks=[4, 4], # max=4, but max_lora_rank=8 - )) + def test_max_lora_rank_must_match(self): + with pytest.raises(ValueError, match="max_lora_rank"): + GraniteSwitchConfig(**_valid_kwargs(max_lora_rank=4)) # ════════════════════════════════════════════════════════════════════ -# 2. Config defaults and derived properties +# 2. Defaults # ════════════════════════════════════════════════════════════════════ + class TestConfigDefaults: - def test_zero_adapters_no_validation(self): - """Config with 0 adapters should not require adapter_ranks or token_ids.""" - cfg = GraniteSwitchConfig( - vocab_size=256, hidden_size=64, intermediate_size=128, - num_hidden_layers=2, num_attention_heads=4, num_key_value_heads=4, - num_adapters=0, - ) + def test_zero_adapter_default(self): + cfg = GraniteSwitchConfig(num_adapters=0) assert cfg.num_adapters == 0 - assert cfg.adapter_ranks is None - - -# ════════════════════════════════════════════════════════════════════ -# 3. Hiding groups and policy -# ════════════════════════════════════════════════════════════════════ - -class TestHidingConfig: - - def test_hiding_groups_none_by_default(self): - """No hiding groups when not specified.""" - cfg = GraniteSwitchConfig(**_valid_kwargs()) - assert cfg.num_hiding_groups == 0 - assert cfg.hiding_group_names == [] - - def test_hiding_groups_count(self): - """num_hiding_groups reflects configured groups.""" - cfg = GraniteSwitchConfig(**_valid_kwargs( - hiding_groups={ - "group_a": ["adapter_0"], - "group_b": ["adapter_1"], - }, - )) - assert cfg.num_hiding_groups == 2 - assert cfg.hiding_group_names == ["group_a", "group_b"] - - def test_control_dims_less_than_groups_raises(self): - """control_dims must be >= number of hiding groups.""" - with pytest.raises(ValueError, match="control_dims.*must be >= number of hiding groups"): - GraniteSwitchConfig(**_valid_kwargs( - control_dims=1, - hiding_groups={ - "g1": ["adapter_0"], - "g2": ["adapter_1"], - }, - )) - - def test_get_hiding_group_token_ids(self): - """Token IDs resolved correctly for SingleSwitch.""" - cfg = GraniteSwitchConfig(**_valid_kwargs( - hiding_groups={"all_controls": ["adapter_0", "adapter_1"]}, - )) - group_tokens = cfg.get_hiding_group_token_ids() - # SingleSwitch: no base offset, adapter_0 → token 500, adapter_1 → token 501 - assert group_tokens == {0: [500, 501]} - - def test_get_hiding_group_token_ids_multiple_groups(self): - """Multiple groups with different adapter assignments.""" - cfg = GraniteSwitchConfig(**_valid_kwargs( - hiding_groups={ - "group_a": ["adapter_0"], - "group_b": ["adapter_1"], - }, - )) - group_tokens = cfg.get_hiding_group_token_ids() - assert group_tokens == {0: [500], 1: [501]} - - def test_get_adapter_hiding_policy_matrix(self): - """Policy matrix built correctly from named config.""" - cfg = GraniteSwitchConfig(**_valid_kwargs( - hiding_groups={ - "group_a": ["adapter_0"], - "group_b": ["adapter_1"], - }, - hiding_policy={ - "base": ["group_a", "group_b"], - "adapter_0": ["group_b"], - "adapter_1": ["group_a"], - }, - )) - matrix = cfg.get_adapter_hiding_policy_matrix() - # [base, adapter_0, adapter_1] x [group_a, group_b] - assert matrix == [ - [True, True], # base hides both - [False, True], # adapter_0 hides group_b only - [True, False], # adapter_1 hides group_a only - ] - - def test_get_adapter_hiding_policy_matrix_no_policy(self): - """Empty matrix when no policy configured.""" - cfg = GraniteSwitchConfig(**_valid_kwargs()) - assert cfg.get_adapter_hiding_policy_matrix() == [] - - -# ════════════════════════════════════════════════════════════════════ -# 4. Third-party adapter config -# ════════════════════════════════════════════════════════════════════ - -class TestAdapterThirdParty: - - def test_adapter_third_party_stored(self): - cfg = GraniteSwitchConfig(**_valid_kwargs( - adapter_third_party=["adapter_0"], - )) - assert cfg.adapter_third_party == ["adapter_0"] + assert cfg.adapter_token_ids is None + assert cfg.adapter_substitute_token_ids is None - def test_adapter_third_party_none_by_default(self): + def test_projection_head_dim_inferred_from_hidden_size(self): cfg = GraniteSwitchConfig(**_valid_kwargs()) - assert cfg.adapter_third_party is None + assert cfg.projection_head_dim == 64 // 4 diff --git a/tests/unit/test_config_edge_cases.py b/tests/unit/test_config_edge_cases.py index b161920..917d933 100644 --- a/tests/unit/test_config_edge_cases.py +++ b/tests/unit/test_config_edge_cases.py @@ -1,13 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -"""Additional config edge case tests for GraniteSwitchConfig. - -These tests cover edge cases not covered by the main test_config.py, -specifically targeting previously uncovered code paths: -- Line 99: shared_intermediate_size default from intermediate_size -- Line 119: negative control_dims validation -- Lines 220, 222: get_hiding_group_token_ids with missing configs -- Lines 250-259: get_third_party_adapter_mask functionality -""" +"""Additional config edge case tests for GraniteSwitchConfig.""" import pytest @@ -15,7 +7,7 @@ def _valid_kwargs(num_adapters=2, **overrides): - """Return kwargs for a valid SingleSwitch config, with optional overrides.""" + """Return kwargs for a valid token-exchange config.""" adapter_names = [f"adapter_{i}" for i in range(num_adapters)] base = dict( vocab_size=300, @@ -26,6 +18,7 @@ def _valid_kwargs(num_adapters=2, **overrides): num_key_value_heads=4, num_adapters=num_adapters, adapter_token_ids=list(range(500, 500 + num_adapters)), + adapter_substitute_token_ids=[1] * num_adapters, adapter_names=adapter_names, max_lora_rank=8, adapter_ranks=[8] * num_adapters, @@ -35,195 +28,54 @@ def _valid_kwargs(num_adapters=2, **overrides): class TestSharedIntermediateSize: - """Tests for shared_intermediate_size default handling (line 99). - - Note: The parent GraniteMoeHybridConfig may have a non-None default, - so line 99 (the None check) may not always be hit. We test the - explicit case and verify the config has a sensible value. - """ + """The parent GraniteMoeHybridConfig may have a non-None default for + shared_intermediate_size. Verify our config has a sensible value.""" def test_shared_intermediate_size_has_value(self): - """shared_intermediate_size has a value (either explicit or parent default).""" cfg = GraniteSwitchConfig(**_valid_kwargs()) - # Should have a sensible value (not None) assert cfg.shared_intermediate_size is not None assert cfg.shared_intermediate_size > 0 def test_explicit_shared_intermediate_size_preserved(self): - """Explicit shared_intermediate_size is preserved.""" cfg = GraniteSwitchConfig(**_valid_kwargs( shared_intermediate_size=256, )) assert cfg.shared_intermediate_size == 256 -class TestControlDimsValidation: - """Tests for control_dims validation (line 119).""" - - def test_negative_control_dims_raises(self): - """Negative control_dims should raise ValueError.""" - with pytest.raises(ValueError, match="control_dims must be >= 0"): - GraniteSwitchConfig(**_valid_kwargs(control_dims=-1)) - - def test_zero_control_dims_valid(self): - """Zero control_dims is valid (native mode, no KV hiding).""" - cfg = GraniteSwitchConfig(**_valid_kwargs(control_dims=0)) - assert cfg.control_dims == 0 - - def test_positive_control_dims_valid(self): - """Positive control_dims is valid.""" - cfg = GraniteSwitchConfig(**_valid_kwargs(control_dims=64)) - assert cfg.control_dims == 64 - - -class TestGetHidingGroupTokenIds: - """Tests for get_hiding_group_token_ids edge cases (lines 220, 222).""" - - def test_no_hiding_groups_returns_empty(self): - """Empty dict when hiding_groups is None (line 219).""" - cfg = GraniteSwitchConfig(**_valid_kwargs(hiding_groups=None)) - result = cfg.get_hiding_group_token_ids() - assert result == {} - - def test_no_adapter_names_returns_empty(self): - """Empty dict when adapter_names is None (line 219).""" - cfg = GraniteSwitchConfig(**_valid_kwargs( - adapter_names=None, - hiding_groups={"all": ["adapter_0"]}, - )) - result = cfg.get_hiding_group_token_ids() - assert result == {} - - def test_no_adapter_token_ids_returns_empty(self): - """Empty dict when adapter_token_ids is None (line 222).""" - cfg = GraniteSwitchConfig(**_valid_kwargs( - adapter_token_ids=None, - hiding_groups={"all": ["adapter_0"]}, - )) - result = cfg.get_hiding_group_token_ids() - assert result == {} - - def test_partial_adapter_name_match(self): - """Only matching adapter names are included in result.""" - cfg = GraniteSwitchConfig(**_valid_kwargs( - hiding_groups={"all": ["adapter_0", "nonexistent_adapter"]}, - )) - result = cfg.get_hiding_group_token_ids() - # Only adapter_0 should be in the result (token 500) - assert result == {0: [500]} - - -class TestGetThirdPartyAdapterMask: - """Tests for get_third_party_adapter_mask (lines 250-259).""" - - def test_no_third_party_returns_all_false(self): - """All-False mask when adapter_third_party is not configured.""" - cfg = GraniteSwitchConfig(**_valid_kwargs( - adapter_third_party=None, - )) - mask = cfg.get_third_party_adapter_mask() - # Length = num_adapters + 1 (base + adapters) - assert len(mask) == 3 # base + 2 adapters - assert mask == [False, False, False] - - def test_empty_third_party_returns_all_false(self): - """All-False mask when adapter_third_party is empty list.""" - cfg = GraniteSwitchConfig(**_valid_kwargs( - adapter_third_party=[], - )) - mask = cfg.get_third_party_adapter_mask() - assert mask == [False, False, False] - - def test_no_adapter_names_returns_all_false(self): - """All-False mask when adapter_names is None.""" - cfg = GraniteSwitchConfig(**_valid_kwargs( - adapter_names=None, - adapter_third_party=["adapter_0"], - )) - mask = cfg.get_third_party_adapter_mask() - assert mask == [False, False, False] - - def test_single_third_party_adapter(self): - """Mask correctly identifies single third-party adapter.""" - cfg = GraniteSwitchConfig(**_valid_kwargs( - adapter_third_party=["adapter_0"], - )) - mask = cfg.get_third_party_adapter_mask() - # Index 0 = base (never third-party) - # Index 1 = adapter_0 (third-party) - # Index 2 = adapter_1 (not third-party) - assert mask == [False, True, False] - - def test_multiple_third_party_adapters(self): - """Mask correctly identifies multiple third-party adapters.""" - cfg = GraniteSwitchConfig(**_valid_kwargs( - adapter_third_party=["adapter_0", "adapter_1"], - )) - mask = cfg.get_third_party_adapter_mask() - # Both adapters are third-party - assert mask == [False, True, True] - - def test_base_never_third_party(self): - """Base adapter (index 0) is never marked as third-party.""" - cfg = GraniteSwitchConfig(**_valid_kwargs( - adapter_third_party=["adapter_0", "adapter_1"], - )) - mask = cfg.get_third_party_adapter_mask() - assert mask[0] is False # Base is never third-party - - def test_mask_length_matches_num_adapters_plus_one(self): - """Mask length is num_adapters + 1 (includes base slot).""" - for num_adapters in [0, 1, 4, 10]: - cfg = GraniteSwitchConfig(**_valid_kwargs( - num_adapters=num_adapters, - adapter_token_ids=list(range(500, 500 + num_adapters)), - adapter_names=[f"adapter_{i}" for i in range(num_adapters)], - adapter_ranks=[8] * num_adapters if num_adapters > 0 else None, - )) - mask = cfg.get_third_party_adapter_mask() - assert len(mask) == num_adapters + 1 - - class TestLayerTypesDefault: - """Tests for layer_types default handling.""" + """layer_types defaults to all-attention with length == num_hidden_layers.""" - def test_layer_types_defaults_to_attention(self): - """layer_types defaults to all 'attention' when None.""" - cfg = GraniteSwitchConfig(**_valid_kwargs( - layer_types=None, - num_hidden_layers=4, - )) - # Should have 4 attention layers (adapters add a switch layer at index 0, - # but the config has num_hidden_layers=4 which becomes 5 with switch) - # The default is set before parent init adds the switch layer - assert cfg.layer_types is not None + def test_default_layer_types_when_omitted(self): + cfg = GraniteSwitchConfig(num_adapters=0, num_hidden_layers=4) + assert cfg.layer_types == ["attention"] * 4 def test_explicit_layer_types_preserved(self): - """Explicit layer_types are preserved.""" - cfg = GraniteSwitchConfig(**_valid_kwargs( - layer_types=["attention", "attention"], - num_hidden_layers=2, - )) - assert cfg.layer_types == ["attention", "attention"] + cfg = GraniteSwitchConfig( + num_adapters=0, + num_hidden_layers=3, + layer_types=["attention", "attention", "attention"], + ) + assert cfg.layer_types == ["attention", "attention", "attention"] class TestLoraTargetModulesDefault: - """Tests for lora_target_modules default handling.""" + """lora_target_modules defaults to qkv_proj/o_proj + shared_mlp pair + when num_adapters > 0; empty when num_adapters == 0.""" - def test_lora_target_modules_empty_when_no_adapters(self): - """lora_target_modules defaults to empty when num_adapters=0.""" - cfg = GraniteSwitchConfig( - vocab_size=256, hidden_size=64, intermediate_size=128, - num_hidden_layers=2, num_attention_heads=4, num_key_value_heads=4, - num_adapters=0, - ) + def test_no_adapters_empty_target_modules(self): + cfg = GraniteSwitchConfig(num_adapters=0) assert cfg.lora_target_modules == [] - def test_lora_target_modules_populated_with_adapters(self): - """lora_target_modules defaults to standard modules when adapters present.""" + def test_adapters_populate_target_modules(self): cfg = GraniteSwitchConfig(**_valid_kwargs()) - # Should include attention and MLP modules assert "qkv_proj" in cfg.lora_target_modules assert "o_proj" in cfg.lora_target_modules assert "shared_input_linear" in cfg.lora_target_modules assert "shared_output_linear" in cfg.lora_target_modules + + def test_explicit_target_modules_preserved(self): + cfg = GraniteSwitchConfig( + **_valid_kwargs(lora_target_modules=["qkv_proj"]) + ) + assert cfg.lora_target_modules == ["qkv_proj"] diff --git a/tests/unit/test_hiding_constant.py b/tests/unit/test_hiding_constant.py deleted file mode 100644 index 00c7724..0000000 --- a/tests/unit/test_hiding_constant.py +++ /dev/null @@ -1,118 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""Tests for the K-side hiding constant used in control-dimension masking. - -The hiding mechanism sets K[d_g] = finfo(dtype).min for tokens in hiding group g. -This test validates that this constant behaves correctly across all supported -floating-point types: - -1. exp(constant) == 0 (softmax produces zero weight) -2. Accumulation of multiple constants (multiple groups) also exponentiates to zero -3. Adding realistic finite attention scores to the constant still exponentiates to zero -4. 0 * constant does NOT produce NaN (critical: Q[d_g]=0 for non-hiding adapters) - -Safety margin reporting (how large a positive value must be added before exp produces -a nonzero result) is part of the builder's verbose output, not tested here. -""" - -import pytest -import torch - -DTYPES = [torch.float16, torch.bfloat16, torch.float32, torch.float64] -DTYPE_IDS = ["float16", "bfloat16", "float32", "float64"] - -# Maximum number of hiding groups we might realistically accumulate in a dot product -MAX_GROUPS = 32 - - - -def hiding_constant(dtype: torch.dtype) -> torch.Tensor: - """The K-side hiding constant: the most negative finite value in the dtype.""" - return torch.tensor(torch.finfo(dtype).min, dtype=dtype) - - -@pytest.fixture(params=DTYPES, ids=DTYPE_IDS) -def dtype(request): - return request.param - - -class TestHidingConstantExponentiation: - """exp(hiding_constant) must be exactly zero — this is what makes softmax - assign zero attention weight to hidden tokens.""" - - def test_exp_of_constant_is_zero(self, dtype): - c = hiding_constant(dtype) - assert torch.exp(c).item() == 0.0 - - def test_exp_of_sum_of_constants_is_zero(self, dtype): - """A token in multiple groups: dot product accumulates multiple constants.""" - c = hiding_constant(dtype) - for n in [2, 4, MAX_GROUPS]: - accum = torch.zeros(1, dtype=dtype) - for _ in range(n): - accum = accum + c - assert torch.exp(accum).item() == 0.0, f"Failed for {n} groups" - - def test_exp_of_constant_plus_finite_is_zero(self, dtype): - """Normal attention score added to the constant must still exponentiate to zero.""" - c = hiding_constant(dtype) - for score in [0.0, 10.0, 100.0, 1000.0]: - s = torch.tensor(score, dtype=dtype) - result = torch.exp(c + s) - assert result.item() == 0.0, f"Failed for score={score}" - - -class TestHidingConstantNoNaN: - """0 * hiding_constant must NOT produce NaN. This is the scenario where - Q[d_g] = 0 (adapter does not hide group g) and K[d_g] = constant.""" - - def test_zero_times_constant_is_not_nan(self, dtype): - c = hiding_constant(dtype) - zero = torch.tensor(0.0, dtype=dtype) - result = zero * c - assert not result.isnan().item() - - def test_zero_times_constant_does_not_corrupt_dot_product(self, dtype): - """In a realistic dot product, 0 * constant contributions must not - change the result compared to a clean dot product without control dims.""" - torch.manual_seed(42) - head_dim = 128 - control_dims = 4 - total_dim = head_dim + control_dims - - Q = torch.randn(total_dim, dtype=dtype) - K = torch.randn(total_dim, dtype=dtype) - - c = hiding_constant(dtype) - # Token is in groups 0 and 2 - K[head_dim + 0] = c - K[head_dim + 1] = 0.0 - K[head_dim + 2] = c - K[head_dim + 3] = 0.0 - - # Query does NOT hide any group - Q[head_dim:] = 0.0 - - dot_with_ctrl = torch.dot(Q, K) - dot_clean = torch.dot(Q[:head_dim], K[:head_dim]) - - assert not dot_with_ctrl.isnan().item() - assert torch.isclose(dot_with_ctrl, dot_clean, atol=1e-2) - - -class TestHidingConstantSoftmax: - """End-to-end: softmax assigns exactly zero weight to hidden positions.""" - - def test_softmax_zero_weight_for_hidden(self, dtype): - scores = torch.tensor([5.0, 3.0, 7.0], dtype=dtype) - c = hiding_constant(dtype) - scores_with_hidden = scores.clone() - scores_with_hidden[1] = scores_with_hidden[1] + c # hide position 1 - - sm = torch.softmax(scores_with_hidden, dim=0) - assert sm[1].item() == 0.0 - # Non-hidden positions should get all the probability mass - assert sm[0].item() > 0.0 - assert sm[2].item() > 0.0 - assert abs(sm.sum().item() - 1.0) < 1e-3 - - diff --git a/tests/unit/test_token_exchange.py b/tests/unit/test_token_exchange.py new file mode 100644 index 0000000..d24e968 --- /dev/null +++ b/tests/unit/test_token_exchange.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Unit tests for the token-exchange config path. + +Verifies the validators and required-field semantics on +GraniteSwitchConfig, now that token-exchange is the only mode. +""" + +import pytest + +from granite_switch.config import GraniteSwitchConfig + + +def _base(num_adapters=2, **overrides): + names = [f"a{i}" for i in range(num_adapters)] + base = dict( + vocab_size=300, + hidden_size=64, + intermediate_size=128, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=4, + num_adapters=num_adapters, + adapter_token_ids=list(range(500, 500 + num_adapters)), + adapter_substitute_token_ids=[1] * num_adapters, + adapter_names=names, + max_lora_rank=8, + adapter_ranks=[8] * num_adapters, + ) + base.update(overrides) + return base + + +class TestDefaults: + def test_no_adapters_no_validation(self): + cfg = GraniteSwitchConfig(num_adapters=0) + assert cfg.adapter_substitute_token_ids is None + + +class TestValidation: + def test_substitute_ids_required_when_adapters_present(self): + with pytest.raises(ValueError, match="adapter_substitute_token_ids is required"): + GraniteSwitchConfig(**_base(adapter_substitute_token_ids=None)) + + def test_substitute_wrong_length_raises(self): + with pytest.raises(ValueError, match="adapter_substitute_token_ids length"): + GraniteSwitchConfig(**_base(adapter_substitute_token_ids=[1])) + + def test_duplicate_adapter_token_ids_raises(self): + with pytest.raises(ValueError, match="adapter_token_ids must be unique"): + GraniteSwitchConfig(**_base(adapter_token_ids=[100, 100])) + + def test_negative_substitute_id_raises(self): + with pytest.raises(ValueError, match=">= 0"): + GraniteSwitchConfig(**_base(adapter_substitute_token_ids=[-1, 1])) + + +class TestProjectionHeadDim: + def test_inferred_from_hidden_size(self): + cfg = GraniteSwitchConfig(**_base()) + assert cfg.projection_head_dim == cfg.hidden_size // cfg.num_attention_heads diff --git a/tests/vllm/_generation_equivalence_worker.py b/tests/vllm/_generation_equivalence_worker.py index 0609b8f..17c4e40 100644 --- a/tests/vllm/_generation_equivalence_worker.py +++ b/tests/vllm/_generation_equivalence_worker.py @@ -9,8 +9,8 @@ python worker.py compare --work-dir --label **build**: Loads config for dtype/vocab, generates a deterministic 64-token prompt, -builds a GraniteSwitch model with 1 built-in adapter (zero LoRA weights) and -control_dims=32. Saves the switch model and inputs to ``/``. +builds a GraniteSwitch model with 1 built-in adapter (zero LoRA weights). +Saves the switch model and inputs to ``/``. **run**: Loads inputs from ``/inputs.json``, loads model in vLLM, runs greedy autoregressive generation (temperature=0, max_tokens=32), saves generated @@ -83,17 +83,14 @@ def cmd_build(args): print(f" saved inputs to {inputs_path}") # Build switch model with 1 built-in adapter - print(f"\nBuilding GraniteSwitch (1 built-in adapter, control_dims=32)...") + print(f"\nBuilding GraniteSwitch (1 built-in adapter)...") skin_dir = os.path.join(work_dir, "switch") model = GraniteSwitchComposer.from_base_and_adapters( model_name, built_in_adapter_names=["test"], adapter_names=["test"], adapter_token_ids=[adapter_token_id], - control_dims=32, - hiding_groups={"all_controls": ["test"]}, - hiding_policy={"base": ["all_controls"], "test": ["all_controls"]}, - adapter_third_party=["test"], + adapter_substitute_token_ids=[1], torch_dtype=dtype, ) diff --git a/tests/vllm/_kv_hiding_gap_tests.py b/tests/vllm/_kv_hiding_gap_tests.py deleted file mode 100644 index be29394..0000000 --- a/tests/vllm/_kv_hiding_gap_tests.py +++ /dev/null @@ -1,71 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""KV hiding gap equivalence tests (inner file — run by test_kv_hiding_gap_equivalence.py). - -Verify that a hidden control token creates a transparent gap in vLLM. -Requires CUDA GPU and vLLM installed. -""" - -import pytest -import torch - -from tests.shared.granite4_equivalence import GRANITE4_MINI -from tests.shared.gap_equivalence import extract_visible_flat - - -_CUDA_AVAILABLE = torch.cuda.is_available() - - -def _try_import_vllm(): - try: - from vllm import LLM # noqa: F401 - return True - except ImportError: - return False - - -_VLLM_AVAILABLE = _try_import_vllm() if _CUDA_AVAILABLE else False - -_CFG_NAME = "4.0-350m" - - -@pytest.mark.skipif( - not _CUDA_AVAILABLE or not _VLLM_AVAILABLE, - reason="requires CUDA GPU and vLLM installed", -) -class TestKVHidingGapEquivalence: - - @pytest.fixture - def gap_runner(self, tmp_path): - cfg_dict = GRANITE4_MINI[_CFG_NAME] - - def run(seq_len, ctrl_pos): - from tests.shared.vllm_equivalence import run_gap_equivalence - return run_gap_equivalence( - cfg_dict, - seq_len=seq_len, ctrl_pos=ctrl_pos, - tmpdir=tmp_path, - ) - - return run - - def _assert_gap(self, run, seq_len, ctrl_pos, atol=0, rtol=0): - upstream_lp, switch_lp = run(seq_len, ctrl_pos) - visible_lp = extract_visible_flat(switch_lp, ctrl_pos) - - torch.testing.assert_close( - visible_lp, upstream_lp, - atol=atol, rtol=rtol, - msg=( - f"{_CFG_NAME}: visible logprobs diverge " - f"(seq={seq_len}, ctrl={ctrl_pos})" - ), - ) - - def test_gap_short(self, gap_runner): - self._assert_gap(gap_runner, seq_len=16, ctrl_pos=2) - - def test_gap_ctrl_at_1(self, gap_runner): - self._assert_gap(gap_runner, seq_len=16, ctrl_pos=1) - - def test_gap_near_end(self, gap_runner): - self._assert_gap(gap_runner, seq_len=16, ctrl_pos=14) diff --git a/tests/vllm/_model_forward_tests.py b/tests/vllm/_model_forward_tests.py index faf4cf4..4caeebb 100644 --- a/tests/vllm/_model_forward_tests.py +++ b/tests/vllm/_model_forward_tests.py @@ -63,40 +63,10 @@ def _tiny_vllm_config(): num_adapters=2, adapter_token_ids=[250, 251], adapter_names=["adapter_1", "adapter_2"], - hiding_groups={"all_controls": ["adapter_1", "adapter_2"]}, - hiding_policy={"base": ["all_controls"], "adapter_1": ["all_controls"], "adapter_2": ["all_controls"]}, - adapter_third_party=["adapter_1", "adapter_2"], + adapter_substitute_token_ids=[1, 1], max_lora_rank=4, adapter_ranks=[4, 4], switch_head_dim=32, - control_dims=32, - max_position_embeddings=512, - attention_multiplier=1.0, - embedding_multiplier=1.0, - residual_multiplier=1.0, - logits_scaling=1.0, - ) - - -def _tiny_vllm_mixed_tp_config(): - """SingleSwitch config where only adapter_1 is third-party.""" - return GraniteSwitchConfig( - vocab_size=300, - hidden_size=64, - intermediate_size=128, - num_hidden_layers=3, - num_attention_heads=2, - num_key_value_heads=2, - num_adapters=2, - adapter_token_ids=[250, 251], - adapter_names=["adapter_1", "adapter_2"], - hiding_groups={"all_controls": ["adapter_1", "adapter_2"]}, - hiding_policy={"base": ["all_controls"], "adapter_1": ["all_controls"], "adapter_2": ["all_controls"]}, - adapter_third_party=["adapter_1"], - max_lora_rank=4, - adapter_ranks=[4, 4], - switch_head_dim=32, - control_dims=32, max_position_embeddings=512, attention_multiplier=1.0, embedding_multiplier=1.0, @@ -449,46 +419,7 @@ def test_different_adapters_produce_different_logits(self): # ════════════════════════════════════════════════════════════════════ -# 5. Control token KV invisibility -# ════════════════════════════════════════════════════════════════════ - -class TestControlTokenKVInvisibility(_VLLMModelTestBase): - - def test_control_token_invisible_to_future_positions(self): - torch.manual_seed(SEED) - self.model.eval() - - seq = [10, 20, 250, 30, 40, 50, 60, 70] - - with torch.no_grad(): - logits_a = self._run_forward_and_logits(seq) - - with torch.no_grad(): - perturbation = torch.randn( - self.config.hidden_size, device=self.device, dtype=torch.bfloat16 - ) * 10.0 - self.model.model.embed_tokens.weight.data[250] += perturbation - - with torch.no_grad(): - logits_b = self._run_forward_and_logits(seq) - - torch.testing.assert_close( - logits_a[:2], logits_b[:2], - msg="Pre-control logits should be identical" - ) - - assert not torch.allclose(logits_a[2], logits_b[2]), \ - "Control token logits should differ after perturbation" - - torch.testing.assert_close( - logits_a[3:], logits_b[3:], - msg="Post-control logits should be identical " - "(control token KV masked by control_dims)" - ) - - -# ════════════════════════════════════════════════════════════════════ -# 6. KV visibility tests +# 5. KV visibility tests # ════════════════════════════════════════════════════════════════════ class TestKVVisibility(_VLLMModelTestBase): diff --git a/tests/vllm/_position_zero_nan_tests.py b/tests/vllm/_position_zero_nan_tests.py deleted file mode 100644 index 69b3360..0000000 --- a/tests/vllm/_position_zero_nan_tests.py +++ /dev/null @@ -1,479 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""NaN regression tests — control token at sequence position 0 (vLLM backend). -Inner file — run by test_position_zero_nan.py in subprocess. - -Combines: - - vLLM-specific unit tests for GraniteLoRAEmbeddedAttention._expand_with_control_dimensions - (flat token layout: [num_tokens, num_heads * head_dim]) - - Shared SDPANaNCases and ModelFinitenessCases from tests/shared/position_zero_nan_cases.py - -Requires CUDA GPU and vLLM installed. -""" - -import types -import json -import os -import tempfile - -import pytest -import torch - -from tests.shared.position_zero_nan_cases import ModelFinitenessCases, SDPANaNCases - -_CUDA_AVAILABLE = torch.cuda.is_available() - - -def _try_import_vllm(): - try: - from vllm.config import VllmConfig # noqa: F401 - from vllm.model_executor.layers.attention.attention import Attention # noqa: F401 - from vllm.forward_context import ForwardContext, override_forward_context # noqa: F401 - return True - except ImportError: - return False - - -_VLLM_AVAILABLE = _try_import_vllm() if _CUDA_AVAILABLE else False - -pytestmark = pytest.mark.skipif( - not _CUDA_AVAILABLE or not _VLLM_AVAILABLE, - reason="requires CUDA GPU and vLLM installed", -) - -if _VLLM_AVAILABLE: - from vllm.config import VllmConfig, ModelConfig, set_current_vllm_config - from vllm.forward_context import ForwardContext, override_forward_context - from granite_switch.config import GraniteSwitchConfig - from granite_switch.vllm.granite_switch_model import GraniteSwitchForCausalLM - from granite_switch.vllm.core.decoder import GraniteLoRAEmbeddedAttention - -from tests.shared.vllm_distributed import ensure_distributed as _ensure_distributed - -BLOCK_SIZE = 16 -MAX_TOKENS = 512 -SEED = 42 - - -# ── vLLM config + ctrl token ID ──────────────────────────────────── - - -def _make_config(): - return GraniteSwitchConfig( - vocab_size=300, - hidden_size=64, - intermediate_size=128, - num_hidden_layers=3, - num_attention_heads=2, - num_key_value_heads=2, - num_adapters=2, - adapter_token_ids=[250, 251], - adapter_names=["adapter_1", "adapter_2"], - hiding_groups={"all_controls": ["adapter_1", "adapter_2"]}, - hiding_policy={ - "base": ["all_controls"], - "adapter_1": ["all_controls"], - "adapter_2": ["all_controls"], - }, - adapter_third_party=["adapter_1", "adapter_2"], - max_lora_rank=4, - adapter_ranks=[4, 4], - switch_head_dim=32, - control_dims=32, - max_position_embeddings=MAX_TOKENS, - attention_multiplier=1.0, - embedding_multiplier=1.0, - residual_multiplier=1.0, - logits_scaling=1.0, - ) - - -_CTRL_TOKEN = 250 - - -# ── vLLM model runner ─────────────────────────────────────────────── - - -def _make_vllm_config(config): - from granite_switch.vllm import register - register() - tmpdir = tempfile.mkdtemp(prefix="gs_nan_test_") - cfg_dict = config.to_dict() - cfg_dict["architectures"] = ["GraniteSwitchForCausalLM"] - with open(os.path.join(tmpdir, "config.json"), "w") as f: - json.dump(cfg_dict, f) - model_config = ModelConfig( - model=tmpdir, - dtype="bfloat16", - max_model_len=config.max_position_embeddings, - enforce_eager=True, - ) - return VllmConfig(model_config=model_config) - - -def _init_weights(model): - torch.manual_seed(SEED) - with torch.no_grad(): - for name, param in model.named_parameters(): - if not param.is_floating_point(): - continue - if "lora_A" in name or "lora_B" in name: - continue - if "layernorm" in name or "norm" in name: - continue - param.data.normal_(0, 0.02) - - -def _setup_kv_caches(model, config, vllm_config, device): - kv_caches = [] - attention_map = {} - num_blocks = (MAX_TOKENS + BLOCK_SIZE - 1) // BLOCK_SIZE + 1 - - def _add(attn, name): - attn.kv_cache_torch_dtype = torch.bfloat16 - shape = attn.attn_backend.get_kv_cache_shape( - num_blocks, BLOCK_SIZE, attn.num_kv_heads, attn.head_size, - ) - kv = torch.zeros(shape, device=device, dtype=torch.bfloat16) - attn.kv_cache = kv - kv_caches.append(kv) - attention_map[name] = attn - - sw = model.model.switch - _add(sw.attn, "switch.layers.0") - num_decoder = config.num_hidden_layers - sw.num_cache_layers - for i in range(num_decoder): - _add(model.model.layers[i].self_attn.attn, f"model.layers.{i}.self_attn.attn") - - return kv_caches, attention_map - - -def _build_metadata(attention_map, seq_len, device): - slot_mapping = torch.arange(seq_len, dtype=torch.int64, device=device) - num_blocks = (seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE - block_table = torch.arange(num_blocks, dtype=torch.int32, device=device).unsqueeze(0) - query_start_loc = torch.tensor([0, seq_len], dtype=torch.int32, device=device) - seq_lens = torch.tensor([seq_len], dtype=torch.int32, device=device) - - backend_name = list(attention_map.values())[0].attn_backend.get_name() - if backend_name == "FLASH_ATTN": - from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata - - # scheduler_metadata is FA3-only — passing it on FA2 (Ampere/A100) - # forces FA3 kernel dispatch and crashes with "no kernel image - # is available". Only compute it when get_flash_attn_version() == 3 - # (Hopper SM90+). - scheduler_metadata = None - try: - from vllm.v1.attention.backends.fa_utils import ( - get_flash_attn_version, - get_scheduler_metadata, - ) - if get_flash_attn_version() == 3: - first_attn = list(attention_map.values())[0] - scheduler_metadata = get_scheduler_metadata( - batch_size=1, - max_seqlen_q=seq_len, - max_seqlen_k=seq_len, - num_heads_q=first_attn.num_heads, - num_heads_kv=first_attn.num_kv_heads, - headdim=first_attn.head_size, - cache_seqlens=seq_lens, - qkv_dtype=torch.bfloat16, - cu_seqlens_q=query_start_loc, - page_size=BLOCK_SIZE, - causal=True, - num_splits=0, - ) - except ImportError: - pass - - metadata = FlashAttentionMetadata( - num_actual_tokens=seq_len, - max_query_len=seq_len, - query_start_loc=query_start_loc, - max_seq_len=seq_len, - seq_lens=seq_lens, - block_table=block_table, - slot_mapping=slot_mapping, - use_cascade=False, - common_prefix_len=0, - cu_prefix_query_lens=None, - prefix_kv_lens=None, - suffix_kv_lens=None, - causal=True, - scheduler_metadata=scheduler_metadata, - ) - else: - pytest.skip(f"Backend {backend_name}: metadata not implemented for this test") - - return metadata, slot_mapping - - -def _run_vllm_forward_is_finite(ctrl_pos, seq_len, seed): - """Create a vLLM switch model and check for finite output at the given ctrl_pos.""" - _ensure_distributed() - device = torch.device("cuda") - config = _make_config() - - old_dtype = torch.get_default_dtype() - torch.set_default_dtype(torch.bfloat16) - try: - vllm_config = _make_vllm_config(config) - with set_current_vllm_config(vllm_config): - model = GraniteSwitchForCausalLM(vllm_config=vllm_config).to(device) - finally: - torch.set_default_dtype(old_dtype) - - _init_weights(model) - kv_caches, attention_map = _setup_kv_caches(model, config, vllm_config, device) - - # Build input: ctrl_token at ctrl_pos, random content elsewhere - torch.manual_seed(seed) - ctrl_id = _CTRL_TOKEN - content = torch.randint(0, 100, (seq_len,)).tolist() - ids_list = content[:ctrl_pos] + [ctrl_id] + content[ctrl_pos:] - total_len = len(ids_list) - - input_ids = torch.tensor(ids_list, dtype=torch.long, device=device) - positions = torch.arange(total_len, dtype=torch.long, device=device) - metadata, slot_mapping = _build_metadata(attention_map, total_len, device) - - layer_names = list(attention_map.keys()) - attn_metadata = {n: metadata for n in layer_names} - slot_mapping_dict = {n: slot_mapping for n in layer_names} - - forward_ctx = ForwardContext( - no_compile_layers=vllm_config.compilation_config.static_forward_context, - attn_metadata=attn_metadata, - slot_mapping=slot_mapping_dict, - ) - - old_direct = {n: attention_map[n].use_direct_call for n in layer_names} - for n in layer_names: - attention_map[n].use_direct_call = True - - try: - for kv in kv_caches: - kv.zero_() - with override_forward_context(forward_ctx): - hidden = model.forward(input_ids=input_ids, positions=positions) - logits = model.compute_logits(hidden) - finally: - for n in layer_names: - attention_map[n].use_direct_call = old_direct[n] - - sfc = vllm_config.compilation_config.static_forward_context - for n in layer_names: - sfc.pop(n, None) - - return bool(logits.isfinite().all()) - - -# ════════════════════════════════════════════════════════════════════ -# 1. vLLM-specific unit tests: _expand_with_control_dimensions -# Tensor layout: [num_tokens, num_heads * head_dim] (flat) -# ════════════════════════════════════════════════════════════════════ - - -def _vllm_stub(num_heads=2, num_kv_heads=2, head_dim=32, control_dims=1): - """Minimal namespace for vLLM _expand_with_control_dimensions.""" - return types.SimpleNamespace( - num_heads=num_heads, - num_kv_heads=num_kv_heads, - head_dim=head_dim, - control_dims=control_dims, - expanded_head_dim=head_dim + control_dims, - ) - - -def _vllm_expand(stub, q, k, v, membership, suppression): - return GraniteLoRAEmbeddedAttention._expand_with_control_dimensions( - stub, q, k, v, membership, suppression - ) - - -def _vllm_qkv(stub, num_tokens): - """Create flat vLLM-layout Q/K/V tensors.""" - q = torch.randn(num_tokens, stub.num_heads * stub.head_dim) - k = torch.randn(num_tokens, stub.num_kv_heads * stub.head_dim) - v = torch.randn(num_tokens, stub.num_kv_heads * stub.head_dim) - return q, k, v - - -class TestExpandControlDimensions: - """Direct tests of _expand_with_control_dimensions (vLLM flat tensor layout). - - Input shape: [num_tokens, num_heads * head_dim] - Output shape: [num_tokens, num_heads * expanded_head_dim] - """ - - _HEAD_DIM = 32 - _CTRL_DIMS = 1 - - def test_control_token_q_hide_zero_at_position_zero(self): - """Core fix: control token at pos 0 must not activate Q-side hiding.""" - stub = _vllm_stub(control_dims=self._CTRL_DIMS) - membership = torch.ones(1, 1, dtype=torch.bool) - suppression = torch.ones(1, 1, dtype=torch.bool) - q, k, v = _vllm_qkv(stub, num_tokens=1) - - q_exp, _, _ = _vllm_expand(stub, q, k, v, membership, suppression) - - q_reshaped = q_exp.view(1, stub.num_heads, stub.expanded_head_dim) - q_ctrl = q_reshaped[0, :, self._HEAD_DIM:] - assert q_ctrl.eq(0).all(), f"Control token at pos 0: q_control must be 0, got {q_ctrl}" - - def test_adapter_generated_tokens_q_hide_one(self): - """Adapter-generated tokens (non-members) keep q_control=1.""" - stub = _vllm_stub(control_dims=self._CTRL_DIMS) - num_tokens = 5 - membership = torch.zeros(num_tokens, 1, dtype=torch.bool) - membership[0, 0] = True # control token at pos 0 - suppression = torch.ones(num_tokens, 1, dtype=torch.bool) - q, k, v = _vllm_qkv(stub, num_tokens) - - q_exp, _, _ = _vllm_expand(stub, q, k, v, membership, suppression) - q_reshaped = q_exp.view(num_tokens, stub.num_heads, stub.expanded_head_dim) - - assert q_reshaped[0, :, self._HEAD_DIM:].eq(0).all(), "Control token: q_control must be 0" - for pos in range(1, num_tokens): - assert q_reshaped[pos, :, self._HEAD_DIM:].eq(1).all(), ( - f"Adapter-generated token at pos {pos}: q_control must be 1" - ) - - def test_k_side_finfo_min_for_control_token(self): - """K-side branding is unaffected by the fix.""" - stub = _vllm_stub(control_dims=self._CTRL_DIMS) - membership = torch.ones(1, 1, dtype=torch.bool) - q, k, v = _vllm_qkv(stub, num_tokens=1) - - _, k_exp, _ = _vllm_expand(stub, q, k, v, membership, None) - - k_reshaped = k_exp.view(1, stub.num_kv_heads, stub.expanded_head_dim) - k_ctrl = k_reshaped[0, :, self._HEAD_DIM:] - expected_min = torch.finfo(k.dtype).min - torch.testing.assert_close(k_ctrl, torch.full_like(k_ctrl, expected_min)) - - def test_both_none_leaves_all_control_dims_zero(self): - """With both tensors None, all control dims remain zero.""" - stub = _vllm_stub(control_dims=2) - q, k, v = _vllm_qkv(stub, num_tokens=4) - q_exp, k_exp, v_exp = _vllm_expand(stub, q, k, v, None, None) - - exp_head = stub.expanded_head_dim - assert q_exp.view(4, stub.num_heads, exp_head)[:, :, self._HEAD_DIM:].eq(0).all() - assert k_exp.view(4, stub.num_kv_heads, exp_head)[:, :, self._HEAD_DIM:].eq(0).all() - assert v_exp.view(4, stub.num_kv_heads, exp_head)[:, :, self._HEAD_DIM:].eq(0).all() - - def test_original_dimensions_preserved(self): - """Original head dims are unchanged; only control dims appended.""" - stub = _vllm_stub(control_dims=2) - q, k, v = _vllm_qkv(stub, num_tokens=3) - q_exp, k_exp, v_exp = _vllm_expand(stub, q, k, v, None, None) - - exp_head = stub.expanded_head_dim - torch.testing.assert_close( - q_exp.view(3, stub.num_heads, exp_head)[:, :, :self._HEAD_DIM], - q.view(3, stub.num_heads, stub.head_dim), - ) - - -# ════════════════════════════════════════════════════════════════════ -# 2. Shared SDPA cases -# ════════════════════════════════════════════════════════════════════ - - -class TestSDPANaN(SDPANaNCases): - pass - - -# ════════════════════════════════════════════════════════════════════ -# 3. Shared model finiteness cases — vLLM backend -# ════════════════════════════════════════════════════════════════════ - - -class TestModelFiniteness(ModelFinitenessCases): - def _assert_no_nan(self, switch_type, ctrl_pos, seq_len, seed): - is_finite = _run_vllm_forward_is_finite(ctrl_pos, seq_len, seed) - assert is_finite, ( - f"[vLLM] ctrl_pos={ctrl_pos}: logits contain NaN/Inf" - ) - - -# ════════════════════════════════════════════════════════════════════ -# 4. Mutation test — proves TestModelFiniteness is sensitive to the fix -# ════════════════════════════════════════════════════════════════════ - - -def _buggy_expand( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - token_group_membership, - query_group_suppression, -) -> tuple: - """Pre-fix version: omits q_hide *= (1 - membership), causing NaN at ctrl_pos=0.""" - num_tokens = q.size(0) - device = q.device - dtype = q.dtype - - q = q.view(num_tokens, self.num_heads, self.head_dim) - k = k.view(num_tokens, self.num_kv_heads, self.head_dim) - v = v.view(num_tokens, self.num_kv_heads, self.head_dim) - - q_control = torch.zeros(num_tokens, self.num_heads, self.control_dims, device=device, dtype=dtype) - k_control = torch.zeros(num_tokens, self.num_kv_heads, self.control_dims, device=device, dtype=dtype) - v_control = torch.zeros(num_tokens, self.num_kv_heads, self.control_dims, device=device, dtype=dtype) - - if token_group_membership is not None: - num_groups = token_group_membership.shape[-1] - hiding_constant = torch.finfo(dtype).min - k_control[:, :, :num_groups] = ( - token_group_membership.unsqueeze(1) - .expand(-1, self.num_kv_heads, -1) - .to(dtype) * hiding_constant - ) - - if query_group_suppression is not None: - num_groups = query_group_suppression.shape[-1] - q_hide = query_group_suppression.to(dtype) - # BUG: missing `q_hide *= (1 - token_group_membership)` — control token - # at position 0 gets q_ctrl=1, causing softmax([-inf]) = NaN. - q_control[:, :, :num_groups] = q_hide.unsqueeze(1).expand(-1, self.num_heads, -1) - - q = torch.cat([q, q_control], dim=-1) - k = torch.cat([k, k_control], dim=-1) - v = torch.cat([v, v_control], dim=-1) - - q = q.view(num_tokens, self.num_heads * self.expanded_head_dim) - k = k.view(num_tokens, self.num_kv_heads * self.expanded_head_dim) - v = v.view(num_tokens, self.num_kv_heads * self.expanded_head_dim) - - return q, k, v - - -class TestFixSensitivity: - """Mutation test: revert the fix and confirm NaN is produced. - - Patches _expand_with_control_dimensions with the pre-fix (buggy) version. - If _run_vllm_forward_is_finite still returns True, the model-level tests - are not actually sensitive to the fix and must be reconsidered. - """ - - def test_buggy_expand_produces_nan_at_ctrl_pos_zero(self): - """Without the fix, ctrl_pos=0 must produce non-finite logits in vLLM.""" - from granite_switch.vllm.core.decoder import GraniteLoRAEmbeddedAttention - from unittest.mock import patch - - with patch.object( - GraniteLoRAEmbeddedAttention, - "_expand_with_control_dimensions", - _buggy_expand, - ): - is_finite = _run_vllm_forward_is_finite(ctrl_pos=0, seq_len=8, seed=99) - - assert not is_finite, ( - "[vLLM] Expected NaN with buggy expand at ctrl_pos=0, " - "but output was finite — test is not sensitive to the fix" - ) diff --git a/tests/vllm/_single_switch_worker.py b/tests/vllm/_single_switch_worker.py index d876dc4..4f7e632 100644 --- a/tests/vllm/_single_switch_worker.py +++ b/tests/vllm/_single_switch_worker.py @@ -48,14 +48,23 @@ def _setup(): MAX_TOKENS = 131_072 NUM_ADAPTERS = 32 ADAPTER_TOKEN_IDS_LIST = list(range(1000, 1000 + NUM_ADAPTERS)) + # Deterministic substitute mapping for token-exchange tests: + # control id 1000+i → substitute id i+1 (i.e. 1, 2, ..., NUM_ADAPTERS). + # Substitute ids must be < vocab_size and != any control id. + ADAPTER_SUBSTITUTE_TOKEN_IDS_LIST = [i + 1 for i in range(NUM_ADAPTERS)] # Mock config with realistic backbone geometry (GQA: 4Q/2KV, head_dim=64) # so unit tests exercise the multi-head path, not the fallback. + # adapter_token_ids + adapter_substitute_token_ids enable the + # control_to_substitute_lut path, which production configs always have. mock_config = SimpleNamespace( num_attention_heads=4, num_key_value_heads=2, - expanded_head_dim=64, + projection_head_dim=64, attention_multiplier=0.125, + vocab_size=2000, + adapter_token_ids=ADAPTER_TOKEN_IDS_LIST, + adapter_substitute_token_ids=ADAPTER_SUBSTITUTE_TOKEN_IDS_LIST, ) device = torch.device("cuda") @@ -72,6 +81,14 @@ def _setup(): control_token_gain=15.0, config=mock_config, ) + # Move buffers to CUDA so the LUT (registered as a CPU buffer in + # SingleSwitch.__init__) can index the CUDA input_ids during forward. + # The Q/K/V tensors in SingleSwitch.forward() are constructed directly + # on CUDA so they don't otherwise force a .to() call. + if switch.control_to_substitute_lut is not None: + switch.control_to_substitute_lut = ( + switch.control_to_substitute_lut.to(device) + ) finally: torch.set_default_dtype(old_dtype) @@ -98,6 +115,7 @@ def _setup(): "kv_cache": kv_cache, "device": device, "layer_name": layer_name, + "adapter_substitute_token_ids_list": ADAPTER_SUBSTITUTE_TOKEN_IDS_LIST, "backend_name": backend_name, "block_size": BLOCK_SIZE, "adapter_token_ids_list": ADAPTER_TOKEN_IDS_LIST, @@ -213,7 +231,7 @@ def _run(harness, seq, num_adapters, control_token_gain): try: with override_forward_context(forward_ctx): - result = switch.forward( + adapter_indices, _modified_input_ids = switch.forward( input_ids=input_ids, adapter_token_ids=adapter_token_ids, ) @@ -223,7 +241,67 @@ def _run(harness, seq, num_adapters, control_token_gain): switch.effective_gain = orig_effective_gain switch.num_adapters = orig_num_adapters - return result.cpu().tolist() + return adapter_indices.cpu().tolist() + + +def _run_with_modified(harness, seq, num_adapters, control_token_gain): + """Execute SingleSwitch.forward and return BOTH outputs as lists. + + Used by token-exchange tests that need to inspect modified_input_ids + in addition to adapter_indices. Same setup as _run, but the response + is a dict {"adapter_indices": [...], "modified_input_ids": [...]}. + """ + from vllm.forward_context import ForwardContext, override_forward_context + + switch = harness["switch"] + vllm_config = harness["vllm_config"] + kv_cache = harness["kv_cache"] + device = harness["device"] + layer_name = harness["layer_name"] + adapter_token_ids_list = harness["adapter_token_ids_list"] + + seq_len = len(seq) + kv_cache.zero_() + + orig_gain = switch.control_token_gain + orig_effective_gain = switch.effective_gain + orig_num_adapters = switch.num_adapters + switch.control_token_gain = control_token_gain + switch.effective_gain = control_token_gain / switch.scaling + switch.num_adapters = num_adapters + + input_ids = torch.tensor(seq, dtype=torch.long, device=device) + adapter_token_ids = torch.tensor( + adapter_token_ids_list[:num_adapters], dtype=torch.long, device=device, + ) + + metadata, slot_mapping = _build_metadata(harness, seq_len) + + forward_ctx = ForwardContext( + no_compile_layers=vllm_config.compilation_config.static_forward_context, + attn_metadata={layer_name: metadata}, + slot_mapping={layer_name: slot_mapping}, + ) + + old_direct = switch.attn.use_direct_call + switch.attn.use_direct_call = True + + try: + with override_forward_context(forward_ctx): + adapter_indices, modified_input_ids = switch.forward( + input_ids=input_ids, + adapter_token_ids=adapter_token_ids, + ) + finally: + switch.attn.use_direct_call = old_direct + switch.control_token_gain = orig_gain + switch.effective_gain = orig_effective_gain + switch.num_adapters = orig_num_adapters + + return { + "adapter_indices": adapter_indices.cpu().tolist(), + "modified_input_ids": modified_input_ids.cpu().tolist(), + } def _query_geometry(harness): @@ -316,6 +394,17 @@ def main(): control_token_gain=req.get("control_token_gain", 15.0), ) resp = {"result": result} + elif command == "forward_with_modified": + result = _run_with_modified( + harness, + seq=req["seq"], + num_adapters=req.get("num_adapters", 32), + control_token_gain=req.get("control_token_gain", 15.0), + ) + resp = {"result": result} + elif command == "query_lut": + lut = harness["switch"].control_to_substitute_lut + resp = {"result": lut.cpu().tolist() if lut is not None else None} else: resp = {"error": f"Unknown command: {command}"} except Exception: diff --git a/tests/vllm/_tp_integration_worker.py b/tests/vllm/_tp_integration_worker.py index 6ebf2ca..1729c01 100644 --- a/tests/vllm/_tp_integration_worker.py +++ b/tests/vllm/_tp_integration_worker.py @@ -47,12 +47,9 @@ def cmd_build(args): built_in_adapter_names=["test"], adapter_names=["test"], adapter_token_ids=[adapter_token_id], + adapter_substitute_token_ids=[1], muted_adapter_token_ids=[muted_token_id], - control_dims=32, switch_type="single", - hiding_groups={"all_controls": ["test"]}, - hiding_policy={"base": ["all_controls"], "test": ["all_controls"]}, - adapter_third_party=["test"], ) model.save_pretrained(output_dir) diff --git a/tests/vllm/test_generation_equivalence.py b/tests/vllm/test_generation_equivalence.py index 8f64e1f..d967107 100644 --- a/tests/vllm/test_generation_equivalence.py +++ b/tests/vllm/test_generation_equivalence.py @@ -2,13 +2,10 @@ """Verify greedy generation equivalence: upstream model vs zero-adapter switch model. Tests that autoregressive generation produces identical token sequences when a -GraniteSwitch model has a single built-in adapter with zero LoRA weights and -control_dims=32 (KV hiding infrastructure active, standard third-party mode). +GraniteSwitch model has a single built-in adapter with zero LoRA weights. No control tokens appear in the prompt, so: -- Switch layer → adapter_indices=0 everywhere -- hidden_count=0 → no RoPE gap correction -- K control dims = 0 for all tokens → QK dot product unchanged +- Switch layer → adapter_indices=0 everywhere, no token rewrite - LoRA delta = 0 → decoder layers produce identical output Each model runs in its own set of subprocesses so CUDA context is fully torn diff --git a/tests/vllm/test_granite4_mini.py b/tests/vllm/test_granite4_mini.py index 363edaf..6f34688 100644 --- a/tests/vllm/test_granite4_mini.py +++ b/tests/vllm/test_granite4_mini.py @@ -95,7 +95,7 @@ def test_weight_transfer(self, model_name): for name in unloaded: assert any(k in name for k in ( "lora_A", "lora_B", "switch", "adapter_token_ids", - "token_to_group_mask", "adapter_hiding_matrix", + "control_to_substitute_lut", )), f"Unexpected unloaded parameter: {name}" assert len(unloaded) > 0, "Expected LoRA/switch params to be unloaded" diff --git a/tests/vllm/test_kv_hiding_gap_equivalence.py b/tests/vllm/test_kv_hiding_gap_equivalence.py deleted file mode 100644 index 13e6f34..0000000 --- a/tests/vllm/test_kv_hiding_gap_equivalence.py +++ /dev/null @@ -1,39 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""KV hiding gap equivalence tests (subprocess wrapper). - -Runs _kv_hiding_gap_tests.py in a subprocess so the parent pytest process -never creates a CUDA context. -""" - -import importlib.util -import subprocess -import sys -from pathlib import Path - -import pytest - -_VLLM_AVAILABLE = importlib.util.find_spec("vllm") is not None - -pytestmark = pytest.mark.skipif( - not _VLLM_AVAILABLE, - reason="requires vLLM installed (GPU checked by inner tests)", -) - -_INNER = Path(__file__).parent / "_kv_hiding_gap_tests.py" -_TIMEOUT = 600 - - -def _run_inner_class(class_name): - cmd = [sys.executable, "-m", "pytest", str(_INNER), - "-v", "-s", "--tb=short", "-k", class_name] - result = subprocess.run(cmd, capture_output=True, text=True, timeout=_TIMEOUT) - if result.stdout: - print(result.stdout[-4000:]) - if result.stderr: - print("STDERR:", result.stderr[-2000:]) - assert result.returncode == 0, f"Inner tests failed (exit {result.returncode})" - - -class TestKVHidingGapEquivalence: - def test_suite(self): - _run_inner_class("TestKVHidingGapEquivalence") diff --git a/tests/vllm/test_model_forward.py b/tests/vllm/test_model_forward.py index b70daca..17f98be 100644 --- a/tests/vllm/test_model_forward.py +++ b/tests/vllm/test_model_forward.py @@ -50,11 +50,6 @@ def test_suite(self): _run_inner_class("TestAdapterIndicesWiring") -class TestControlTokenKVInvisibility: - def test_suite(self): - _run_inner_class("TestControlTokenKVInvisibility") - - class TestKVVisibility: def test_suite(self): _run_inner_class("TestKVVisibility") diff --git a/tests/vllm/test_position_zero_nan.py b/tests/vllm/test_position_zero_nan.py deleted file mode 100644 index 0bb0644..0000000 --- a/tests/vllm/test_position_zero_nan.py +++ /dev/null @@ -1,54 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""NaN regression tests — control token at sequence position 0 (vLLM backend). - -Runs _position_zero_nan_tests.py in a subprocess so the parent pytest process -never creates a CUDA context. -""" - -import importlib.util -import subprocess -import sys -from pathlib import Path - -import pytest - -_VLLM_AVAILABLE = importlib.util.find_spec("vllm") is not None - -pytestmark = pytest.mark.skipif( - not _VLLM_AVAILABLE, - reason="requires vLLM installed (GPU checked by inner tests)", -) - -_INNER = Path(__file__).parent / "_position_zero_nan_tests.py" -_TIMEOUT = 600 - - -def _run_inner_class(class_name): - cmd = [sys.executable, "-m", "pytest", str(_INNER), - "-v", "-s", "--tb=short", "-k", class_name] - result = subprocess.run(cmd, capture_output=True, text=True, timeout=_TIMEOUT) - if result.stdout: - print(result.stdout[-4000:]) - if result.stderr: - print("STDERR:", result.stderr[-2000:]) - assert result.returncode == 0, f"Inner tests failed (exit {result.returncode})" - - -class TestExpandControlDimensions: - def test_suite(self): - _run_inner_class("TestExpandControlDimensions") - - -class TestSDPANaN: - def test_suite(self): - _run_inner_class("TestSDPANaN") - - -class TestModelFiniteness: - def test_suite(self): - _run_inner_class("TestModelFiniteness") - - -class TestFixSensitivity: - def test_suite(self): - _run_inner_class("TestFixSensitivity") diff --git a/tests/vllm/test_token_exchange.py b/tests/vllm/test_token_exchange.py new file mode 100644 index 0000000..faac66f --- /dev/null +++ b/tests/vllm/test_token_exchange.py @@ -0,0 +1,254 @@ +# SPDX-License-Identifier: Apache-2.0 +"""vLLM backend tests for token-exchange mode. + +Mirrors tests/hf/test_token_exchange.py — verifies that on the vLLM +SingleSwitch path: + +1. The control_to_substitute_lut tensor maps each adapter control token id + to its configured substitute id, and leaves all other ids at -1. +2. Non-control positions in modified_input_ids are unchanged from the + original input_ids tensor. +3. Control positions in modified_input_ids are rewritten to the + substitute id from the LUT. + +Tests #2 and #3 require a forward pass, so they go through the long-lived +SingleSwitch worker subprocess (the same one used by test_single_switch.py) +via two new commands: 'query_lut' and 'forward_with_modified'. The worker's +mock config now populates adapter_token_ids + adapter_substitute_token_ids +so the LUT path is exercised — see _single_switch_worker.py:_setup. + +Requires CUDA GPU and vLLM installed. All tests skipped otherwise. +All GPU work happens in the subprocess worker — the parent pytest process +never creates a CUDA context (required for Exclusive_Process GPU mode). +""" + +import atexit +import importlib.util +import json +import subprocess +import sys +import threading +from pathlib import Path + +import pytest + +_VLLM_AVAILABLE = importlib.util.find_spec("vllm") is not None + +pytestmark = pytest.mark.skipif( + not _VLLM_AVAILABLE, + reason="requires vLLM installed (GPU checked by worker)", +) + +from tests.shared.single_switch_cases import ( + NUM_ADAPTERS, + TEXT_TOKEN, + ADAPTER_TOKEN_IDS_LIST, +) + +# Worker's deterministic substitute mapping: control_id (1000+i) → sub_id (i+1). +# Matches ADAPTER_SUBSTITUTE_TOKEN_IDS_LIST in _single_switch_worker.py:_setup. +ADAPTER_SUBSTITUTE_TOKEN_IDS_LIST = [i + 1 for i in range(NUM_ADAPTERS)] + + +# ── Worker management ───────────────────────────────────────────── +# Same pattern as test_single_switch.py — own module-private worker so +# pytest can run the two files independently or together. + +_WORKER_PATH = Path(__file__).parent / "_single_switch_worker.py" +_worker_proc = None +_worker_lock = threading.Lock() +_fatal_startup_error = None + + +def _ensure_worker(): + global _worker_proc, _fatal_startup_error + if _fatal_startup_error is not None: + pytest.fail(_fatal_startup_error, pytrace=False) + if _worker_proc is not None and _worker_proc.poll() is None: + return + proc = subprocess.Popen( + [sys.executable, str(_WORKER_PATH)], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + bufsize=1, + ) + ready_line = proc.stdout.readline() + if not ready_line: + stderr = proc.stderr.read() + raise RuntimeError(f"Worker failed to start:\n{stderr}") + ready = json.loads(ready_line) + if "fatal" in ready: + try: + proc.wait(timeout=5) + except subprocess.TimeoutExpired: + proc.kill() + stderr_tail = (proc.stderr.read() or "")[-2000:] + backend = ready.get("backend_name", "unknown") + _fatal_startup_error = ( + f"vLLM worker cannot start: {ready['fatal']}\n" + f"Backend: {backend}\n" + f"Hint: {ready.get('hint', '')}\n" + f"--- worker stderr (tail) ---\n{stderr_tail}" + ) + pytest.fail(_fatal_startup_error, pytrace=False) + assert ready.get("ready"), f"Unexpected ready message: {ready}" + _worker_proc = proc + atexit.register(_shutdown_worker) + + +def _shutdown_worker(): + global _worker_proc + if _worker_proc is not None and _worker_proc.poll() is None: + _worker_proc.stdin.close() + _worker_proc.wait(timeout=30) + _worker_proc = None + + +def _send_command(req): + """Send a JSON request to the worker and return its 'result' field.""" + _ensure_worker() + with _worker_lock: + _worker_proc.stdin.write(json.dumps(req) + "\n") + _worker_proc.stdin.flush() + resp_line = _worker_proc.stdout.readline() + if not resp_line: + stderr = _worker_proc.stderr.read() + raise RuntimeError(f"Worker died unexpectedly:\n{stderr}") + resp = json.loads(resp_line) + if "error" in resp: + raise RuntimeError(f"Worker error:\n{resp['error']}") + return resp["result"] + + +@pytest.fixture(autouse=True, scope="module") +def _worker_lifecycle(): + yield + _shutdown_worker() + + +# ── Tests ───────────────────────────────────────────────────────── + + +class TestLUTMapping: + """control_to_substitute_lut is the canonical control→substitute table. + + It is built once at SingleSwitch construction from + config.adapter_token_ids + config.adapter_substitute_token_ids; tested + here against the worker's mock config (control 1000+i → substitute i+1). + """ + + def test_lut_maps_control_to_substitute(self): + lut = _send_command({"command": "query_lut"}) + assert lut is not None, ( + "control_to_substitute_lut was None — adapter_substitute_token_ids " + "missing from worker mock config?" + ) + for ctrl_id, sub_id in zip( + ADAPTER_TOKEN_IDS_LIST, ADAPTER_SUBSTITUTE_TOKEN_IDS_LIST + ): + assert lut[ctrl_id] == sub_id, ( + f"lut[{ctrl_id}]={lut[ctrl_id]}, expected substitute {sub_id}" + ) + + def test_lut_marks_non_control_with_sentinel(self): + lut = _send_command({"command": "query_lut"}) + assert lut is not None + # TEXT_TOKEN (50) and a few arbitrary non-control ids should be -1. + for non_control in [TEXT_TOKEN, 0, 51, 52, 999]: + assert lut[non_control] == -1, ( + f"lut[{non_control}]={lut[non_control]}, expected -1 sentinel" + ) + + +class TestInputRewrite: + """SingleSwitch.forward returns (adapter_indices, modified_input_ids). + + modified_input_ids must equal input_ids at non-control positions and + equal lut[ctrl_id] (the substitute) at control positions. The decoder + embeds modified_input_ids; the switch itself reads the original + input_ids so adapter detection is unaffected. + """ + + def test_non_control_positions_unchanged(self): + # Mix of non-control tokens with one control token in the middle. + ctrl_id = ADAPTER_TOKEN_IDS_LIST[0] + seq = [TEXT_TOKEN, 51, ctrl_id, 53, 54] + result = _send_command( + { + "command": "forward_with_modified", + "seq": seq, + "num_adapters": 4, + "control_token_gain": 15.0, + } + ) + modified = result["modified_input_ids"] + # Positions 0, 1, 3, 4 are non-control — must be unchanged. + assert modified[0] == seq[0] + assert modified[1] == seq[1] + assert modified[3] == seq[3] + assert modified[4] == seq[4] + + def test_control_positions_rewritten_to_substitute(self): + # Control token at position 2 — must be rewritten to its substitute. + ctrl_id = ADAPTER_TOKEN_IDS_LIST[0] + expected_sub = ADAPTER_SUBSTITUTE_TOKEN_IDS_LIST[0] + seq = [TEXT_TOKEN, 51, ctrl_id, 53, 54] + result = _send_command( + { + "command": "forward_with_modified", + "seq": seq, + "num_adapters": 4, + "control_token_gain": 15.0, + } + ) + modified = result["modified_input_ids"] + assert modified[2] == expected_sub, ( + f"control position rewrite failed: got {modified[2]}, " + f"expected substitute {expected_sub}" + ) + + def test_multiple_control_tokens_each_rewritten(self): + # Two distinct control tokens; each must map to its own substitute. + ctrl0 = ADAPTER_TOKEN_IDS_LIST[0] + ctrl1 = ADAPTER_TOKEN_IDS_LIST[1] + sub0 = ADAPTER_SUBSTITUTE_TOKEN_IDS_LIST[0] + sub1 = ADAPTER_SUBSTITUTE_TOKEN_IDS_LIST[1] + seq = [TEXT_TOKEN, ctrl0, TEXT_TOKEN, ctrl1, TEXT_TOKEN] + result = _send_command( + { + "command": "forward_with_modified", + "seq": seq, + "num_adapters": 4, + "control_token_gain": 15.0, + } + ) + modified = result["modified_input_ids"] + assert modified[0] == TEXT_TOKEN + assert modified[1] == sub0 + assert modified[2] == TEXT_TOKEN + assert modified[3] == sub1 + assert modified[4] == TEXT_TOKEN + + def test_switch_still_detects_adapter_after_rewrite(self): + # The rewrite must NOT confuse adapter detection — the switch reads + # the original input_ids before the rewrite happens. + ctrl_id = ADAPTER_TOKEN_IDS_LIST[2] + seq = [TEXT_TOKEN, ctrl_id, TEXT_TOKEN, TEXT_TOKEN] + result = _send_command( + { + "command": "forward_with_modified", + "seq": seq, + "num_adapters": 4, + "control_token_gain": 15.0, + } + ) + adapter_indices = result["adapter_indices"] + # Position 0 fires before any control: adapter 0 (base). + # Position 1 is the control for adapter index 3 (1-indexed: ctrl_idx 2 → adapter 3). + # SingleSwitch persists adapter id once fired → positions 1+ all 3. + assert adapter_indices[0] == 0 + assert adapter_indices[1] == 3 + assert adapter_indices[2] == 3 + assert adapter_indices[3] == 3 diff --git a/tutorials/scripts/reference/run_adapter_generation_direct.py b/tutorials/scripts/reference/run_adapter_generation_direct.py index 9998e5c..2d33b6a 100644 --- a/tutorials/scripts/reference/run_adapter_generation_direct.py +++ b/tutorials/scripts/reference/run_adapter_generation_direct.py @@ -62,7 +62,17 @@ def load_model(model_dir: str): def _generate(model, tokenizer, text: str, max_new_tokens: int) -> str: - """Generate text and return only the new tokens.""" + """Generate text and return only the new tokens. + + When ``_PROMPT_CAPTURE`` is active (set by build_demo_prompts), skip + generation entirely, capture the prompt on the thread-local list, and + return an empty string so each demo's subsequent logic (score parsing, + etc.) can no-op harmlessly. + """ + if _PROMPT_CAPTURE is not None: + _PROMPT_CAPTURE.append(text) + return "" + device = model.device inputs = tokenizer(text, return_tensors="pt").to(device) @@ -75,11 +85,77 @@ def _generate(model, tokenizer, text: str, max_new_tokens: int) -> str: return tokenizer.decode(generated_ids, skip_special_tokens=True).strip() +# Module-level capture switch. Populated by build_demo_prompts; None means +# the normal generate path runs. +_PROMPT_CAPTURE: Optional[list] = None + + +def build_demo_prompts( + tokenizer, available_adapters: Optional[set[str]] = None, +) -> list[tuple[str, str]]: + """Render every registered demo's prompt as a string, without generation. + + Returns a list of ``(demo_key, prompt_text)`` pairs for all demos whose + base adapter is present in ``available_adapters`` (or every registered + demo when the filter is None). The prompts are exactly what the demo + script would feed to ``model.generate`` — chat-template-rendered and + adapter-token-injected by the composed tokenizer. + + Used by the token-exchange parity eval + (tests/integration/test_token_exchange_parity.py) to compare legacy + hiding vs. token-exchange on realistic adapter inputs. + """ + global _PROMPT_CAPTURE + results: list[tuple[str, str]] = [] + _PROMPT_CAPTURE = [] + try: + for base_adapter, demo_fn in _DEMOS: + if available_adapters is not None and base_adapter not in available_adapters: + continue + demo_key = demo_fn.__name__.removeprefix("demo_") + _PROMPT_CAPTURE.clear() + try: + demo_fn(model=None, tokenizer=tokenizer, max_new_tokens=1) + except Exception as e: + # Some demos parse the (empty) output and may raise. Capture + # the prompt we already collected and move on; partial prompts + # are still useful for parity comparison. + if not _PROMPT_CAPTURE: + print(f"[build_demo_prompts] {demo_key}: {e}") + continue + for prompt_text in _PROMPT_CAPTURE: + results.append((demo_key, prompt_text)) + finally: + _PROMPT_CAPTURE = None + return results + + # --------------------------------------------------------------------------- # Activation helper — uses the composed model's chat template # --------------------------------------------------------------------------- +def _build_prompt( + tokenizer, + adapter_name: str, + messages: list[dict], + documents: Optional[list[dict]] = None, +) -> str: + """Render an adapter prompt using the composed model's chat template. + + Separated from _invoke so callers (e.g. the parity eval) can obtain the + exact prompt text without running generation. + """ + tmpl_kwargs: dict = { + "tokenize": False, + "add_generation_prompt": True, + "adapter_name": adapter_name, + } + if documents is not None: + tmpl_kwargs["documents"] = documents + return tokenizer.apply_chat_template(messages, **tmpl_kwargs) + + def _invoke( model, tokenizer, @@ -96,26 +172,8 @@ def _invoke( position for that adapter's technology (LoRA prefix vs aLoRA splice). See ``composer/tokenizer_setup.py`` for the template machinery. - - Args: - adapter_name: Name of the adapter to activate; must be one of - the composed model's ``adapter_names``. - messages: List of ``{"role", "content"}`` dicts. - documents: Optional list of ``{"doc_id", "text"}`` dicts, as - documented in the granite-switch README. - max_new_tokens: Generation budget. - - Returns: - The generated adapter output (new tokens only, decoded). """ - tmpl_kwargs: dict = { - "tokenize": False, - "add_generation_prompt": True, - "adapter_name": adapter_name, - } - if documents is not None: - tmpl_kwargs["documents"] = documents - prompt = tokenizer.apply_chat_template(messages, **tmpl_kwargs) + prompt = _build_prompt(tokenizer, adapter_name, messages, documents=documents) return _generate(model, tokenizer, prompt, max_new_tokens)