diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 8a5b2f4805a..89c5cd9b0df 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -709,6 +709,33 @@ jobs: # Test test_arm_baremetal.sh with test backends/arm/test/test_arm_baremetal.sh "${ARM_TEST}" + test-arm-backend-public-api-backward-compatibility: + name: test-arm-backend-public-api-backward-compatibility + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + permissions: + id-token: write + contents: read + with: + runner: linux.2xlarge.memory + docker-image: ci-image:executorch-ubuntu-22.04-arm-sdk + submodules: 'recursive' + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 120 + script: | + # The generic Linux job chooses to use base env, not the one setup by the image + CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") + conda activate "${CONDA_ENV}" + + source .ci/scripts/utils.sh + install_executorch "--use-pt-pinned-commit" + + .ci/scripts/setup-arm-baremetal-tools.sh --enable-mlsdk-deps --install-mlsdk-deps-with-pip + source examples/arm/arm-scratch/setup_path.sh + + backends/arm/scripts/public_api_manifest/validate_all_public_api_manifests.sh + + python backends/arm/test/public_api_bc/run_public_api_bc_scenarios.py + test-llama-runner-qnn-linux: name: test-llama-runner-qnn-linux uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main diff --git a/backends/arm/public_api_manifests/api_manifest_running.toml b/backends/arm/public_api_manifests/api_manifest_running.toml index 44de795799e..cd1deddfee7 100644 --- a/backends/arm/public_api_manifests/api_manifest_running.toml +++ b/backends/arm/public_api_manifests/api_manifest_running.toml @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. # # This file is generated by -# backends/arm/scripts/generate_public_api_manifest.py +# backends/arm/scripts/public_api_manifest/generate_public_api_manifest.py [python] diff --git a/backends/arm/scripts/pre-push b/backends/arm/scripts/pre-push index 8e26463cd94..acfa5f81c7a 100755 --- a/backends/arm/scripts/pre-push +++ b/backends/arm/scripts/pre-push @@ -70,6 +70,13 @@ run_docgen_check() { fi } +run_public_api_validator() { + if ! backends/arm/scripts/public_api_manifest/validate_all_public_api_manifests.sh; then + echo -e "${ERROR} Arm public API manifest validation failed" + FAILED=1 + fi +} + # This list of imperative verbs was compiled from the entire list of Executorch # commits. It should be fairly exhaustive, but add more verbs if you find one # that's missing. @@ -149,7 +156,6 @@ for COMMIT in ${COMMITS}; do fi done fi - # Check license headers # We do a simple check of if all committed headers contain # "$current_year Arm". This does not guarantee OK in ci but should be ok @@ -177,7 +183,7 @@ for COMMIT in ${COMMITS}; do for committed_file in "${license_files[@]}"; do # Skip files with certain extensions case "$committed_file" in - *.md|*.md.in|*.json|*.yml|*.yaml|*.cmake|*.patch|.gitignore|*.bzl) + *.md|*.md.in|*.json|*.yml|*.yaml|*.cmake|*.patch|*.bzl|.gitignore) echo -e "${INFO} Skipping license check for ${committed_file} (excluded extension)" continue ;; @@ -311,6 +317,8 @@ else echo -e "${INFO} Skipping Arm docgen (no public API inputs changed)" fi +run_public_api_validator + if [[ $FAILED ]]; then echo -e "${INFO} Fix your commit message errors with"\ "'git commit --amend' or 'git commit --fixup='" diff --git a/backends/arm/scripts/public_api_manifest/README.md b/backends/arm/scripts/public_api_manifest/README.md new file mode 100644 index 00000000000..c8fbef402d5 --- /dev/null +++ b/backends/arm/scripts/public_api_manifest/README.md @@ -0,0 +1,105 @@ +# Manifests + +Manifests are used to track the current public API of the Arm backend. They are +generated with +`python backends/arm/scripts/public_api_manifest/generate_public_api_manifest.py`. + +## Running manifest + +There is always one running manifest which has the main purpose of tracking the +API surface inbetween releases. + +## Static manifests + +At any given time there may be up to two static manifests. These are generated +in conjunction with a release and are used to track the API surface of that +release. The main purpose of these is to make sure backwards compatibility. + +A static manifest may never be changed. It belongs to a release and must be kept +as is. + +A static manifest should not live longer than 2 releases. It may then be +removed. + +# On release + +With each release, check that the running manifest is up to date and reflects +the API surface of the release. Then, copy the running manifest to a new static +manifest for the release. This can be done by running +`cp `. The new static manifest should be +named according to the release, e.g. `api_manifest_1_3.toml` for release 1.3 and +so on. If there are now more than two static manifests, remove the oldest one in +the same commit. + +# API changes + +When introducing an API change, the running manifest must be updated to reflect +the change. This is done by running the manifest generation script, +`python backends/arm/scripts/public_api_manifest/generate_public_api_manifest.py`. +This updates the running manifest. + +To validate the running manifest directly, run +`python backends/arm/scripts/public_api_manifest/validate_public_api_manifest.py`. + +To validate all manifests, use `backends/arm/scripts/pre-push`. This is the +check that must pass before the change is ready to merge. + +Manifest validation only checks the API surface and signatures. Workflow-level +backward compatibility is covered separately by the scenario runner described +below. + +Running-manifest validation uses exact signature matching. Any intentional API +change must update `api_manifest_running.toml`. + +Static-manifest validation uses backward-compatibility matching. The old +release signature must still be callable against the current API. For example, +adding a trailing optional parameter is accepted for static manifests, while +removing a parameter, reordering parameters, or adding a new required +parameter still fails validation. + +## Backward-compatibility scenarios + +Workflow-level backward compatibility is checked by +`python backends/arm/test/public_api_bc/run_public_api_bc_scenarios.py`. + +The runner hardcodes the current canonical public API workflow scripts: + +- `backends/arm/test/public_api_bc/ethosu_flow.py` +- `backends/arm/test/public_api_bc/vgf_fp_flow.py` +- `backends/arm/test/public_api_bc/vgf_int_flow.py` + +These scripts should be updated continuously to reflect the current public API. +The runner materializes those same paths into a temporary harness and executes +them there with pytest so they import the latest installed +`executorch.backends.arm` package instead of the repository source tree. + +The rolling support window is controlled by the `OLDEST_SUPPORTED_REF` constant +in `backends/arm/test/public_api_bc/run_public_api_bc_scenarios.py`: + +- If `OLDEST_SUPPORTED_REF` is empty, the runner uses the current workspace. + This is the bootstrap mode until a release contains the scenario scripts. +- Once a release contains the scripts, the release epic should update + `OLDEST_SUPPORTED_REF` to the oldest still-supported release ref. +- At that point the runner uses `git show :` to fetch the old + release's scripts and run them against the latest code. + +When an old release falls out of the support window, update +`OLDEST_SUPPORTED_REF` to the next newer supported release. That is how the +backward-compatibility window rolls forward. +Reasons for passing validation may include: +- Adding a new API symbol and adding it to the running manifest. +- Removing an API that was marked as deprecated and no longer exists in any + manifest. +- Deprecated symbols do not break backward compatibility with static + manifests. +- Deprecating a symbol removes it from the running manifest, but it can only be + removed fully once it no longer appears in any static manifest. +- Extending a static-manifest signature in a backward-compatible way, such as + adding a trailing optional parameter. + +Reasons for failing validation may include: +- Removing an API symbol without deprecation. +- Changing a running-manifest signature without regenerating the running + manifest. +- Changing a static-manifest signature in a non-backward-compatible way. +- New API symbol added but not added to the running manifest. diff --git a/backends/arm/scripts/public_api_manifest/__init__.py b/backends/arm/scripts/public_api_manifest/__init__.py new file mode 100644 index 00000000000..19ebb35e5f2 --- /dev/null +++ b/backends/arm/scripts/public_api_manifest/__init__.py @@ -0,0 +1,4 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/arm/scripts/generate_public_api_manifest.py b/backends/arm/scripts/public_api_manifest/generate_public_api_manifest.py similarity index 81% rename from backends/arm/scripts/generate_public_api_manifest.py rename to backends/arm/scripts/public_api_manifest/generate_public_api_manifest.py index bc57bdd9e7f..60dcfed7c37 100644 --- a/backends/arm/scripts/generate_public_api_manifest.py +++ b/backends/arm/scripts/public_api_manifest/generate_public_api_manifest.py @@ -23,10 +23,10 @@ # LICENSE file in the root directory of this source tree. # # This file is generated by -# backends/arm/scripts/generate_public_api_manifest.py +# backends/arm/scripts/public_api_manifest/generate_public_api_manifest.py """ MANIFEST_PATH = ( - Path(__file__).resolve().parents[1] + Path(__file__).resolve().parents[2] / "public_api_manifests" / "api_manifest_running.toml" ) @@ -81,8 +81,10 @@ def _collect_entry( path: str, obj: object, entries: dict[str, dict[str, str]], + *, + include_deprecated: bool = False, ) -> None: - if _is_unstable_api(obj): + if _is_unstable_api(obj) and not include_deprecated: return entries[path] = {"kind": _api_kind(obj), "signature": _api_signature(path, obj)} if not inspect.isclass(obj): @@ -96,13 +98,25 @@ def _collect_entry( continue member = getattr(obj, name) if inspect.isclass(member) or callable(member): - _collect_entry(f"{path}.{name}", member, entries) + _collect_entry( + f"{path}.{name}", + member, + entries, + include_deprecated=include_deprecated, + ) -def _collect_public_api() -> dict[str, dict[str, str]]: +def _collect_public_api( + *, include_deprecated: bool = False +) -> dict[str, dict[str, str]]: entries: dict[str, dict[str, str]] = {} for name in sorted(LAZY_IMPORTS): - _collect_entry(name, getattr(arm, name), entries) + _collect_entry( + name, + getattr(arm, name), + entries, + include_deprecated=include_deprecated, + ) return entries @@ -124,9 +138,10 @@ def _render_manifest(entries: dict[str, dict[str, str]]) -> str: def generate_manifest_from_init( *, repo_path: Path | None = None, + include_deprecated: bool = False, ) -> str: del repo_path - return _render_manifest(_collect_public_api()) + return _render_manifest(_collect_public_api(include_deprecated=include_deprecated)) def main() -> None: diff --git a/backends/arm/scripts/public_api_manifest/validate_all_public_api_manifests.sh b/backends/arm/scripts/public_api_manifest/validate_all_public_api_manifests.sh new file mode 100755 index 00000000000..a921bdf9679 --- /dev/null +++ b/backends/arm/scripts/public_api_manifest/validate_all_public_api_manifests.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -u + +echo "Validating Arm public API manifests" + +manifest_failures=0 +for manifest_path in backends/arm/public_api_manifests/api_manifest_*.toml; do + if [[ ! -f "${manifest_path}" ]]; then + continue + fi + manifest_name="${manifest_path##*/}" + echo + echo "=== ${manifest_name} ===" + validator_output=$( + python backends/arm/scripts/public_api_manifest/validate_public_api_manifest.py \ + --manifest "${manifest_path}" 2>&1 + ) + validator_status=$? + printf '%s\n' "${validator_output}" + if [[ ${validator_status} -ne 0 ]]; then + manifest_failures=$((manifest_failures + 1)) + fi +done + +echo +if [[ ${manifest_failures} -eq 0 ]]; then + echo "Arm public API manifests OK" +else + echo "${manifest_failures} manifest(s) failed validation" + exit 1 +fi diff --git a/backends/arm/scripts/public_api_manifest/validate_public_api_manifest.py b/backends/arm/scripts/public_api_manifest/validate_public_api_manifest.py new file mode 100644 index 00000000000..e48538809e3 --- /dev/null +++ b/backends/arm/scripts/public_api_manifest/validate_public_api_manifest.py @@ -0,0 +1,358 @@ +#!/usr/bin/env python3 + +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +"""Validate one Arm public API manifest against the current API.""" + +from __future__ import annotations + +import argparse +import ast +import importlib.util +import inspect +from pathlib import Path + +try: + import tomllib +except ModuleNotFoundError: + import tomli as tomllib # type: ignore[import-not-found,no-redef] + +try: + import executorch.backends.arm.scripts.public_api_manifest.generate_public_api_manifest as gpam +except ModuleNotFoundError: + generator_path = Path(__file__).resolve().parent / "generate_public_api_manifest.py" + spec = importlib.util.spec_from_file_location( + "generate_public_api_manifest", + generator_path, + ) + if spec is None or spec.loader is None: + raise RuntimeError(f"Unable to load generator script at {generator_path}") + gpam = importlib.util.module_from_spec(spec) + spec.loader.exec_module(gpam) + +REPO_PATH = Path(__file__).resolve().parents[4] +MANIFEST_PATH = ( + REPO_PATH + / "backends" + / "arm" + / "public_api_manifests" + / "api_manifest_running.toml" +) +Issue = tuple[str, str, str | None, str | None] +NEW_API_SYMBOL_REASON = ( + "entry is present in the current API but missing from the manifest" +) +INCOMPATIBLE_SIGNATURE_REASON = "signature is not backward compatible" + + +ParameterDescriptor = tuple[str, str, str | None] +ParsedSignature = tuple[list[ParameterDescriptor], str | None] + + +def read_manifest(manifest_path: Path) -> dict: + with open(manifest_path, "rb") as manifest_file: + return tomllib.load(manifest_file) + + +def _collect_python_symbols( + table: dict[str, object], + *, + prefix: str = "", +) -> dict[str, dict[str, str]]: + symbols: dict[str, dict[str, str]] = {} + kind = table.get("kind") + signature = table.get("signature") + if kind is not None or signature is not None: + if not isinstance(kind, str) or not isinstance(signature, str): + raise ValueError( + f"Entry [python.{prefix}] must define `kind` and `signature`" + ) + symbols[prefix] = {"kind": kind, "signature": signature} + + for name, entry in table.items(): + if name in {"kind", "signature"}: + continue + if not isinstance(entry, dict): + raise ValueError(f"Entry [python.{prefix}] contains invalid child {name}") + child_prefix = f"{prefix}.{name}" if prefix else name + symbols.update(_collect_python_symbols(entry, prefix=child_prefix)) + return symbols + + +def get_manifest_python_symbols(manifest: dict) -> dict[str, dict[str, str]]: + python_manifest = manifest.get("python") + if not isinstance(python_manifest, dict): + raise ValueError("Manifest is missing [python] section") + return _collect_python_symbols(python_manifest) + + +def get_current_python_symbols( + *, + include_deprecated: bool = False, +) -> dict[str, dict[str, str]]: + generated_manifest = tomllib.loads( + gpam.generate_manifest_from_init( + repo_path=REPO_PATH, + include_deprecated=include_deprecated, + ) + ) + return get_manifest_python_symbols(generated_manifest) + + +def _parse_signature(signature: str) -> ParsedSignature: + suffix_start = signature.find("(") + if suffix_start == -1: + raise ValueError(f"Malformed signature: {signature}") + signature_suffix = signature[suffix_start:] + function_definition = ast.parse( + f"def _manifest_stub{signature_suffix}:\n pass\n" + ) + function_node = function_definition.body[0] + if not isinstance(function_node, ast.FunctionDef): + raise ValueError(f"Unable to parse signature: {signature}") + + parameters: list[ParameterDescriptor] = [] + positional_args = list(function_node.args.posonlyargs) + list( + function_node.args.args + ) + positional_defaults = [None] * ( + len(positional_args) - len(function_node.args.defaults) + ) + list(function_node.args.defaults) + + for argument, default in zip( + function_node.args.posonlyargs, + positional_defaults[: len(function_node.args.posonlyargs)], + ): + parameters.append( + ( + argument.arg, + inspect.Parameter.POSITIONAL_ONLY.name, + None if default is None else ast.unparse(default), + ) + ) + + for argument, default in zip( + function_node.args.args, + positional_defaults[len(function_node.args.posonlyargs) :], + ): + parameters.append( + ( + argument.arg, + inspect.Parameter.POSITIONAL_OR_KEYWORD.name, + None if default is None else ast.unparse(default), + ) + ) + + if function_node.args.vararg is not None: + parameters.append( + ( + function_node.args.vararg.arg, + inspect.Parameter.VAR_POSITIONAL.name, + None, + ) + ) + + for argument, default in zip( + function_node.args.kwonlyargs, + function_node.args.kw_defaults, + ): + parameters.append( + ( + argument.arg, + inspect.Parameter.KEYWORD_ONLY.name, + None if default is None else ast.unparse(default), + ) + ) + + if function_node.args.kwarg is not None: + parameters.append( + ( + function_node.args.kwarg.arg, + inspect.Parameter.VAR_KEYWORD.name, + None, + ) + ) + + return_annotation = ( + None if function_node.returns is None else ast.unparse(function_node.returns) + ) + return parameters, return_annotation + + +def is_signature_backward_compatible( + manifest_signature: str, + current_signature: str, +) -> bool: + try: + manifest_parameters, manifest_return = _parse_signature(manifest_signature) + current_parameters, current_return = _parse_signature(current_signature) + except (SyntaxError, ValueError): + return False + + if manifest_return != current_return: + return False + + if len(current_parameters) < len(manifest_parameters): + return False + + for expected, actual in zip(manifest_parameters, current_parameters): + if actual != expected: + return False + + for _, kind, default in current_parameters[len(manifest_parameters) :]: + if ( + kind + not in ( + inspect.Parameter.VAR_POSITIONAL.name, + inspect.Parameter.VAR_KEYWORD.name, + ) + and default is None + ): + return False + + return True + + +def validate_symbols( + manifest_symbols: dict[str, dict[str, str]], + current_symbols: dict[str, dict[str, str]], + *, + ignore_new_api_symbols: bool = False, + allow_backward_compatible_signature_changes: bool = False, +) -> list[Issue]: + issues: list[Issue] = [] + manifest_keys = set(manifest_symbols) + current_keys = set(current_symbols) + + for name in sorted(manifest_keys - current_keys): + issues.append( + ( + name, + "entry is present in the manifest but missing from the current API", + manifest_symbols[name]["signature"], + None, + ) + ) + + if not ignore_new_api_symbols: + for name in sorted(current_keys - manifest_keys): + issues.append( + ( + name, + NEW_API_SYMBOL_REASON, + None, + current_symbols[name]["signature"], + ) + ) + + for name in sorted(manifest_keys & current_keys): + expected = manifest_symbols[name] + actual = current_symbols[name] + if actual["kind"] != expected["kind"]: + issues.append( + ( + name, + f"kind changed from '{expected['kind']}' to '{actual['kind']}'", + expected["signature"], + actual["signature"], + ) + ) + elif actual["signature"] != expected["signature"] and ( + not allow_backward_compatible_signature_changes + or not is_signature_backward_compatible( + expected["signature"], + actual["signature"], + ) + ): + issues.append( + ( + name, + ( + INCOMPATIBLE_SIGNATURE_REASON + if allow_backward_compatible_signature_changes + else "signature changed" + ), + expected["signature"], + actual["signature"], + ) + ) + return issues + + +def format_manifest_guidance(manifest_path: Path) -> str: + if manifest_path.name == "api_manifest_running.toml": + return ( + f"If this change is intentional, regenerate {manifest_path.name} and amend " + "it into your change." + ) + return ( + "If this change is intentional, deprecate the old symbol instead of " + "changing or removing it directly." + ) + + +def format_validation_report(manifest_path: Path, issues: list[Issue]) -> str: + if not issues: + return f"{manifest_path.name}: public API is up to date." + + lines = [f"{manifest_path.name}: public API validation failed."] + for name, reason, expected, actual in issues: + lines.append(f"- {name}: {reason}") + if expected is not None: + lines.append(f" manifest: {expected}") + if actual is not None: + lines.append(f" current: {actual}") + if manifest_path.name == MANIFEST_PATH.name and any( + reason == NEW_API_SYMBOL_REASON for _, reason, _, _ in issues + ): + lines.append( + "If you intentionally added a new API symbol, update the running " + "manifest with:" + ) + lines.append("") + lines.append( + "python backends/arm/scripts/public_api_manifest/generate_public_api_manifest.py" + ) + lines.append("") + lines.append("and amend the manifest into your change.") + else: + lines.append(format_manifest_guidance(manifest_path)) + return "\n".join(lines) + + +def validate_manifest(manifest_path: Path) -> list[Issue]: + return validate_symbols( + get_manifest_python_symbols(read_manifest(manifest_path)), + get_current_python_symbols( + include_deprecated=manifest_path.name != MANIFEST_PATH.name, + ), + ignore_new_api_symbols=manifest_path.name != MANIFEST_PATH.name, + allow_backward_compatible_signature_changes=( + manifest_path.name != MANIFEST_PATH.name + ), + ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument( + "--manifest", + type=Path, + default=MANIFEST_PATH, + help="Path to the public API manifest TOML file.", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + issues = validate_manifest(args.manifest) + print(format_validation_report(args.manifest, issues)) + if issues: + raise SystemExit(1) + + +if __name__ == "__main__": + main() diff --git a/backends/arm/test/misc/test_public_api_manifest.py b/backends/arm/test/misc/test_public_api_manifest.py index e891584364b..41584eb2a86 100644 --- a/backends/arm/test/misc/test_public_api_manifest.py +++ b/backends/arm/test/misc/test_public_api_manifest.py @@ -5,8 +5,9 @@ from pathlib import Path +import executorch.backends.arm as arm from executorch.backends.arm import LAZY_IMPORTS -from executorch.backends.arm.scripts.generate_public_api_manifest import ( +from executorch.backends.arm.scripts.public_api_manifest.generate_public_api_manifest import ( _collect_entry, _collect_public_api, _render_manifest, @@ -32,9 +33,14 @@ def _entry_block(path: str, entry: dict[str, str]) -> str: def test_public_api_manifest_entries_are_well_formed(): entries = _collect_public_api() + expected_roots = { + name + for name in LAZY_IMPORTS + if getattr(getattr(arm, name), "__deprecated__", None) is None + } assert entries - assert {path.split(".")[0] for path in entries} == set(LAZY_IMPORTS) + assert {path.split(".")[0] for path in entries} == expected_roots for path, entry in entries.items(): assert entry["kind"] in {"class", "enum", "function"} @@ -73,6 +79,20 @@ def old_foo(x: int) -> int: assert "old_foo" not in entries +def test_public_api_manifest_collection_can_include_deprecated_symbols(): + @deprecated("old foo") + def old_foo(x: int) -> int: + return x + + old_foo.__module__ = "executorch.backends.arm.synthetic" + entries: dict[str, dict[str, str]] = {} + + _collect_entry("old_foo", old_foo, entries, include_deprecated=True) + + assert entries["old_foo"]["kind"] == "function" + assert entries["old_foo"]["signature"] == "old_foo(x: int) -> int" + + def test_public_api_manifest_collection_excludes_init_for_equivalent_classes(): class ExplicitInit: def __init__(self, x: int = 0) -> None: diff --git a/backends/arm/test/misc/test_validate_public_api_manifest.py b/backends/arm/test/misc/test_validate_public_api_manifest.py new file mode 100644 index 00000000000..f1eb40c6bd4 --- /dev/null +++ b/backends/arm/test/misc/test_validate_public_api_manifest.py @@ -0,0 +1,262 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from pathlib import Path + +import executorch.backends.arm.scripts.public_api_manifest.validate_public_api_manifest as vpam + +try: + import tomllib +except ModuleNotFoundError: + import tomli as tomllib # type: ignore[import-not-found,no-redef] + +from executorch.backends.arm.scripts.public_api_manifest.validate_public_api_manifest import ( + format_validation_report, + get_current_python_symbols, + get_manifest_python_symbols, + validate_symbols, +) + +RUNNING_MANIFEST_PATH = ( + Path(__file__).resolve().parents[2] + / "public_api_manifests" + / "api_manifest_running.toml" +) +MOCK_STATIC_MANIFEST_PATH = Path("mock_api_manifest_static_VERSION.toml") + + +def test_public_api_manifest_exact_comparison_rejects_signature_expansion(): + manifest_entries = {"foo": {"kind": "function", "signature": "foo(x: int) -> int"}} + current_entries = { + "foo": { + "kind": "function", + "signature": "foo(x: int, y: int | None = None) -> int", + } + } + + issues = validate_symbols(manifest_entries, current_entries) + + assert len(issues) == 1 + assert issues[0][0] == "foo" + assert issues[0][1] == "signature changed" + + +def test_get_manifest_python_symbols_flattens_nested_tables(): + manifest = tomllib.loads( + """ + [python] + + [python.Foo] + kind = "class" + signature = "Foo()" + + [python.Foo.bar] + kind = "function" + signature = "Foo.bar() -> None" + """ + ) + + assert get_manifest_python_symbols(manifest) == { + "Foo": {"kind": "class", "signature": "Foo()"}, + "Foo.bar": {"kind": "function", "signature": "Foo.bar() -> None"}, + } + + +def test_nested_python_manifest_entries_are_validated(): + manifest_symbols = get_manifest_python_symbols( + tomllib.loads( + """ + [python] + + [python.Foo] + kind = "class" + signature = "Foo()" + + [python.Foo.bar] + kind = "function" + signature = "Foo.bar(x: int) -> int" + """ + ) + ) + + issues = validate_symbols( + manifest_symbols, + { + "Foo": {"kind": "class", "signature": "Foo()"}, + }, + ) + + assert issues == [ + ( + "Foo.bar", + "entry is present in the manifest but missing from the current API", + "Foo.bar(x: int) -> int", + None, + ) + ] + + +def test_public_api_manifest_static_accepts_backward_compatible_signature_expansion(): + manifest_entries = { + "foo": {"kind": "function", "signature": "foo(x: int, y: int = 0) -> int"} + } + current_entries = { + "foo": { + "kind": "function", + "signature": "foo(x: int, y: int = 0, z: int | None = None) -> int", + } + } + + issues = validate_symbols( + manifest_entries, + current_entries, + ignore_new_api_symbols=True, + allow_backward_compatible_signature_changes=True, + ) + + assert issues == [] + + +def test_public_api_manifest_static_rejects_new_required_parameter(): + manifest_entries = {"foo": {"kind": "function", "signature": "foo(x: int) -> int"}} + current_entries = { + "foo": { + "kind": "function", + "signature": "foo(x: int, y: int) -> int", + } + } + + issues = validate_symbols( + manifest_entries, + current_entries, + ignore_new_api_symbols=True, + allow_backward_compatible_signature_changes=True, + ) + + assert len(issues) == 1 + assert issues[0][0] == "foo" + assert issues[0][1] == vpam.INCOMPATIBLE_SIGNATURE_REASON + + +def test_public_api_manifest_exact_comparison_rejects_additions(): + manifest_entries = {"foo": {"kind": "function", "signature": "foo(x: int) -> int"}} + current_entries = { + "bar": {"kind": "function", "signature": "bar() -> int"}, + "foo": {"kind": "function", "signature": "foo(x: int) -> int"}, + } + + issues = validate_symbols(manifest_entries, current_entries) + + assert len(issues) == 1 + assert issues[0][0] == "bar" + assert "missing from the manifest" in issues[0][1] + + +def test_public_api_manifest_running_regeneration_reports_drift(): + manifest_entries = {"foo": {"kind": "function", "signature": "foo(x: int) -> int"}} + current_entries = { + "bar": {"kind": "function", "signature": "bar() -> int"}, + "foo": { + "kind": "function", + "signature": "foo(x: int, y: int | None = None) -> int", + }, + } + + issues = validate_symbols(manifest_entries, current_entries) + report = format_validation_report(RUNNING_MANIFEST_PATH, issues) + + assert len(issues) == 2 + assert {issue[0] for issue in issues} == {"foo", "bar"} + assert "public API validation failed" in report + assert "added a new API symbol" in report + assert "manifest with:" in report + assert ( + "manifest with:\n\n" + "python backends/arm/scripts/public_api_manifest/generate_public_api_manifest.py\n\n" + "and amend the manifest into your change." + ) in report + + +def test_public_api_manifest_static_deprecation_reports_drift(): + manifest_entries = {"foo": {"kind": "function", "signature": "foo(x: int) -> int"}} + current_entries = { + "bar": {"kind": "function", "signature": "bar() -> int"}, + "foo": {"kind": "function", "signature": "foo(x: int, y: int = 0) -> int"}, + } + + issues = validate_symbols( + manifest_entries, + current_entries, + ignore_new_api_symbols=True, + allow_backward_compatible_signature_changes=True, + ) + report = format_validation_report(MOCK_STATIC_MANIFEST_PATH, issues) + + assert issues == [] + assert "public API is up to date" in report + + +def test_public_api_manifest_static_reports_incompatible_signature_drift(): + manifest_entries = { + "foo": {"kind": "function", "signature": "foo(x: int, y: int = 0) -> int"} + } + current_entries = { + "foo": { + "kind": "function", + "signature": "foo(x: int, y: int, z: int | None = None) -> int", + } + } + + issues = validate_symbols( + manifest_entries, + current_entries, + ignore_new_api_symbols=True, + allow_backward_compatible_signature_changes=True, + ) + report = format_validation_report(MOCK_STATIC_MANIFEST_PATH, issues) + + assert len(issues) == 1 + assert issues[0][0] == "foo" + assert issues[0][1] == vpam.INCOMPATIBLE_SIGNATURE_REASON + assert "deprecate the old symbol" in report + + +def test_public_api_manifest_static_ignores_additions(): + manifest_entries = {"foo": {"kind": "function", "signature": "foo(x: int) -> int"}} + current_entries = { + "bar": {"kind": "function", "signature": "bar() -> int"}, + "foo": {"kind": "function", "signature": "foo(x: int) -> int"}, + } + + issues = validate_symbols( + manifest_entries, + current_entries, + ignore_new_api_symbols=True, + ) + report = format_validation_report(MOCK_STATIC_MANIFEST_PATH, issues) + + assert issues == [] + assert "public API is up to date" in report + + +def test_get_current_python_symbols_can_include_deprecated(monkeypatch): + def fake_generate_manifest_from_init( + *, + repo_path=None, + include_deprecated: bool = False, + ) -> str: + del repo_path + if include_deprecated: + return '[python]\n\n[python.foo]\nkind = "function"\nsignature = "foo()"\n' + return "[python]\n" + + monkeypatch.setattr( + vpam.gpam, "generate_manifest_from_init", fake_generate_manifest_from_init + ) + + assert get_current_python_symbols() == {} + assert get_current_python_symbols(include_deprecated=True) == { + "foo": {"kind": "function", "signature": "foo()"} + } diff --git a/backends/arm/test/public_api_bc/ethosu_flow.py b/backends/arm/test/public_api_bc/ethosu_flow.py new file mode 100644 index 00000000000..9e54807e8cf --- /dev/null +++ b/backends/arm/test/public_api_bc/ethosu_flow.py @@ -0,0 +1,129 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from pathlib import Path + +import torch +from executorch.backends.arm import ( + EthosUBackend, + EthosUCompileSpec, + EthosUPartitioner, + EthosUQuantizer, + get_symmetric_a16w8_quantization_config, + get_symmetric_quantization_config, +) +from executorch.exir import ExecutorchBackendConfig, to_edge_transform_and_lower +from executorch.extension.export_util.utils import save_pte_program +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e + + +class TinyConvRelu(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(3, 4, kernel_size=3) + self.relu = torch.nn.ReLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.relu(self.conv(x)) + + +class TinyAdd(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + y + + +def _configured_compile_spec(tmp_path: Path) -> EthosUCompileSpec: + compile_spec = EthosUCompileSpec( + "ethos-u55-128", + system_config="Ethos_U55_High_End_Embedded", + memory_mode="Shared_Sram", + ) + + assert compile_spec == EthosUCompileSpec( + "ethos-u55-128", + system_config="Ethos_U55_High_End_Embedded", + memory_mode="Shared_Sram", + ) + assert "EthosUCompileSpec" in repr(compile_spec) + + compile_spec.dump_intermediate_artifacts_to(str(tmp_path / "ethosu_intermediates")) + returned = compile_spec.dump_debug_info(EthosUCompileSpec.DebugMode.TOSA) + assert returned is compile_spec + return compile_spec + + +def _exercise_quantizer_api(compile_spec: EthosUCompileSpec) -> None: + quantizer = EthosUQuantizer(compile_spec) + symmetric_config = get_symmetric_quantization_config(is_per_channel=False) + a16w8_config = get_symmetric_a16w8_quantization_config(is_per_channel=False) + + quantizer.set_global(symmetric_config) + quantizer.set_io(a16w8_config) + quantizer.set_module_name("conv", symmetric_config) + quantizer.set_module_type(torch.nn.ReLU, symmetric_config) + + example_inputs = (torch.randn(1, 3, 8, 8),) + graph_module = torch.export.export(TinyConvRelu().eval(), example_inputs).module( + check_guards=False + ) + transformed = quantizer.transform_for_annotation(graph_module) + annotated = quantizer.annotate(transformed) + quantizer.validate(annotated) + + +def _build_quantized_program(compile_spec: EthosUCompileSpec): + model = TinyAdd().eval() + example_inputs = ( + torch.ones(1, 1, 1, 1), + torch.ones(1, 1, 1, 1), + ) + exported_program = torch.export.export(model, example_inputs) + graph_module = exported_program.module(check_guards=False) + + quantizer = EthosUQuantizer(compile_spec) + quantizer.set_global(get_symmetric_quantization_config()) + + prepared = prepare_pt2e(graph_module, quantizer) + prepared(*example_inputs) + converted = convert_pt2e(prepared) + + return torch.export.export(converted, example_inputs) + + +def test_ethosu_public_api_scenario(tmp_path: Path) -> None: + backend = EthosUBackend() + assert isinstance(backend, EthosUBackend) + + compile_spec = _configured_compile_spec(tmp_path) + _exercise_quantizer_api(compile_spec) + + partitioner = EthosUPartitioner(compile_spec) + quantized_program_for_partition = _build_quantized_program(compile_spec) + ops_to_preserve, filter_fn = partitioner.ops_to_not_decompose( + quantized_program_for_partition + ) + partition_result = partitioner.partition(quantized_program_for_partition) + + assert isinstance(ops_to_preserve, list) + assert filter_fn is None or callable(filter_fn) + assert partition_result.tagged_exported_program is quantized_program_for_partition + + quantized_program = _build_quantized_program(compile_spec) + edge_manager = to_edge_transform_and_lower( + programs=quantized_program, + partitioner=[partitioner], + ) + executorch_program_manager = edge_manager.to_executorch( + config=ExecutorchBackendConfig(extract_delegate_segments=False) + ) + + pte_path = tmp_path / "ethosu_public_api_bc.pte" + save_pte_program(executorch_program_manager, str(pte_path)) + + assert pte_path.is_file() + assert pte_path.stat().st_size > 0 + assert any((tmp_path / "ethosu_intermediates").rglob("*")) diff --git a/backends/arm/test/public_api_bc/run_public_api_bc_scenarios.py b/backends/arm/test/public_api_bc/run_public_api_bc_scenarios.py new file mode 100644 index 00000000000..cc6fc05185e --- /dev/null +++ b/backends/arm/test/public_api_bc/run_public_api_bc_scenarios.py @@ -0,0 +1,104 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +"""Run Arm public API backward-compatibility scenarios.""" + +from __future__ import annotations + +import os +import shutil +import sys +import tempfile +from pathlib import Path +from subprocess import run # nosec B404 + + +REPO_ROOT = Path(__file__).resolve().parents[4] +PYTEST_CONFIG = Path("backends/arm/test/pytest.ini") +SCENARIO_FILES = ( + Path("backends/arm/test/public_api_bc/ethosu_flow.py"), + Path("backends/arm/test/public_api_bc/vgf_fp_flow.py"), + Path("backends/arm/test/public_api_bc/vgf_int_flow.py"), +) +# Leave empty until a release contains these scenario files. The release epic +# should update this to the oldest still-supported release ref. +OLDEST_SUPPORTED_REF = "" + + +def _resolve_git() -> str: + git = shutil.which("git") + if git is None: + raise RuntimeError("Could not find git in PATH") + return git + + +GIT = _resolve_git() + + +def _materialize_file(repo_relative_path: Path, output_root: Path) -> Path: + destination_path = output_root / repo_relative_path + destination_path.parent.mkdir(parents=True, exist_ok=True) + + if OLDEST_SUPPORTED_REF: + result = run( # nosec B603 + [ + GIT, + "show", + f"{OLDEST_SUPPORTED_REF}:{repo_relative_path.as_posix()}", + ], + cwd=REPO_ROOT, + check=True, + capture_output=True, + text=True, + ) + destination_path.write_text(result.stdout, encoding="utf-8") + return destination_path + + source_path = REPO_ROOT / repo_relative_path + if not source_path.is_file(): + raise FileNotFoundError(f"Missing scenario file: {source_path}") + + shutil.copy2(source_path, destination_path) + return destination_path + + +def _run_pytest(entrypoints: list[Path], output_root: Path) -> None: + env = os.environ.copy() + env.setdefault("PYTEST_DISABLE_PLUGIN_AUTOLOAD", "1") + run( # nosec B603 + [ + sys.executable, + "-m", + "pytest", + "--config-file", + str(REPO_ROOT / PYTEST_CONFIG), + *[str(path) for path in entrypoints], + ], + cwd=output_root, + env=env, + check=True, + ) + + +def main() -> None: + with tempfile.TemporaryDirectory( + prefix="arm-public-api-bc-", + ignore_cleanup_errors=True, + ) as temporary_dir: + materialized_root = Path(temporary_dir) + entrypoints = [ + _materialize_file(repo_relative_path, materialized_root) + for repo_relative_path in SCENARIO_FILES + ] + + source_name = OLDEST_SUPPORTED_REF or "workspace snapshot" + print("Materialized Arm public API BC scenarios:") + print(f" source: {source_name}") + print(f" root: {materialized_root}") + + _run_pytest(entrypoints, materialized_root) + + +if __name__ == "__main__": + main() diff --git a/backends/arm/test/public_api_bc/vgf_fp_flow.py b/backends/arm/test/public_api_bc/vgf_fp_flow.py new file mode 100644 index 00000000000..3f29fb248d3 --- /dev/null +++ b/backends/arm/test/public_api_bc/vgf_fp_flow.py @@ -0,0 +1,75 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from pathlib import Path + +import torch +from executorch.backends.arm import VgfBackend, VgfCompileSpec, VgfPartitioner +from executorch.exir import ExecutorchBackendConfig, to_edge_transform_and_lower +from executorch.extension.export_util.utils import save_pte_program + + +class TinyAddSigmoid(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return self.sigmoid(x + y) + + +def _configured_compile_spec(tmp_path: Path) -> VgfCompileSpec: + compile_spec = VgfCompileSpec("TOSA-1.0+FP") + + assert compile_spec == VgfCompileSpec("TOSA-1.0+FP") + assert "VgfCompileSpec" in repr(compile_spec) + + compile_spec.dump_intermediate_artifacts_to(str(tmp_path / "vgf_fp_intermediates")) + returned = compile_spec.dump_debug_info(VgfCompileSpec.DebugMode.TOSA) + assert returned is compile_spec + return compile_spec + + +def test_vgf_fp_public_api_scenario(tmp_path: Path) -> None: + backend = VgfBackend() + assert isinstance(backend, VgfBackend) + + compile_spec = _configured_compile_spec(tmp_path) + partitioner = VgfPartitioner(compile_spec) + + example_inputs = ( + torch.ones(1, 1, 1, 1), + torch.ones(1, 1, 1, 1), + ) + exported_program_for_partition = torch.export.export( + TinyAddSigmoid().eval(), + example_inputs, + ) + ops_to_preserve, filter_fn = partitioner.ops_to_not_decompose( + exported_program_for_partition + ) + partition_result = partitioner.partition(exported_program_for_partition) + + assert isinstance(ops_to_preserve, list) + assert filter_fn is None or callable(filter_fn) + assert partition_result.tagged_exported_program is exported_program_for_partition + + exported_program = torch.export.export(TinyAddSigmoid().eval(), example_inputs) + edge_manager = to_edge_transform_and_lower( + programs=exported_program, + partitioner=[partitioner], + ) + executorch_program_manager = edge_manager.to_executorch( + config=ExecutorchBackendConfig(extract_delegate_segments=False) + ) + + pte_path = tmp_path / "vgf_fp_public_api_bc.pte" + save_pte_program(executorch_program_manager, str(pte_path)) + + assert pte_path.is_file() + assert pte_path.stat().st_size > 0 + assert any((tmp_path / "vgf_fp_intermediates").rglob("*")) diff --git a/backends/arm/test/public_api_bc/vgf_int_flow.py b/backends/arm/test/public_api_bc/vgf_int_flow.py new file mode 100644 index 00000000000..154ec6b8bd7 --- /dev/null +++ b/backends/arm/test/public_api_bc/vgf_int_flow.py @@ -0,0 +1,113 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from pathlib import Path + +import torch +from executorch.backends.arm import ( + get_symmetric_quantization_config, + VgfCompileSpec, + VgfPartitioner, + VgfQuantizer, +) +from executorch.exir import ExecutorchBackendConfig, to_edge_transform_and_lower +from executorch.extension.export_util.utils import save_pte_program +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e + + +class TinyAddSigmoid(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return self.sigmoid(x + y) + + +def _configured_compile_spec(tmp_path: Path) -> VgfCompileSpec: + compile_spec = VgfCompileSpec("TOSA-1.0+INT") + + assert compile_spec == VgfCompileSpec("TOSA-1.0+INT") + assert "VgfCompileSpec" in repr(compile_spec) + + compile_spec.dump_intermediate_artifacts_to(str(tmp_path / "vgf_int_intermediates")) + returned = compile_spec.dump_debug_info(VgfCompileSpec.DebugMode.TOSA) + assert returned is compile_spec + return compile_spec + + +def _exercise_quantizer_api(compile_spec: VgfCompileSpec) -> None: + quantizer = VgfQuantizer(compile_spec) + symmetric_config = get_symmetric_quantization_config(is_per_channel=False) + + quantizer.set_global(symmetric_config) + quantizer.set_io(symmetric_config) + quantizer.set_module_name("sigmoid", symmetric_config) + quantizer.set_module_type(torch.nn.Sigmoid, symmetric_config) + + example_inputs = ( + torch.ones(1, 1, 1, 1), + torch.ones(1, 1, 1, 1), + ) + graph_module = torch.export.export( + TinyAddSigmoid().eval(), + example_inputs, + ).module(check_guards=False) + transformed = quantizer.transform_for_annotation(graph_module) + annotated = quantizer.annotate(transformed) + quantizer.validate(annotated) + + +def _build_quantized_program(compile_spec: VgfCompileSpec): + model = TinyAddSigmoid().eval() + example_inputs = ( + torch.ones(1, 1, 1, 1), + torch.ones(1, 1, 1, 1), + ) + exported_program = torch.export.export(model, example_inputs) + graph_module = exported_program.module(check_guards=False) + + quantizer = VgfQuantizer(compile_spec) + quantizer.set_global(get_symmetric_quantization_config()) + + prepared = prepare_pt2e(graph_module, quantizer) + prepared(*example_inputs) + converted = convert_pt2e(prepared) + + return torch.export.export(converted, example_inputs) + + +def test_vgf_int_public_api_scenario(tmp_path: Path) -> None: + compile_spec = _configured_compile_spec(tmp_path) + _exercise_quantizer_api(compile_spec) + + partitioner = VgfPartitioner(compile_spec) + quantized_program_for_partition = _build_quantized_program(compile_spec) + ops_to_preserve, filter_fn = partitioner.ops_to_not_decompose( + quantized_program_for_partition + ) + partition_result = partitioner.partition(quantized_program_for_partition) + + assert isinstance(ops_to_preserve, list) + assert filter_fn is None or callable(filter_fn) + assert partition_result.tagged_exported_program is quantized_program_for_partition + + quantized_program = _build_quantized_program(compile_spec) + edge_manager = to_edge_transform_and_lower( + programs=quantized_program, + partitioner=[partitioner], + ) + executorch_program_manager = edge_manager.to_executorch( + config=ExecutorchBackendConfig(extract_delegate_segments=False) + ) + + pte_path = tmp_path / "vgf_int_public_api_bc.pte" + save_pte_program(executorch_program_manager, str(pte_path)) + + assert pte_path.is_file() + assert pte_path.stat().st_size > 0 + assert any((tmp_path / "vgf_int_intermediates").rglob("*")) diff --git a/backends/arm/test/targets.bzl b/backends/arm/test/targets.bzl index 4d9e1f6c169..acb0b2020a9 100644 --- a/backends/arm/test/targets.bzl +++ b/backends/arm/test/targets.bzl @@ -45,6 +45,7 @@ def define_arm_tests(): test_files += [ "misc/test_compile_spec.py", "misc/test_pass_pipeline_config.py", + "misc/test_public_api_manifest.py", "misc/test_tosa_spec.py", "misc/test_bn_relu_folding_qat.py", "misc/test_custom_partition.py", @@ -53,6 +54,13 @@ def define_arm_tests(): # "misc/test_dim_order.py", (TODO - T238390249) ] + # Public API backward-compatibility scenarios + test_files += [ + "public_api_bc/ethosu_flow.py", + "public_api_bc/vgf_fp_flow.py", + "public_api_bc/vgf_int_flow.py", + ] + # Deprecation tests test_files += [ "deprecation/test_arm_compile_spec_deprecation.py",