Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Checkpoint Validation Agent

This agent validates MaxText checkpoints to ensure they are compatible with the inference engine.

## How to run
Set the checkpoint directory:
`export MAXTEXT_CHECKPOINT_DIR="/path/to/your/maxtext/folder"`

Run the validation:
`python3 -m src.maxtext.experimental.agent.checkpoint_validation_agent.main <model_key>`

Example:
`python3 -m src.maxtext.experimental.agent.checkpoint_validation_agent.main qwen3-4b-unscanned`

## Adding a new model
1. Open `model_registry.py`.
2. Add the model metadata to the `MODEL_REGISTRY` dictionary.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""
Checkpoint Validation Agent Package.
Used to verify and report the status of converted model checkpoints.
"""
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import subprocess
import json
import os
import sys
import argparse
from src.maxtext.experimental.agent.checkpoint_validation_agent.model_registry import get_model_config

def validate_checkpoint(model_name, override_length=None):
# Fetch configuration dynamically
config = get_model_config(model_name)
target_length = override_length or config['max_target_length'] #override if provided, else use registry default
print(f"Validating {model_name} with parameters at: {config['load_parameters_path']}")

#run a smoke test using decode.py to check if model can load and initialize layers
command = [
"python3", "src/maxtext/inference/decode.py", "src/maxtext/configs/base.yml",
f"load_parameters_path={config['load_parameters_path']}",
f"model_name={config['maxtext_model_name']}",
f"tokenizer_path={config['tokenizer_path']}",
f"scan_layers={config['scan_layers']}",
f"max_target_length={target_length}"
]

#capture terminal printouts
result = subprocess.run(command, capture_output=True, text=True)

#create the report
report = {
"model": model_name,
"success": result.returncode == 0, #if returncode is 0, command worked
"stderr": result.stderr if result.returncode != 0 else "Success" #store error message if there's a failure
}

report_dir = os.path.join(os.path.dirname(__file__), "reports")
os.makedirs(report_dir, exist_ok=True)
output_path = os.path.join(report_dir, f"report_{model_name}.json")
with open(output_path, "w") as f:
json.dump(report, f, indent=4)
print(f"Report saved to {output_path}")

#script runs only if called directly
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Validate MaxText Checkpoints")
parser.add_argument("model_name", help="Model key in registry")
parser.add_argument("--max_target_length", type=int, help="Override target length")
args = parser.parse_args()

try:
validate_checkpoint(args.model_name, args.max_target_length)
except Exception as e:
print(f"FAILED: {e}")
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import os

# Base path for checkpoints. should be strictly provided by the environment
BASE_PATH = os.getenv("MAXTEXT_CHECKPOINT_DIR")

#stop program if user didn't set the variable
if not BASE_PATH:
raise EnvironmentError("MAXTEXT_CHECKPOINT_DIR not set.")

#store metadata for supported models
MODEL_REGISTRY = {
"qwen3-4b-unscanned": {
"maxtext_model_name": "qwen3-4b",
"checkpoint_dir": "MaxText-Qwen3-4B-Unscanned",
"tokenizer_path": "Qwen/Qwen3-4B",
"max_target_length": 2048,
"scan_layers": False
},
#more will be added as progress is made
}

#used by main.py to get info
def get_model_config(model_name):
if model_name not in MODEL_REGISTRY:
raise ValueError(f"Model '{model_name}' not registered.")

#copy of registry to prevent accidental edits
configuration = MODEL_REGISTRY[model_name].copy()
configuration["load_parameters_path"] = os.path.join(BASE_PATH, configuration["checkpoint_dir"], "0/items") #path to orbax weights
return configuration
Loading