diff --git a/README.md b/README.md index 9bf4bacd..0e8fb1f8 100644 --- a/README.md +++ b/README.md @@ -89,25 +89,76 @@ Benchmark results for `gemma-4-26b-a4b-it-4bit` (26B MoE, 4-bit) on M5 Pro 64 GB --- -## ๐Ÿง  Supported Models & Methodologies +## ๐Ÿ“ก Supported Models & Methodologies -`SwiftLM` dynamically maps Apple MLX primitives to standard HuggingFace architectures, enabling complete support for the latest frontier open-weights models across modalities (Text, Vision, Audio). +`SwiftLM` dynamically maps Apple MLX primitives to standard HuggingFace architectures, enabling native Metal inference across the latest frontier open-weights models. -### Text (LLMs) -- **Gemma 4**: Fully supports both Dense (`gemma-4-e4b`) and Sparse Mixture of Experts (MoE) architectures (`gemma-4-26b`, `gemma-4-31b`). -- **Qwen 2.5 & 3**: Robust support for sliding window attention limits and custom RoPE scaling. -- **Mistral & Mixtral**: Out-of-the-box structural mappings. -- **Phi-3 & Phi-3.5**: Full 128k context parsing via Swift chunked-prefill. +### ๐Ÿ’ฌ Text (LLMs) -### Vision (VLMs) +| Family | Models | Notes | +|---|---|---| +| **Gemma 4** | `gemma-4-e2b`, `gemma-4-e4b` (dense) ยท `gemma-4-26b-a4b`, `gemma-4-31b` (MoE) | Interleaved local + global attention; KV sharing; native quantized KV cache (issue #71 fix) | +| **Gemma 3 / 3n** | `gemma-3-*`, `gemma-3n-*` | Google Gemma 3 and nano variants | +| **Gemma / Gemma 2** | `gemma-*`, `gemma-2-*` | Original Gemma family | +| **Qwen 3.5** | `Qwen3.5-7B`, `Qwen3.5-27B`, `Qwen3.5-122B-A10B`, `Qwen3.5-397B-A22B` | Dense + MoE; SSD streaming at 10ร— for 122B/397B | +| **Qwen 3** | `Qwen3-*` (dense + MoE) | Sliding window + hybrid attention | +| **Qwen 2.5** | `Qwen2.5-7B`, `Qwen2.5-14B`, `Qwen2.5-72B` | Robust RoPE scaling | +| **Qwen 2** | `Qwen2-*` | Linear RoPE variants | +| **Phi 4 / PhiMoE** | `phi-4-mlx`, `Phi-3.5-MoE` | Microsoft Phi family incl. MoE | +| **Phi 3 / Phi** | `Phi-3`, `Phi-3.5-mini` | 128k context via chunked prefill | +| **Mistral / Mixtral** | `Mistral-7B`, `Mistral-4`, `Mixtral-*` | GQA + sliding window variants | +| **Llama / Llama 3** | `Llama-3.1-*`, `Llama-3.2-*`, `Llama-3.3-*` | YaRN + dynamic NTK RoPE scaling | +| **GLM 4** | `GLM-4-*` | THUDM GLM-4 dense + MoE-Lite variants | +| **DeepSeek V3** | `DeepSeek-V3-*` | MLA attention architecture | +| **Falcon H1** | `Falcon-H1-*` | Falcon hybrid SSM+attention | +| **LFM 2** | `LFM2-*`, `LFM2-MoE-*` | Liquid AI dense + MoE | +| **OLMo 2 / OLMo 3 / OLMoE** | `OLMo-2-*`, `OLMo-3-*` | AllenAI open language models | +| **Granite / GraniteMoE** | `Granite-*`, `GraniteMoE-Hybrid-*` | IBM Granite hybrid Mamba+attention | +| **SmolLM 3** | `SmolLM3-*` | HuggingFace compact LM | +| **MiniCPM** | `MiniCPM-*` | Lightweight efficient LM | +| **InternLM 2** | `InternLM2-*` | Shanghai AI Lab series | +| **Cohere / Command-R** | `Command-R-*`, `c4ai-*` | Cohere retrieval-tuned models | +| **Jamba** | `Jamba-v0.1` | AI21 hybrid Mamba+attention | +| **Exaone 4** | `EXAONE-4.0-*` | LG AI Research | +| **MiMo / MiMo V2** | `MiMo-7B-*` | Xiaomi reasoning model | +| **Ernie 4.5** | `ERNIE-4.5-*` | Baidu ERNIE series | +| **Baichuan M1** | `Baichuan-M1-*` | Baichuan multimodal base | +| **Bailing MoE** | `Ling-*` | Bailing/Ling MoE family | +| **NemotronH** | `Nemotron-H-*` | NVIDIA Nemotron hybrid | +| **Starcoder 2** | `starcoder2-*` | Code generation | +| **OpenELM** | `OpenELM-*` | Apple on-device efficient LM | +| **Apertus / AfMoE** | `Apertus-*` | Sparse MoE research models | +| **BitNet** | `bitnet-*` | 1-bit weight quantization | +| **MiniMax** | `MiniMax-Text-*` | Lightning attention architecture | +| **Olmo3** | `Olmo3-*` | AllenAI Olmo3 series | + +### ๐Ÿ‘๏ธ Vision (VLMs) *Run with `--vision` flag.* -- **Qwen2-VL & Qwen3-VL**: Real-time positional bounding and Metal image scaling. -- **PaliGemma / LFM2-VL / Pixtral**: Base64 spatial decomposition. -### Audio (ALMs) -*Run with `--audio` flag.* -- **Qwen2-Audio (7B-Instruct)**: Deep multi-modal spectrogram processing via Swift audio interleaving. -- **Gemma-4 Audio Pipelines**: Ready for Audio-in/Text-out variants mapping `.audio_tower` extraction parameters natively off NVMe. +| Family | Models | Notes | +|---|---|---| +| **Gemma 4** | `gemma-4-*` (VLM mode) | Native image tower via MLXVLM | +| **Gemma 3** | `gemma-3-*` (VLM mode) | PaLiGemma-style image projection | +| **Qwen3-VL / Qwen3.5-VL** | `Qwen3-VL-*`, `Qwen3.5-VL-*` | Dynamic resolution with native RoPE | +| **Qwen2-VL / Qwen2.5-VL** | `Qwen2-VL-2B/7B`, `Qwen2.5-VL-*` | Real-time positional bounding + Metal image scaling | +| **LFM2-VL** | `LFM2-VL-1.6B` | Liquid AI multimodal | +| **Pixtral** | `pixtral-12b` | Mistral vision model | +| **PaliGemma** | `paligemma-*` | Google vision-language | +| **Idefics 3** | `Idefics3-*` | HuggingFace multimodal | +| **Mistral 3** | `Mistral-Small-3.1-*` | Mistral vision variant | +| **FastVLM** | `FastVLM-*` | Apple on-device VLM | +| **SmolVLM 2** | `SmolVLM2-*` | HuggingFace compact VLM | +| **GLM OCR** | `glm-4v-*` | THUDM vision+OCR | +| **QwenVL** | `Qwen-VL-*` | Original Qwen VL | + +### ๐ŸŽง Audio (ALMs) +*Run with `--audio` flag. Only `gemma-4-e4b` variants include an audio tower.* + +| Family | Models | Notes | +|---|---|---| +| **Gemma 4 Omni** | `gemma-4-e4b-it-4bit`, `gemma-4-e4b-it-8bit` | Audio-in via vDSP STFT โ†’ Mel spectrogram (16kHz, 128 bins); text-out | + + --- @@ -352,10 +403,46 @@ curl http://localhost:5413/v1/chat/completions \ | `--min-p` | `0.0` | Default min-p sampling threshold relative to the highest probability token (0 disables) | | `--gpu-layers` | `model_default`| Restrict the amount of layers allocated to GPU hardware | | `--stream-experts` | `false` | Enable SSD expert streaming for MoE models (10x speedup) | -| `--turbo-kv` | `false` | Enable TurboQuant 3-bit KV cache compression | +| `--turbo-kv` | `false` | Enable TurboQuant 3-bit KV cache compression (activates after 2048 tokens, server-wide) | | `--draft-model` | (none) | Draft model path/ID for speculative decoding (in-RAM models only) | | `--num-draft-tokens` | `4` | Number of draft tokens per speculation round | +## ๐Ÿ”ง Per-Request API Parameters + +In addition to the standard OpenAI fields (`temperature`, `top_p`, `max_tokens`, etc.), SwiftLM accepts the following **SwiftLM-specific** fields on `POST /v1/chat/completions`: + +| Field | Type | Description | +|---|---|---| +| `kv_bits` | `int` (4 or 8) | Enable **MLX-native quantized KV cache** for this request. Uses `QuantizedKVCache` (standard group quantization) instead of `KVCacheSimple`. Separate from `--turbo-kv`. Reduces KV memory ~2โ€“4ร— at mild quality cost. | +| `enable_thinking` | `bool` | Force-enable or disable chain-of-thought thinking blocks for Gemma-4 / Qwen3. | +| `kv_group_size` | `int` | Group size for `kv_bits` quantization (default: `64`). | +| `top_k` | `int` | Per-request top-k sampling override (0 = disabled). | +| `min_p` | `float` | Per-request min-p sampling threshold (0 = disabled). | +| `repetition_penalty` | `float` | Token repetition penalty (e.g. `1.15`). | + +### `kv_bits` vs `--turbo-kv` โ€” What's the difference? + +| | `kv_bits` (per-request) | `--turbo-kv` (server flag) | +|---|---|---| +| **Scope** | Per-request, sent in JSON body | Server-wide, set at startup | +| **Algorithm** | MLX-native group quantization (4-bit / 8-bit) | Custom 3-bit PolarQuant + QJL Walsh-Hadamard | +| **Activation** | From token 0 | After 2048 tokens | +| **Memory savings** | ~2โ€“4ร— vs FP16 | ~3.5ร— vs FP16 | +| **Use case** | Targeted memory reduction per conversation | Extreme long-context (100K+) compression | + +### Example: Enable 4-bit KV cache per request +```bash +curl http://localhost:5413/v1/chat/completions \\ + -H "Content-Type: application/json" \\ + -d '{ + "model": "gemma-4-26b-a4b-it-4bit", + "kv_bits": 4, + "messages": [ + {"role": "user", "content": "Summarize the history of computing in 3 sentences."} + ] + }' +``` + ## ๐Ÿ“ฆ Requirements - macOS 14.0+ diff --git a/Sources/SwiftLM/Server.swift b/Sources/SwiftLM/Server.swift index 17d68d37..c6e416d3 100644 --- a/Sources/SwiftLM/Server.swift +++ b/Sources/SwiftLM/Server.swift @@ -1048,9 +1048,20 @@ func handleChatCompletion( // These are accepted but may not affect generation if MLX doesn't support them } + // โ”€โ”€ Validate kv_bits: only nil, 4, and 8 are supported โ”€โ”€ + if let kb = chatReq.kvBits, kb != 4 && kb != 8 { + let errBody = "{\"error\":{\"message\":\"Invalid kv_bits value \(kb). Supported values are 4 and 8.\",\"type\":\"invalid_request_error\",\"code\":\"invalid_kv_bits\"}}" + return Response( + status: .badRequest, + headers: jsonHeaders(), + body: .init(byteBuffer: ByteBuffer(string: errBody)) + ) + } + let params = GenerateParameters( maxTokens: tokenLimit, maxKVSize: config.ctxSize, + kvBits: chatReq.kvBits, temperature: temperature, topP: topP, topK: topK, @@ -1200,9 +1211,13 @@ func handleChatCompletion( // raw <|image|>/<|audio|> token embeddings instead of the projected features. let isMultimodalRequest = lmInput.image != nil || lmInput.audio != nil - // Try to restore via token-by-token prefix match (llama-server style) + // Try to restore via token-by-token prefix match (llama-server style). + // Skip for quantized-KV requests: the prompt cache stores KV state produced + // with KVCacheSimple; restoring it into a QuantizedKVCache (or vice-versa) + // is unsafe and produces incorrect results or runtime failures. + let skipPromptCache = isMultimodalRequest || params.kvBits != nil var stream: AsyncStream - if !isMultimodalRequest, let cachedCount = await promptCache.restore(newTokens: promptTokens, into: cache) { + if !skipPromptCache, let cachedCount = await promptCache.restore(newTokens: promptTokens, into: cache) { // Cache hit: KV state is pre-populated up to cachedCount tokens. // Only compute the remaining (new) tokens. var startIndex = cachedCount @@ -1251,6 +1266,10 @@ func handleChatCompletion( let onPrefillDone: (() async -> Void)? = { if turboHasCompressed { print("[SwiftLM] ๐Ÿง  Skipping prompt cache save โ€” TurboQuant has compressed \(cache.compactMap { ($0 as? KVCacheSimple)?.compressedOffset }.max() ?? 0) tokens. Saving would decode ~37 GB back to fp16.") + } else if params.kvBits != nil { + // kv_bits is set: the cache contains QuantizedKVCache layers whose token + // format is incompatible with the FP16 KVCacheSimple format expected by + // promptCache.save. Skip saving to prevent unsafe mixed-format restores. } else { await promptCache.save(tokens: promptTokens, cache: cache) } @@ -2305,6 +2324,10 @@ struct ChatCompletionRequest: Decodable { let chatTemplateKwargs: [String: Bool]? /// Top-level thinking override emitted by Aegis-AI gateway let enableThinking: Bool? + /// Number of bits for native MLX quantized KV cache (nil = no quantization). + /// Only 4 and 8 are supported by the underlying MLX QuantizedKVCache. + /// Enables `QuantizedKVCache` instead of `KVCacheSimple`. Separate from `--turbo-kv`. + let kvBits: Int? enum CodingKeys: String, CodingKey { case model, messages, stream, temperature, tools, stop, seed @@ -2319,6 +2342,7 @@ struct ChatCompletionRequest: Decodable { case responseFormat = "response_format" case chatTemplateKwargs = "chat_template_kwargs" case enableThinking = "enable_thinking" + case kvBits = "kv_bits" } } diff --git a/mlx-swift-lm b/mlx-swift-lm index 71a77e07..63707c0c 160000 --- a/mlx-swift-lm +++ b/mlx-swift-lm @@ -1 +1 @@ -Subproject commit 71a77e07b4936599cc40c4a423458c2bc834a0cc +Subproject commit 63707c0ccde78daa63ceb0575af52edc9d941c07 diff --git a/run_benchmark.sh b/run_benchmark.sh index 8ad40921..88b1dc86 100755 --- a/run_benchmark.sh +++ b/run_benchmark.sh @@ -102,8 +102,9 @@ echo "5) Test 5: ALM Audio End-to-End Evaluation" echo "6) Test 6: Omni End-to-End Evaluation" echo "7) Model Maintain List and Delete" echo "8) Test 8: Tool-Call Degeneration Regression (Gemma-4 vague-query bug)" -echo "9) Quit" -read -p "Option (0-9): " suite_opt +echo "9) Test 9: Quantized KV Cache Regression (Gemma-4 issue #71 โ€” native kv_bits)" +echo "q) Quit" +read -p "Option (0-9/q): " suite_opt if [ "$suite_opt" == "0" ]; then echo "==============================================" @@ -131,12 +132,13 @@ if [ "$suite_opt" == "0" ]; then exit 0 fi -if [ "$suite_opt" == "9" ] || [ "$suite_opt" == "8" ] || [ -z "$suite_opt" ]; then - # 9 = Quit (old 8), 8 = Test 8 โ€” only exit on 9 or blank - if [ "$suite_opt" == "9" ] || [ -z "$suite_opt" ]; then - echo "Exiting." - exit 0 - fi +if [ "$suite_opt" == "q" ] || [ -z "$suite_opt" ]; then + echo "Exiting." + exit 0 +fi + +if [ "$suite_opt" == "9" ] || [ "$suite_opt" == "8" ]; then + : # handled below โ€” fall through fi if [ "$suite_opt" == "7" ]; then @@ -969,6 +971,180 @@ EOF exit 0 fi +# โ”€โ”€ Test 9: QuantizedKVCache Regression (issue #71) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ +# Verifies that Gemma-4 text models can decode with native MLX QuantizedKVCache +# (kv_bits=4 and kv_bits=8) without triggering the: +# fatalError: `update` was called on `QuantizedKVCache`. Use `updateQuantized`. +# crash fixed in PR #29 of mlx-swift-lm. +# +# Pass criteria: +# - 4-bit run: server does not crash, returns non-empty text response (โ‰ฅ3 tokens) +# - 8-bit run: same +# - Longer prompt run: exercises the last-20-layer KV-sharing path, same pass criteria +# - Baseline (no kv_bits): regression guard that the non-quantized path still works +if [ "$suite_opt" == "9" ]; then + echo "" + echo "=> Test 9: Quantized KV Cache Regression (issue #71) on $FULL_MODEL" + echo " Tests MLX native QuantizedKVCache (kv_bits=4, kv_bits=8) โ€” NOT TurboKV" + echo " This exercises the fix in mlx-swift-lm PR #29." + + echo "Starting server on port 5431..." + killall SwiftLM 2>/dev/null + mkdir -p tmp + # No --turbo-kv flag: we want the vanilla KVCacheSimple path that will be + # upgraded to QuantizedKVCache by the per-request kv_bits field. + $BIN --model "$FULL_MODEL" --port 5431 --stream-experts --ctx-size 8192 > ./tmp/kvcache_regression.log 2>&1 & + SERVER_PID=$! + + SERVER_READY=0 + for i in {1..180}; do + if ! kill -0 $SERVER_PID 2>/dev/null; then + echo "โŒ Server died early. Logs:" + print_server_log ./tmp/kvcache_regression.log + exit 1 + fi + if curl -sf http://127.0.0.1:5431/health > /dev/null 2>&1; then + echo "Server ready (${i}s)" + SERVER_READY=1 + break + fi + sleep 1 + done + if [ $SERVER_READY -eq 0 ]; then + echo "โŒ Server not ready after 180s. Logs:" + print_server_log ./tmp/kvcache_regression.log + kill $SERVER_PID 2>/dev/null + exit 1 + fi + + echo "" + echo "Running QuantizedKVCache regression suite..." + + python3 - << 'KVBITS_EOF' +import json, urllib.request, time, sys, re + +BASE = "http://127.0.0.1:5431" + +FAILS = [] + +def call(messages, kv_bits=None, max_tokens=60, temperature=0.0): + payload = { + "messages": messages, + "max_tokens": max_tokens, + "temperature": temperature, + "stream": False, + } + if kv_bits is not None: + payload["kv_bits"] = kv_bits + req = urllib.request.Request( + f"{BASE}/v1/chat/completions", + data=json.dumps(payload).encode(), + headers={"Content-Type": "application/json"}, + ) + t0 = time.time() + try: + with urllib.request.urlopen(req, timeout=180) as r: + d = json.loads(r.read()) + except Exception as e: + return None, str(e), time.time() - t0 + elapsed = time.time() - t0 + content = d["choices"][0]["message"].get("content") or "" + # Strip Gemma-4 thinking blocks โ€” handle both <|channel|>thought and <|channel>thought variants + content = re.sub(r"<\|channel\|?>thought.*?", "", content, flags=re.DOTALL).strip() + return d, content, elapsed + +MSGS_SHORT = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Name the three primary colours. Be brief."}, +] + +# Longer prompt to exercise the KV sharing layers (last 20 of Gemma-4 share KV +# from earlier layers โ€” the bug manifests at those layers on multi-token prefills). +MSGS_LONG = [ + {"role": "system", "content": "You are a knowledgeable AI assistant. Answer concisely."}, + {"role": "user", "content": "Explain in two sentences why the sky appears blue during the day and red at sunset. Use physics terminology."}, +] + +# โ”€โ”€ [1] 4-bit quantized KV cache โ”€โ”€ +print("\nโ”€โ”€โ”€ [1/4] kv_bits=4, short prompt โ”€โ”€โ”€") +d, content, t = call(MSGS_SHORT, kv_bits=4) +if d is None: + print(f" โŒ CRASHED: {content}") + FAILS.append("kv_bits=4 short: server crash or timeout") +else: + gen_toks = d["usage"]["completion_tokens"] + ok = len(content.strip()) > 5 and gen_toks >= 3 + print(f" {'โœ…' if ok else 'โŒ'} [{t:.1f}s, {gen_toks} tokens]: {content[:100]}") + if not ok: + FAILS.append(f"kv_bits=4 short: too few tokens or empty ({gen_toks} tokens)") + +# โ”€โ”€ [2] 8-bit quantized KV cache โ”€โ”€ +print("\nโ”€โ”€โ”€ [2/4] kv_bits=8, short prompt โ”€โ”€โ”€") +d, content, t = call(MSGS_SHORT, kv_bits=8) +if d is None: + print(f" โŒ CRASHED: {content}") + FAILS.append("kv_bits=8 short: server crash or timeout") +else: + gen_toks = d["usage"]["completion_tokens"] + ok = len(content.strip()) > 5 and gen_toks >= 3 + print(f" {'โœ…' if ok else 'โŒ'} [{t:.1f}s, {gen_toks} tokens]: {content[:100]}") + if not ok: + FAILS.append(f"kv_bits=8 short: too few tokens or empty ({gen_toks} tokens)") + +# โ”€โ”€ [3] 4-bit, longer prompt (exercises KV-sharing layers) โ”€โ”€ +print("\nโ”€โ”€โ”€ [3/4] kv_bits=4, longer prompt (exercises KV-sharing path) โ”€โ”€โ”€") +d, content, t = call(MSGS_LONG, kv_bits=4, max_tokens=120) +if d is None: + print(f" โŒ CRASHED: {content}") + FAILS.append("kv_bits=4 long: server crash or timeout") +else: + gen_toks = d["usage"]["completion_tokens"] + ok = len(content.strip()) > 10 and gen_toks >= 5 + print(f" {'โœ…' if ok else 'โŒ'} [{t:.1f}s, {gen_toks} tokens]: {content[:120]}") + if not ok: + FAILS.append(f"kv_bits=4 long: too few tokens or empty ({gen_toks} tokens)") + +# โ”€โ”€ [4] Baseline without kv_bits (must still work โ€” regression guard) โ”€โ”€ +print("\nโ”€โ”€โ”€ [4/4] kv_bits=None baseline (no quantization) โ”€โ”€โ”€") +d, content, t = call(MSGS_SHORT, kv_bits=None) +if d is None: + print(f" โŒ CRASHED: {content}") + FAILS.append("baseline (no kv_bits): server crash or timeout") +else: + gen_toks = d["usage"]["completion_tokens"] + ok = len(content.strip()) > 5 and gen_toks >= 3 + print(f" {'โœ…' if ok else 'โŒ'} [{t:.1f}s, {gen_toks} tokens]: {content[:100]}") + if not ok: + FAILS.append(f"baseline: too few tokens or empty ({gen_toks} tokens)") + +print("\n" + "โ”€" * 60) +if not FAILS: + print("โœ… REGRESSION PASSED โ€” QuantizedKVCache dispatches correctly.") + print(" kv_bits=4 โœ“ | kv_bits=8 โœ“ | KV-sharing path โœ“ | baseline โœ“") + sys.exit(0) +else: + print("โŒ REGRESSION FAILED:") + for f in FAILS: + print(f" โ€ข {f}") + print("\n Root cause (if kv_bits runs crash): unconditional `cache.update()` call") + print(" in Gemma4TextAttention.callAsFunction โ€” see mlx-swift-lm PR #29.") + sys.exit(1) +KVBITS_EOF + TEST9_EXIT=$? + + echo "" + echo "Cleaning up..." + kill $SERVER_PID 2>/dev/null + wait $SERVER_PID 2>/dev/null + + if [ $TEST9_EXIT -eq 0 ]; then + echo "โœ… Test 9 PASSED" + else + echo "โŒ Test 9 FAILED โ€” see output above." + fi + exit $TEST9_EXIT +fi + # Fallback to Test 1 for anything else echo "" read -p "Enter context lengths to test [default: 512,40000,100000]: " CONTEXTS