From 4d878fa33738900b1f81fac0217b71439be904be Mon Sep 17 00:00:00 2001 From: ailuntz Date: Tue, 10 Mar 2026 16:11:10 +0800 Subject: [PATCH 1/2] Guard SCB access in Linear8bitLt --- bitsandbytes/nn/modules.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index ff8ce9c3d..11dc13e37 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -1054,7 +1054,7 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): scb_name = "SCB" # case 1: .cuda was called, SCB is in self.weight - param_from_weight = getattr(self.weight, scb_name) + param_from_weight = getattr(self.weight, scb_name, None) # case 2: self.init_8bit_state was called, SCB is in self.state param_from_state = getattr(self.state, scb_name) @@ -1095,7 +1095,8 @@ def _load_from_state_dict( for key in unexpected_copy: input_name = key[len(prefix) :] if input_name == "SCB": - if self.weight.SCB is None: + weight_scb = getattr(self.weight, "SCB", None) + if weight_scb is None: # buffers not yet initialized, can't access them directly without quantizing first raise RuntimeError( "Loading a quantized checkpoint into non-quantized Linear8bitLt is " @@ -1103,7 +1104,7 @@ def _load_from_state_dict( ) input_param = state_dict[key] - self.weight.SCB.copy_(input_param) + weight_scb.copy_(input_param) if self.state.SCB is not None: self.state.SCB = self.weight.SCB @@ -1133,7 +1134,7 @@ def to(self, *args, **kwargs): def forward(self, x: torch.Tensor): self.state.is_training = self.training - if self.weight.CB is not None: + if getattr(self.weight, "CB", None) is not None: self.init_8bit_state() # weights are cast automatically as Int8Params, but the bias has to be cast manually @@ -1142,7 +1143,7 @@ def forward(self, x: torch.Tensor): out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) - if not self.state.has_fp16_weights and self.state.CB is not None: + if not self.state.has_fp16_weights and self.state.CB is not None and hasattr(self.weight, "CB"): self.weight.data = self.state.CB return out From b301b772035c9febcd49046869df7f5d488a6826 Mon Sep 17 00:00:00 2001 From: ailuntz Date: Wed, 1 Apr 2026 16:00:55 +0800 Subject: [PATCH 2/2] fix(nn): keep tied-weight SCB guards out of forward --- bitsandbytes/nn/modules.py | 4 ++-- tests/test_linear8bitlt.py | 20 ++++++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 11dc13e37..1bd53858b 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -1134,7 +1134,7 @@ def to(self, *args, **kwargs): def forward(self, x: torch.Tensor): self.state.is_training = self.training - if getattr(self.weight, "CB", None) is not None: + if self.weight.CB is not None: self.init_8bit_state() # weights are cast automatically as Int8Params, but the bias has to be cast manually @@ -1143,7 +1143,7 @@ def forward(self, x: torch.Tensor): out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) - if not self.state.has_fp16_weights and self.state.CB is not None and hasattr(self.weight, "CB"): + if not self.state.has_fp16_weights and self.state.CB is not None: self.weight.data = self.state.CB return out diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index dc4ff4741..b7574f855 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -228,6 +228,26 @@ def test_linear8bit_serialization(linear8bit): assert (linear8bit.weight.CB == deserialized.weight.CB).all() +def test_linear8bit_state_dict_skips_scb_for_tied_weight(): + linear = Linear8bitLt(8, 8, bias=False, has_fp16_weights=False) + linear.weight = torch.nn.Parameter(torch.randn_like(linear.weight)) + + state_dict = linear.state_dict() + + assert "SCB" not in state_dict + assert "weight_format" not in state_dict + + +def test_linear8bit_load_state_dict_raises_runtime_for_tied_weight(): + linear = Linear8bitLt(8, 8, bias=False, has_fp16_weights=False) + linear.weight = torch.nn.Parameter(torch.randn_like(linear.weight)) + state_dict = linear.state_dict() + state_dict["SCB"] = torch.ones(linear.out_features) + + with pytest.raises(RuntimeError, match="Loading a quantized checkpoint into non-quantized Linear8bitLt"): + linear.load_state_dict(state_dict, strict=False) + + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("threshold", [0.0, 6.0], ids=id_formatter("threshold")) @pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias"))