diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 235737c9..281e3f3e 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -963,6 +963,22 @@ def _load_optimizer(self, checkpoint_dir, **kwargs): state_dict = torch.load(scheduler_path, map_location='cpu') optimizer_config.lr_scheduler.load_state_dict(state_dict) + def _ensure_lora_dtype(self, model): + """Force LoRA parameters to use the same dtype as base model for FSDP2 compatibility.""" + base_dtype = None + for param in model.parameters(): + if param.dtype in (torch.float16, torch.bfloat16, torch.float32): + base_dtype = param.dtype + break + if base_dtype is None: + return + + # Convert all LoRA parameters to the base model dtype + with torch.no_grad(): + for name, param in model.named_parameters(): + if 'lora_' in name.lower() and param.dtype != base_dtype: + param.data = param.data.to(base_dtype) + @remote_function(collect='first') def get_state_dict(self, **kwargs): return self._get_trainable_parameters(kwargs.pop('adapter_name', self._get_default_group())) @@ -1019,6 +1035,7 @@ def _patch_adapter(self, adapter_name: str, config_or_dir: Union[PeftConfig, str else: unwrapped_model.add_adapter(adapter_name, config) + self._ensure_lora_dtype(self.model) self.optimizer_group[adapter_name] = self._construct_default_optimizer_group() self.optimizer_group[adapter_name].adapter_name = adapter_name self.optimizer_group[adapter_name].adapter_config = config