diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index ff8ce9c3d..1bd53858b 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -1054,7 +1054,7 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): scb_name = "SCB" # case 1: .cuda was called, SCB is in self.weight - param_from_weight = getattr(self.weight, scb_name) + param_from_weight = getattr(self.weight, scb_name, None) # case 2: self.init_8bit_state was called, SCB is in self.state param_from_state = getattr(self.state, scb_name) @@ -1095,7 +1095,8 @@ def _load_from_state_dict( for key in unexpected_copy: input_name = key[len(prefix) :] if input_name == "SCB": - if self.weight.SCB is None: + weight_scb = getattr(self.weight, "SCB", None) + if weight_scb is None: # buffers not yet initialized, can't access them directly without quantizing first raise RuntimeError( "Loading a quantized checkpoint into non-quantized Linear8bitLt is " @@ -1103,7 +1104,7 @@ def _load_from_state_dict( ) input_param = state_dict[key] - self.weight.SCB.copy_(input_param) + weight_scb.copy_(input_param) if self.state.SCB is not None: self.state.SCB = self.weight.SCB diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index dc4ff4741..b7574f855 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -228,6 +228,26 @@ def test_linear8bit_serialization(linear8bit): assert (linear8bit.weight.CB == deserialized.weight.CB).all() +def test_linear8bit_state_dict_skips_scb_for_tied_weight(): + linear = Linear8bitLt(8, 8, bias=False, has_fp16_weights=False) + linear.weight = torch.nn.Parameter(torch.randn_like(linear.weight)) + + state_dict = linear.state_dict() + + assert "SCB" not in state_dict + assert "weight_format" not in state_dict + + +def test_linear8bit_load_state_dict_raises_runtime_for_tied_weight(): + linear = Linear8bitLt(8, 8, bias=False, has_fp16_weights=False) + linear.weight = torch.nn.Parameter(torch.randn_like(linear.weight)) + state_dict = linear.state_dict() + state_dict["SCB"] = torch.ones(linear.out_features) + + with pytest.raises(RuntimeError, match="Loading a quantized checkpoint into non-quantized Linear8bitLt"): + linear.load_state_dict(state_dict, strict=False) + + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("threshold", [0.0, 6.0], ids=id_formatter("threshold")) @pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias"))