From 7ebfdca6360bf7bedbcfbd365d8b4ca1e0384405 Mon Sep 17 00:00:00 2001 From: Fiyin Benstowe Date: Tue, 23 Jun 2026 18:03:03 +0000 Subject: [PATCH] feat(agent): add checkpoint validation agent framework for qwen3-4b-unscanned --- .../checkpoint_validation_agent/README.md | 17 +++++++ .../checkpoint_validation_agent/__init__.py | 4 ++ .../agent/checkpoint_validation_agent/main.py | 51 +++++++++++++++++++ .../model_registry.py | 30 +++++++++++ .../reports/.gitkeep | 0 5 files changed, 102 insertions(+) create mode 100644 src/maxtext/experimental/agent/checkpoint_validation_agent/README.md create mode 100644 src/maxtext/experimental/agent/checkpoint_validation_agent/__init__.py create mode 100644 src/maxtext/experimental/agent/checkpoint_validation_agent/main.py create mode 100644 src/maxtext/experimental/agent/checkpoint_validation_agent/model_registry.py create mode 100644 src/maxtext/experimental/agent/checkpoint_validation_agent/reports/.gitkeep diff --git a/src/maxtext/experimental/agent/checkpoint_validation_agent/README.md b/src/maxtext/experimental/agent/checkpoint_validation_agent/README.md new file mode 100644 index 0000000000..9ed4f2a293 --- /dev/null +++ b/src/maxtext/experimental/agent/checkpoint_validation_agent/README.md @@ -0,0 +1,17 @@ +# Checkpoint Validation Agent + +This agent validates MaxText checkpoints to ensure they are compatible with the inference engine. + +## How to run +Set the checkpoint directory: +`export MAXTEXT_CHECKPOINT_DIR="/path/to/your/maxtext/folder"` + +Run the validation: +`python3 -m src.maxtext.experimental.agent.checkpoint_validation_agent.main ` + +Example: +`python3 -m src.maxtext.experimental.agent.checkpoint_validation_agent.main qwen3-4b-unscanned` + +## Adding a new model +1. Open `model_registry.py`. +2. Add the model metadata to the `MODEL_REGISTRY` dictionary. \ No newline at end of file diff --git a/src/maxtext/experimental/agent/checkpoint_validation_agent/__init__.py b/src/maxtext/experimental/agent/checkpoint_validation_agent/__init__.py new file mode 100644 index 0000000000..fcdbd006ed --- /dev/null +++ b/src/maxtext/experimental/agent/checkpoint_validation_agent/__init__.py @@ -0,0 +1,4 @@ +""" +Checkpoint Validation Agent Package. +Used to verify and report the status of converted model checkpoints. +""" \ No newline at end of file diff --git a/src/maxtext/experimental/agent/checkpoint_validation_agent/main.py b/src/maxtext/experimental/agent/checkpoint_validation_agent/main.py new file mode 100644 index 0000000000..ce2d72d47f --- /dev/null +++ b/src/maxtext/experimental/agent/checkpoint_validation_agent/main.py @@ -0,0 +1,51 @@ +import subprocess +import json +import os +import sys +import argparse +from src.maxtext.experimental.agent.checkpoint_validation_agent.model_registry import get_model_config + +def validate_checkpoint(model_name, override_length=None): + # Fetch configuration dynamically + config = get_model_config(model_name) + target_length = override_length or config['max_target_length'] #override if provided, else use registry default + print(f"Validating {model_name} with parameters at: {config['load_parameters_path']}") + + #run a smoke test using decode.py to check if model can load and initialize layers + command = [ + "python3", "src/maxtext/inference/decode.py", "src/maxtext/configs/base.yml", + f"load_parameters_path={config['load_parameters_path']}", + f"model_name={config['maxtext_model_name']}", + f"tokenizer_path={config['tokenizer_path']}", + f"scan_layers={config['scan_layers']}", + f"max_target_length={target_length}" + ] + + #capture terminal printouts + result = subprocess.run(command, capture_output=True, text=True) + + #create the report + report = { + "model": model_name, + "success": result.returncode == 0, #if returncode is 0, command worked + "stderr": result.stderr if result.returncode != 0 else "Success" #store error message if there's a failure + } + + report_dir = os.path.join(os.path.dirname(__file__), "reports") + os.makedirs(report_dir, exist_ok=True) + output_path = os.path.join(report_dir, f"report_{model_name}.json") + with open(output_path, "w") as f: + json.dump(report, f, indent=4) + print(f"Report saved to {output_path}") + +#script runs only if called directly +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Validate MaxText Checkpoints") + parser.add_argument("model_name", help="Model key in registry") + parser.add_argument("--max_target_length", type=int, help="Override target length") + args = parser.parse_args() + + try: + validate_checkpoint(args.model_name, args.max_target_length) + except Exception as e: + print(f"FAILED: {e}") \ No newline at end of file diff --git a/src/maxtext/experimental/agent/checkpoint_validation_agent/model_registry.py b/src/maxtext/experimental/agent/checkpoint_validation_agent/model_registry.py new file mode 100644 index 0000000000..ff42c9f281 --- /dev/null +++ b/src/maxtext/experimental/agent/checkpoint_validation_agent/model_registry.py @@ -0,0 +1,30 @@ +import os + +# Base path for checkpoints. should be strictly provided by the environment +BASE_PATH = os.getenv("MAXTEXT_CHECKPOINT_DIR") + +#stop program if user didn't set the variable +if not BASE_PATH: + raise EnvironmentError("MAXTEXT_CHECKPOINT_DIR not set.") + +#store metadata for supported models +MODEL_REGISTRY = { + "qwen3-4b-unscanned": { + "maxtext_model_name": "qwen3-4b", + "checkpoint_dir": "MaxText-Qwen3-4B-Unscanned", + "tokenizer_path": "Qwen/Qwen3-4B", + "max_target_length": 2048, + "scan_layers": False + }, + #more will be added as progress is made +} + +#used by main.py to get info +def get_model_config(model_name): + if model_name not in MODEL_REGISTRY: + raise ValueError(f"Model '{model_name}' not registered.") + + #copy of registry to prevent accidental edits + configuration = MODEL_REGISTRY[model_name].copy() + configuration["load_parameters_path"] = os.path.join(BASE_PATH, configuration["checkpoint_dir"], "0/items") #path to orbax weights + return configuration \ No newline at end of file diff --git a/src/maxtext/experimental/agent/checkpoint_validation_agent/reports/.gitkeep b/src/maxtext/experimental/agent/checkpoint_validation_agent/reports/.gitkeep new file mode 100644 index 0000000000..e69de29bb2