From c43301516c88fec766173d7a8611a8d291744eb8 Mon Sep 17 00:00:00 2001 From: 0hujun <96733800+0hujun@users.noreply.github.com> Date: Thu, 23 Apr 2026 17:29:34 +0800 Subject: [PATCH 1/2] fix: model dtype not same as lora dtype in FSDP train --- .../model/transformers/transformers.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 235737c9..d2fe4b00 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -963,6 +963,27 @@ def _load_optimizer(self, checkpoint_dir, **kwargs): state_dict = torch.load(scheduler_path, map_location='cpu') optimizer_config.lr_scheduler.load_state_dict(state_dict) + def _ensure_lora_dtype(self, model): + """Force LoRA parameters to use the same dtype as base model for FSDP2 compatibility.""" + try: + base_model = model.get_base_model() + except Exception: + base_model = model + + base_dtype = None + for param in base_model.parameters(): + if param.dtype in (torch.float32, torch.float16, torch.bfloat16): + base_dtype = param.dtype + break + if base_dtype is None: + return + + # Convert all LoRA parameters to the base model dtype + for name, param in model.named_parameters(): + if 'lora_' in name.lower(): + if param.dtype != base_dtype: + param.data = param.data.to(base_dtype) + @remote_function(collect='first') def get_state_dict(self, **kwargs): return self._get_trainable_parameters(kwargs.pop('adapter_name', self._get_default_group())) @@ -1019,6 +1040,7 @@ def _patch_adapter(self, adapter_name: str, config_or_dir: Union[PeftConfig, str else: unwrapped_model.add_adapter(adapter_name, config) + self._ensure_lora_dtype(self.model) self.optimizer_group[adapter_name] = self._construct_default_optimizer_group() self.optimizer_group[adapter_name].adapter_name = adapter_name self.optimizer_group[adapter_name].adapter_config = config From c120c06aa7bb71cbefacb7d4013d51803c534bbf Mon Sep 17 00:00:00 2001 From: 0hujun <96733800+0hujun@users.noreply.github.com> Date: Thu, 23 Apr 2026 19:28:13 +0800 Subject: [PATCH 2/2] fix: model dtype is not same as lora dtype in FSDP train --- src/twinkle/model/transformers/transformers.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index d2fe4b00..281e3f3e 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -965,23 +965,18 @@ def _load_optimizer(self, checkpoint_dir, **kwargs): def _ensure_lora_dtype(self, model): """Force LoRA parameters to use the same dtype as base model for FSDP2 compatibility.""" - try: - base_model = model.get_base_model() - except Exception: - base_model = model - base_dtype = None - for param in base_model.parameters(): - if param.dtype in (torch.float32, torch.float16, torch.bfloat16): + for param in model.parameters(): + if param.dtype in (torch.float16, torch.bfloat16, torch.float32): base_dtype = param.dtype break if base_dtype is None: return # Convert all LoRA parameters to the base model dtype - for name, param in model.named_parameters(): - if 'lora_' in name.lower(): - if param.dtype != base_dtype: + with torch.no_grad(): + for name, param in model.named_parameters(): + if 'lora_' in name.lower() and param.dtype != base_dtype: param.data = param.data.to(base_dtype) @remote_function(collect='first')