Implementation for soft offline distillation using saved top-k teacher logits#3382
Open
ajkv-google wants to merge 3 commits intomainfrom
Open
Implementation for soft offline distillation using saved top-k teacher logits#3382ajkv-google wants to merge 3 commits intomainfrom
ajkv-google wants to merge 3 commits intomainfrom
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
entrpn
reviewed
Mar 11, 2026
| def __init__(self, data_dir: str, epochs: int = 100): | ||
| # Check if the user passed a directory or a direct file path | ||
| if tf.io.gfile.isdir(data_dir): | ||
| self.filepath = os.path.join(data_dir, "teacher_top_k_global.array_record") |
Collaborator
There was a problem hiding this comment.
is it ok to hardcode this file as teacher_top_k_global.array_record?
|
|
||
| if __name__ == "__main__": | ||
| app.run(main) | ||
| parser = argparse.ArgumentParser() |
Collaborator
There was a problem hiding this comment.
I think these should go inside types.py to add them as part of the config.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This PR introduces an end-to-end offline distillation training pipeline. Previously, the distillation loop executed in an "online" mode, which required both the frozen Teacher model and the learning Student model to be loaded and executed simultaneously during training. This change allows the trainer to load pre-computed, top-K Teacher logits from .array_record files, which allows us to bybass the forward pass for the teacher model during the training loop.
Tests
Tested this code change by running the following command:
python3 src/maxtext/trainers/post_train/distillation/train_distill.py src/maxtext/configs/post_train/distillation.yml steps=100 tokenizer_path="/mnt/ajkv/disks/codebase/maxtext/src/maxtext/assets/tokenizers/tokenizer_llama3.tiktoken" --offline_distillation --offline_data_dir="/mnt/ajkv/disks/teacher_logits_output/teacher_top_k_global.array_record"Truncated output showing the successful run: https://paste.googleplex.com/6342987127848960#l=8.
Verified that the training happened sucessfully and finished the distillation run.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.