Skip to content

Extend LoRA for Gemma4#3969

Merged
copybara-service[bot] merged 1 commit into
mainfrom
jackyf/gemma4-lora
Jun 25, 2026
Merged

Extend LoRA for Gemma4#3969
copybara-service[bot] merged 1 commit into
mainfrom
jackyf/gemma4-lora

Conversation

@RexBearIU

@RexBearIU RexBearIU commented May 22, 2026

Copy link
Copy Markdown
Collaborator

Description

This PR extends the recent LoRA support to accurately target and process Gemma 4 architectures (including MoE).

Gemma 4 introduces complex nested structures (like scanned_blocks and layers_remainder) and unique chat template behaviors (such as the <|channel>thought block) that are incompatible with standard LoRA targeting and data
processing. Furthermore, MoE models require dynamic metadata synchronization during forward passes which is broken by aggressive NNX graph caching.

This PR addresses these challenges by:

  • Adding accurate regex mapping for Gemma 4 standard and MoE LoRA targets in lora_module_path.yml.
  • Dynamically disabling NNX graph caching in train_sft.py specifically for MoE models (where experts > 1) to allow necessary metadata synchronization.

Tests

  • Added unit tests for the Gemma 4 tokenizer bypass in tests/post_training/unit/sft_data_processing_test.py (test_tokenizer_gemma4_thought_channel_bypass).
  • Verified caching behavior changes by running Gemma-4 MoE LoRA tuning on TPU.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov

codecov Bot commented May 22, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 0% with 1 line in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/trainers/post_train/sft/train_sft.py 0.00% 1 Missing ⚠️

📢 Thoughts on this report? Let us know!

@RexBearIU RexBearIU force-pushed the jackyf/gemma4-lora branch from 2bc8632 to ab61640 Compare May 22, 2026 07:38
Comment thread tests/post_training/unit/sft_data_processing_test.py Outdated
@RexBearIU RexBearIU force-pushed the jackyf/gemma4-lora branch 2 times, most recently from 61626bd to ef50ff7 Compare May 28, 2026 08:41
Comment thread src/maxtext/input_pipeline/input_pipeline_utils.py Outdated
@RexBearIU RexBearIU force-pushed the jackyf/gemma4-lora branch from ef50ff7 to 5fd616b Compare May 29, 2026 07:48
@RexBearIU RexBearIU force-pushed the jackyf/gemma4-lora branch from 5fd616b to 6a64bd0 Compare June 1, 2026 06:59
@github-actions

github-actions Bot commented Jun 2, 2026

Copy link
Copy Markdown

🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This Pull Request successfully extends LoRA support for Gemma 4 architectures and addresses a critical issue with NNX graph caching in MoE models. However, there is a significant discrepancy between the PR description and the actual changes, as several mentioned files and unit tests are missing from the diff.

🔍 General Feedback

  • Missing Implementation: The PR description mentions a "thought channel bypass" in input_pipeline_utils.py and new unit tests in tests/post_training/unit/sft_data_processing_test.py, but these files are not included in the PR. Please ensure all intended changes are staged and pushed.
  • Consistency across Trainers: The dynamic disabling of NNX graph caching is a great addition for MoE stability; consider applying this same logic to DPO, RL, and Distillation trainers to ensure consistent behavior across the post-training suite.
  • LoRA Targeting: The regex for Gemma 4 LoRA targeting is comprehensive but should be monitored to ensure it doesn't become overly broad as the architecture evolves.

Comment thread src/maxtext/trainers/post_train/sft/train_sft.py Outdated
Comment thread src/maxtext/configs/post_train/lora_module_path.yml
Comment thread src/maxtext/configs/post_train/lora_module_path.yml Outdated
Comment thread src/maxtext/trainers/post_train/sft/train_sft.py Outdated

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any tests for Gemma4 Lora?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We ran the end-to-end LoRA training loop for Gemma 4 successfully without any issues.4 successfully without any issues. log

@RexBearIU RexBearIU force-pushed the jackyf/gemma4-lora branch 4 times, most recently from b52137c to bc29e4e Compare June 8, 2026 07:17
@RexBearIU RexBearIU force-pushed the jackyf/gemma4-lora branch from bc29e4e to 090040c Compare June 24, 2026 07:20

@gagika gagika left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please see my comment, it can be also addressed as a follow up PR.

deepseek2: "decoder/(dense_layers|moe_stack)/self_attention/(query|out|wkv_a|wkv_b)|decoder/(dense_layers|moe_stack)/(mlp|shared_experts)/(wi_0|wi_1|wo)"
gemma2: "decoder/layers/(self_attention_local|self_attention_global)/(query|key|value|out)|decoder/layers/(mlp_local|mlp_global)/(wi_0|wi_1|wo)"
gemma3: "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo|gate|up|down))"
gemma4: "decoder/(scanned_blocks|layers_remainder)/layers.*/.*(self_attention/(query|key|value|out)|mlp/.*(wi_0|wi_1|wo|shared_experts/(wi_0|wi_1|wo)))"

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the scope of the project both scan_layers=true and false, Gemma4 dense and moe?

Suggest adding a CPU unit test that asserts the regex matches the expected module paths for both scan_layers values and dense and moe.

@gagika gagika left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please make sure follow up items are done, but feel free to merge.

@RexBearIU

RexBearIU commented Jun 25, 2026

Copy link
Copy Markdown
Collaborator Author

Hi @gagika @shralex, to keep this PR's merge path clear and preserve approvals, I have addressed the follow-up items (adding unit tests for Gemma 4 LoRA targeting and SFT trainer model caching) in a dedicated follow-up PR: #4265. I will go ahead and merge this PR now! Thanks again for the review.

@copybara-service copybara-service Bot merged commit ec40b1d into main Jun 25, 2026
82 of 88 checks passed
@copybara-service copybara-service Bot deleted the jackyf/gemma4-lora branch June 25, 2026 05:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants