Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/llama.cpp
3,453 changes: 3,453 additions & 0 deletions data/log.txt

Large diffs are not rendered by default.

275 changes: 275 additions & 0 deletions docs/bitnet-embeddings-qwen3-gguf-conversion.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
# BitNet Embeddings (Qwen3) GGUF Conversion Implementation

## 1. Background

`bitnet-embeddings-0.6b` is a Qwen3-based embedding model with BitNet per-projection RMSNorm (`BitLinear`). Each linear projection (q/k/v/o/gate/up/down) has a `.norm.weight` that applies RMSNorm to the input **before** the matmul:

```
x → RMSNorm(x, norm.weight) → activation_quant(8bit) → matmul(weight_quant(ternary))
```

This pattern does **not** exist in any standard llama.cpp architecture:
- Standard Qwen3: no per-projection norms
- Standard BitNet: has `attn_sub_norm`/`ffn_sub_norm` at different positions (after attention/gate*up, not before each projection)

### Model Config

- Architecture: `Qwen3Model`
- hidden_size: 1024, num_attention_heads: 16, num_key_value_heads: 8
- head_dim: 128 (note: != hidden_size/num_heads = 64)
- intermediate_size: 3072, num_hidden_layers: 28
- tie_word_embeddings: true
- rope_theta: 1000000, rms_norm_eps: 1e-06

### Per-Layer Tensors (7 extra norm tensors per layer)

| Tensor | Shape |
|--------|-------|
| `self_attn.q_proj.norm.weight` | [1024] |
| `self_attn.k_proj.norm.weight` | [1024] |
| `self_attn.v_proj.norm.weight` | [1024] |
| `self_attn.o_proj.norm.weight` | [2048] |
| `mlp.gate_proj.norm.weight` | [1024] |
| `mlp.up_proj.norm.weight` | [1024] |
| `mlp.down_proj.norm.weight` | [3072] |

---

## 2. Implementation Plan

### Step 1: Conversion Script
Create a standalone Python script to convert safetensors → GGUF with proper tensor name mapping for all 7 per-projection norms.

### Step 2: C++ llama.cpp Modifications
Add support for the new tensor types in `llama.cpp`: enums, name mappings, struct fields, loading, and inference graph construction.

### Step 3: Precision Verification
Verify tensor-level and inference-level precision alignment.

---

## 3. GGUF Tensor Name Mapping

| HF Name | GGUF Name | Notes |
|----------|-----------|-------|
| `embed_tokens.weight` | `token_embd.weight` | |
| `norm.weight` | `output_norm.weight` | |
| `layers.{i}.input_layernorm.weight` | `blk.{i}.attn_norm.weight` | |
| `layers.{i}.post_attention_layernorm.weight` | `blk.{i}.ffn_norm.weight` | |
| `layers.{i}.self_attn.q_proj.weight` | `blk.{i}.attn_q.weight` | |
| `layers.{i}.self_attn.k_proj.weight` | `blk.{i}.attn_k.weight` | |
| `layers.{i}.self_attn.v_proj.weight` | `blk.{i}.attn_v.weight` | |
| `layers.{i}.self_attn.o_proj.weight` | `blk.{i}.attn_output.weight` | |
| `layers.{i}.self_attn.q_norm.weight` | `blk.{i}.attn_q_norm.weight` | QK head norm |
| `layers.{i}.self_attn.k_norm.weight` | `blk.{i}.attn_k_norm.weight` | QK head norm |
| `layers.{i}.self_attn.q_proj.norm.weight` | `blk.{i}.attn_q_norm_in.weight` | **NEW** |
| `layers.{i}.self_attn.k_proj.norm.weight` | `blk.{i}.attn_k_norm_in.weight` | **NEW** |
| `layers.{i}.self_attn.v_proj.norm.weight` | `blk.{i}.attn_v_norm_in.weight` | **NEW** |
| `layers.{i}.self_attn.o_proj.norm.weight` | `blk.{i}.attn_output_norm_in.weight` | **NEW** |
| `layers.{i}.mlp.gate_proj.weight` | `blk.{i}.ffn_gate.weight` | |
| `layers.{i}.mlp.up_proj.weight` | `blk.{i}.ffn_up.weight` | |
| `layers.{i}.mlp.down_proj.weight` | `blk.{i}.ffn_down.weight` | |
| `layers.{i}.mlp.gate_proj.norm.weight` | `blk.{i}.ffn_gate_norm_in.weight` | **NEW** |
| `layers.{i}.mlp.up_proj.norm.weight` | `blk.{i}.ffn_up_norm_in.weight` | **NEW** |
| `layers.{i}.mlp.down_proj.norm.weight` | `blk.{i}.ffn_down_norm_in.weight` | **NEW** |

---

## 4. New Files Created

### `utils/convert-bitnet-embedding-to-gguf.py`

Standalone conversion script. Key features:
- Hardcoded HF→GGUF tensor name mapping (no dependency on llama.cpp's Python converter)
- Supports `--outtype f16` (2D weights as f16, norms as f32) and `--outtype f32`
- Writes `key_length` and `value_length` metadata for head_dim=128 (critical: default calculation would give wrong value 64)
- GPT-2 BPE tokenizer handling for Qwen3
- Architecture string: `"qwen3"`

Usage:
```bash
python3 utils/convert-bitnet-embedding-to-gguf.py \
/data2/huangxin/model_list/microsoft_release_multilingual_models/bitnet-embeddings-0.6b \
--outfile output-f16.gguf --outtype f16
```

### `scripts/verify_gguf_precision.py`

Two-level precision verification:
- **Level 1**: Tensor-level comparison (safetensors vs GGUF, accounting for bf16→f16 conversion)
- **Level 2**: Inference-level comparison (PyTorch with BitLinear vs llama-embedding binary)

Usage:
```bash
python3 utils/verify_gguf_precision.py \
--model-dir /data2/.../bitnet-embeddings-0.6b \
--gguf-file output-f16.gguf --level both
```

### `scripts/verify_inference_precision.py`

Per-token hidden state comparison with monkey-patched BitLinear (disabling activation/weight quantization for fair f16 comparison).

---

## 5. C++ Modifications (`3rdparty/llama.cpp/src/llama.cpp`)

### 5.1 New Tensor Enums

Added 7 new entries after `LLM_TENSOR_FFN_SUB_NORM`:

```cpp
LLM_TENSOR_ATTN_Q_NORM_IN,
LLM_TENSOR_ATTN_K_NORM_IN,
LLM_TENSOR_ATTN_V_NORM_IN,
LLM_TENSOR_ATTN_OUT_NORM_IN,
LLM_TENSOR_FFN_GATE_NORM_IN,
LLM_TENSOR_FFN_UP_NORM_IN,
LLM_TENSOR_FFN_DOWN_NORM_IN,
```

### 5.2 Tensor Name Mappings

Added to `LLM_ARCH_QWEN3` tensor name map:

```cpp
{ LLM_TENSOR_ATTN_Q_NORM_IN, "blk.%d.attn_q_norm_in" },
{ LLM_TENSOR_ATTN_K_NORM_IN, "blk.%d.attn_k_norm_in" },
{ LLM_TENSOR_ATTN_V_NORM_IN, "blk.%d.attn_v_norm_in" },
{ LLM_TENSOR_ATTN_OUT_NORM_IN, "blk.%d.attn_output_norm_in" },
{ LLM_TENSOR_FFN_GATE_NORM_IN, "blk.%d.ffn_gate_norm_in" },
{ LLM_TENSOR_FFN_UP_NORM_IN, "blk.%d.ffn_up_norm_in" },
{ LLM_TENSOR_FFN_DOWN_NORM_IN, "blk.%d.ffn_down_norm_in" },
```

### 5.3 Layer Struct Fields

Added to `struct llama_layer`:

```cpp
struct ggml_tensor * attn_q_norm_in;
struct ggml_tensor * attn_k_norm_in;
struct ggml_tensor * attn_v_norm_in;
struct ggml_tensor * attn_out_norm_in;
struct ggml_tensor * ffn_gate_norm_in;
struct ggml_tensor * ffn_up_norm_in;
struct ggml_tensor * ffn_down_norm_in;
```

### 5.4 load_tensors (LLM_ARCH_QWEN3)

Added optional loading with `TENSOR_NOT_REQUIRED`:

```cpp
layer.attn_q_norm_in = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM_IN, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED);
layer.attn_k_norm_in = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM_IN, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED);
layer.attn_v_norm_in = create_tensor(tn(LLM_TENSOR_ATTN_V_NORM_IN, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED);
layer.attn_out_norm_in = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM_IN, "weight", i), {n_embd_head_k * n_head}, TENSOR_NOT_REQUIRED);
layer.ffn_gate_norm_in = create_tensor(tn(LLM_TENSOR_FFN_GATE_NORM_IN, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED);
layer.ffn_up_norm_in = create_tensor(tn(LLM_TENSOR_FFN_UP_NORM_IN, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED);
layer.ffn_down_norm_in = create_tensor(tn(LLM_TENSOR_FFN_DOWN_NORM_IN, "weight", i), {n_ff}, TENSOR_NOT_REQUIRED);
```

Note: `o_proj.norm` input dimension is `n_embd_head_k * n_head` (=2048), `down_proj.norm` input dimension is `n_ff` (=3072).

### 5.5 build_qwen3() Graph Modifications

The `build_qwen3()` function was modified to conditionally apply per-projection RMSNorm. The logic is fully backward compatible — when no `*_norm_in` tensors exist, behavior is identical to original.

**Attention per-projection norms:**
```
// Before Q/K/V matmul:
if (layer.attn_q_norm_in) {
cur_q = ggml_rms_norm(ctx, cur, hparams.f_norm_rms_eps);
cur_q = ggml_mul(ctx, cur_q, layer.attn_q_norm_in);
} else {
cur_q = cur;
}
Qcur = ggml_mul_mat(ctx, layer.wq, cur_q);
// Similarly for K, V
```

**O_proj norm** requires special handling because `llm_build_kv()` normally applies `wo` internally. Solution: pass `wo=NULL` to `llm_build_kv()`, then apply norm + wo manually:

```
cur = llm_build_kv(..., wo=NULL, ...); // returns attention output without o_proj
if (layer.attn_out_norm_in) {
cur = ggml_rms_norm(ctx, cur, hparams.f_norm_rms_eps);
cur = ggml_mul(ctx, cur, layer.attn_out_norm_in);
}
cur = ggml_mul_mat(ctx, layer.wo, cur);
```

**FFN per-projection norms:**
```
// Instead of llm_build_ffn(), manually:
if (layer.ffn_gate_norm_in) {
tmp_gate = rms_norm(cur) * gate_norm_in;
} else {
tmp_gate = cur;
}
tmp_gate = matmul(gate_proj, tmp_gate);
// Similarly for up_proj
tmp = silu(tmp_gate) * tmp_up;

if (layer.ffn_down_norm_in) {
tmp = rms_norm(tmp) * down_norm_in;
}
cur = matmul(down_proj, tmp);
```

---

## 6. Key Issues Encountered and Solutions

### Issue 1: Missing `output_norm.weight`
The model has `norm.weight` in safetensors but it wasn't being mapped. Added `"norm.weight": "output_norm.weight"` to the mapping.

### Issue 2: Wrong head_dim Calculation
`head_dim=128` but `hidden_size/num_attention_heads = 1024/16 = 64`. C++ defaulted to 64, causing shape mismatch (`expected 1024,1024 got 1024,2048` for K/V). Fixed by writing `key_length` and `value_length` metadata in the GGUF.

### Issue 3: llama-embedding Output Parsing
Initial approach failed to parse truncated output format. Fixed by using `--embd-output-format array` for clean JSON output.

### Issue 4: PyTorch Model Ignoring Per-Projection Norms
`AutoModel.from_pretrained` uses standard Qwen3Model which doesn't know about `.norm.weight` tensors. Fixed by using `replace_linear_with_bitlinear()` from bitnet-embeddings repo and reloading weights.

### Issue 5: Inference Precision Mismatch with BitLinear Active
PyTorch with `activation_quant`/`weight_quant` produces different results than llama.cpp f16 (expected). Fixed by monkey-patching `BitLinear.forward` to skip quantization for fair comparison.

---

## 7. Verification Results

### Level 1: Tensor Precision
- 506 tensors compared
- **Zero error** across all tensors (exact match after bf16→f16 conversion)

### Level 2: Inference Precision (f16, no activation/weight quant)
- Per-token hidden state cosine similarity: **> 0.9999999** for all test cases
- Test texts: "hello world", "The quick brown fox...", "机器学习是人工智能的一个分支"

---

## 8. Build and Run

```bash
# 1. Convert to GGUF
python3 utils/convert-bitnet-embedding-to-gguf.py \
/data2/huangxin/model_list/microsoft_release_multilingual_models/bitnet-embeddings-0.6b \
--outfile bitnet-embeddings-0.6b-f16.gguf --outtype f16

# 2. Build llama.cpp
cd /home/huangxin/code_list/BitNet
cmake -B build -DLLAMA_NATIVE=OFF
cmake --build build --target llama-embedding -j$(nproc)

# 3. Verify tensor precision
python3 scripts/verify_gguf_precision.py \
--model-dir /data2/huangxin/model_list/microsoft_release_multilingual_models/bitnet-embeddings-0.6b \
--gguf-file bitnet-embeddings-0.6b-f16.gguf --level 1

# 4. Run embedding inference
build/bin/llama-embedding -m bitnet-embeddings-0.6b-f16.gguf \
-p "hello world" --embd-normalize 2 --embd-output-format array
```
Loading