From f2805e3513b5794736c7d02a82c57a16670a08a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sat, 14 Feb 2026 12:51:53 +0100 Subject: [PATCH 1/9] More tests about patches --- _unittests/ut_helpers/test_args_helper.py | 10 +- _unittests/ut_helpers/test_onnx_helper.py | 60 ++++++ .../test_patch_transformers.py | 175 ++++++++++++++++++ onnx_diagnostic/_command_lines_parser.py | 31 +++- onnx_diagnostic/helpers/args_helper.py | 12 ++ onnx_diagnostic/helpers/onnx_helper.py | 6 + .../hghub/hub_data_cached_configs.py | 33 ++++ 7 files changed, 319 insertions(+), 8 deletions(-) diff --git a/_unittests/ut_helpers/test_args_helper.py b/_unittests/ut_helpers/test_args_helper.py index aa3fbbaa..298c81a8 100644 --- a/_unittests/ut_helpers/test_args_helper.py +++ b/_unittests/ut_helpers/test_args_helper.py @@ -1,6 +1,10 @@ import unittest from onnx_diagnostic.ext_test_case import ExtTestCase -from onnx_diagnostic.helpers.args_helper import get_parsed_args, check_cuda_availability +from onnx_diagnostic.helpers.args_helper import ( + get_parsed_args, + check_cuda_availability, + process_outputname, +) class TestHelpers(ExtTestCase): @@ -52,6 +56,10 @@ def test_args_expose(self): self.assertEqual(args.repeat, 10) self.assertEqual(args.warmup, 5) + def test_process_outputname(self): + self.assertEqual("ggg.g", process_outputname("ggg.g", "hhh.h")) + self.assertEqual("hhh.ggg.h", process_outputname("+.ggg", "hhh.h")) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_helpers/test_onnx_helper.py b/_unittests/ut_helpers/test_onnx_helper.py index 5d83a711..8549d4d6 100644 --- a/_unittests/ut_helpers/test_onnx_helper.py +++ b/_unittests/ut_helpers/test_onnx_helper.py @@ -860,6 +860,66 @@ def test_make_model_with_local_functions_2(self): check_model(new_model) + @hide_stdout() + def test_make_model_with_local_functions_3(self): + model = oh.make_model( + oh.make_graph( + [ + oh.make_node("Unsqueeze", ["X", "zero"], ["xu1"]), + oh.make_node("Unsqueeze", ["xu1", "un"], ["xu2"]), + oh.make_node("Reshape", ["xu2", "shape1"], ["xm1"]), + oh.make_node("Reshape", ["Y", "shape2"], ["xm2c"]), + oh.make_node("Cast", ["xm2c"], ["xm2"], to=1), + oh.make_node("MatMul", ["xm1", "xm2"], ["xm"]), + oh.make_node("Reshape", ["xm", "shape3"], ["Z"]), + ], + "dummy", + [oh.make_tensor_value_info("X", TFLOAT, [320, 1280])], + [oh.make_tensor_value_info("Z", TFLOAT, [3, 5, 320, 640])], + [ + onh.from_array( + np.random.rand(3, 5, 1280, 640).astype(np.float32), name="Y" + ), + onh.from_array(np.array([0], dtype=np.int64), name="zero"), + onh.from_array(np.array([1], dtype=np.int64), name="un"), + onh.from_array(np.array([1, 320, 1280], dtype=np.int64), name="shape1"), + onh.from_array(np.array([15, 1280, 640], dtype=np.int64), name="shape2"), + onh.from_array(np.array([3, 5, 320, 640], dtype=np.int64), name="shape3"), + ], + ), + opset_imports=[oh.make_opsetid("", 18)], + ir_version=9, + ) + for i_node in range(len(model.graph.node) - 1): + if i_node == 2: + continue + node = model.graph.node[i_node] + meta = node.metadata_props.add() + meta.key = f"source[{i_node}]" + meta.value = "LLL" + new_model = make_model_with_local_functions( + model, "^LLL$", metadata_key_prefix="source[", verbose=1 + ) + check_model(model) + self.assertEqual(len(new_model.functions), 1) + p = pretty_onnx(new_model) + self.assertIn("LLL0[local_function]", p) + self.assertIn("LLL1[local_function]", p) + + self.assertEqual(["X", "shape1", "un", "zero"], new_model.functions[0].input) + self.assertEqual(["xm1"], new_model.functions[0].output) + self.assertEqual("LLL0", new_model.functions[0].name) + self.assertEqual("local_function", new_model.functions[0].domain) + self.assertEqual(len(new_model.functions[0].node), 3) + + self.assertEqual(["Y", "shape2"], new_model.functions[1].input) + self.assertEqual(["xm2c"], new_model.functions[1].output) + self.assertEqual("LLL1", new_model.functions[1].name) + self.assertEqual("local_function", new_model.functions[1].domain) + self.assertEqual(len(new_model.functions[1].node), 1) + + check_model(new_model) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_torch_export_patches/test_patch_transformers.py b/_unittests/ut_torch_export_patches/test_patch_transformers.py index e909e23a..2dd18f9b 100644 --- a/_unittests/ut_torch_export_patches/test_patch_transformers.py +++ b/_unittests/ut_torch_export_patches/test_patch_transformers.py @@ -19,6 +19,7 @@ from onnx_diagnostic.torch_models.hghub.hub_api import get_cached_configuration from onnx_diagnostic.torch_export_patches import torch_export_patches from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str +from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs from onnx_diagnostic.torch_export_patches.patches.patch_transformers import ( patch_qwen2_5, patch_funnel, @@ -392,6 +393,20 @@ def forward(self, q, k, cos, sin): rtol=1, ) + @requires_transformers("4.55") + @requires_onnxscript("0.6.2") + @unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers") + def test_qwen_function_proto(self): + from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import ( + LoopAttention23, + LoopMHAAttention, + PackedAttention, + ) + + LoopMHAAttention.to_function_proto() + LoopAttention23.to_function_proto() + PackedAttention.to_function_proto() + @requires_transformers("4.55") @unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers") def test_patched_qwen2_5_vl_rot_pos_emb(self): @@ -874,6 +889,166 @@ def test_model_funnel(self): got = patched.relative_positional_attention(**inputs) self.assertEqualArray(expected, got) + def test_cache_dependant_input_preparation_exporting(self): + from onnx_diagnostic.torch_export_patches.patches._patch_transformers_generation_mixin import ( # noqa: E501 + patched_GenerationMixin as GenerationMixin, + ) + + with self.subTest(case="case1"): + input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)[:, :0] + inputs_embeds = torch.rand((2, 8), dtype=torch.float32) + cache_position = torch.arange(0, 8, dtype=torch.int64) + eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation( + input_ids, inputs_embeds, cache_position + ) + export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting( + input_ids, inputs_embeds, cache_position + ) + torch.testing.assert_close(eager1, export1) + torch.testing.assert_close(eager2, export2) + + with self.subTest(case="case2"): + input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64) + inputs_embeds = torch.rand((2, 8), dtype=torch.float32) + cache_position = torch.arange(0, 8, dtype=torch.int64) + eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation( + input_ids, inputs_embeds, cache_position + ) + export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting( + input_ids, inputs_embeds, cache_position + ) + torch.testing.assert_close(eager1, export1) + torch.testing.assert_close(eager2, export2) + + with self.subTest(case="case3"): + input_ids = torch.randint(0, 16, (2, 12), dtype=torch.int64) + inputs_embeds = None + cache_position = torch.arange(0, 8, dtype=torch.int64) + eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation( + input_ids, inputs_embeds, cache_position + ) + export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting( + input_ids, inputs_embeds, cache_position + ) + torch.testing.assert_close(eager1, export1) + torch.testing.assert_close(eager2, export2) + + with self.subTest(case="case4"): + input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64) + inputs_embeds = None + cache_position = torch.arange(0, 8, dtype=torch.int64) + eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation( + input_ids, inputs_embeds, cache_position + ) + export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting( + input_ids, inputs_embeds, cache_position + ) + torch.testing.assert_close(eager1, export1) + torch.testing.assert_close(eager2, export2) + + def test_prepare_inputs_for_generation_decoder_llm(self): + data = get_untrained_model_with_inputs( + "hf-internal-testing/tiny-random-LlamaForCausalLM" + ) + model = data["model"] + config = model.config + torch_device = "cpu" + + with torch_export_patches(patch_transformers=True): + with self.subTest(case="case1"): + self.assertTrue("GenerationMixin" in str(model.prepare_inputs_for_generation)) + + input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]]).to(torch_device) + cache_position = torch.arange(input_ids.shape[1], device=input_ids.device) + + with self.subTest(case="case2"): + input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]]).to(torch_device) + model_inputs = model.prepare_inputs_for_generation( + input_ids, cache_position=cache_position + ) + self.assertTrue(torch.all(model_inputs["input_ids"] == input_ids)) + + with self.subTest(case="case3"): + attention_mask = torch.tensor([[1, 1, 1], [1, 1, 1]]).to(torch_device) + model_inputs = model.prepare_inputs_for_generation( + input_ids, attention_mask=attention_mask, cache_position=cache_position + ) + self.assertTrue(torch.all(model_inputs["attention_mask"] == attention_mask)) + self.assertTrue(model_inputs["position_ids"].shape == input_ids.shape) + + with self.subTest(case="case4"): + self.assertFalse("use_cache" in model_inputs) + model_inputs = model.prepare_inputs_for_generation( + input_ids, use_cache=True, foo="bar", cache_position=cache_position + ) + self.assertTrue(model_inputs["use_cache"] is True) + self.assertTrue(model_inputs["foo"] == "bar") + + with self.subTest(case="case5"): + init_input_ids = input_ids[:, :2] + dynamic_cache = transformers.cache_utils.DynamicCache(config=config) + dynamic_cache = model( + init_input_ids, past_key_values=dynamic_cache + ).past_key_values + with self.assertRaises((AttributeError, TypeError)): + model_inputs = model.prepare_inputs_for_generation( + input_ids, past_key_values=dynamic_cache + ) + + with self.subTest(case="case6"): + cache_position = torch.arange(input_ids.shape[-1], dtype=torch.long).to( + torch_device + ) + cache_position = cache_position[dynamic_cache.get_seq_length() :] + model_inputs = model.prepare_inputs_for_generation( + input_ids, + past_key_values=dynamic_cache, + cache_position=cache_position, + attention_mask=attention_mask, + ) + self.assertTrue("past_key_values" in model_inputs) + self.assertTrue(torch.all(model_inputs["cache_position"] == cache_position)) + self.assertTrue( + model_inputs["input_ids"].shape[-1] == 1 + ) # 1 = 3 fed tokens - 2 tokens in the cache + self.assertTrue(model_inputs["position_ids"].shape[-1] == 1) + self.assertTrue( + model_inputs["attention_mask"].shape[-1] == 3 + ) # we still need the full attention mask! + + with self.subTest(case="case6.2"): + max_cache_len = 10 + batch_size = 2 + query_length = input_ids.shape[-1] - init_input_ids.shape[-1] + static_cache = transformers.cache_utils.StaticCache( + config=config, max_cache_len=max_cache_len + ) + static_cache = model( + init_input_ids, past_key_values=static_cache + ).past_key_values + model_inputs = model.prepare_inputs_for_generation( + input_ids, + past_key_values=static_cache, + cache_position=cache_position, + attention_mask=attention_mask, + ) + self.assertTrue("past_key_values" in model_inputs) + self.assertTrue( + list(model_inputs["attention_mask"].shape) + == [batch_size, 1, query_length, max_cache_len] + ) + + with self.subTest(case="case7"): + init_inputs_embeds = model.get_input_embeddings()(init_input_ids) + model_inputs = model.prepare_inputs_for_generation( + input_ids, + past_key_values=dynamic_cache, + inputs_embeds=init_inputs_embeds, + cache_position=cache_position, + ) + self.assertTrue(model_inputs["input_ids"] is not None) + self.assertTrue(model_inputs["inputs_embeds"] is None) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index 25897600..d4a908b3 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -47,6 +47,7 @@ def get_parser_dot() -> ArgumentParser: def _cmd_dot(argv: List[Any]): import subprocess + from .helpers.args_helper import process_outputname from .helpers.dot_helper import to_dot parser = get_parser_dot() @@ -58,15 +59,17 @@ def _cmd_dot(argv: List[Any]): print("-- converts into dot") dot = to_dot(onx) if args.output: + outname = process_outputname(args.output, args.input) if args.verbose: - print(f"-- saves into {args.output}") - with open(args.output, "w") as f: + print(f"-- saves into {outname!r}") + with open(outname, "w") as f: f.write(dot) else: print(dot) if args.run: assert args.output, "Cannot run dot without an output file." - cmds = ["dot", f"-T{args.run}", args.output, "-o", f"{args.output}.{args.run}"] + outname = process_outputname(outname, args.input) + cmds = ["dot", f"-T{args.run}", outname, "-o", f"{args.output}.{args.run}"] if args.verbose: print(f"-- run {' '.join(cmds)}") p = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE) @@ -1553,10 +1556,11 @@ def _cmd_optimize(argv: List[Any]): parser = get_parser_optimize() args = parser.parse_args(argv[1:]) + from .helpers.args_helper import process_outputname from .helpers.optim_helper import optimize_model output = ( - args.output + process_outputname(args.output, args.input) if args.output else f"{os.path.splitext(args.input)[0]}.o-{args.algorithm}.onnx" ) @@ -1586,10 +1590,21 @@ def get_parser_partition() -> ArgumentParser: The regular may match the following values, 'model.layers.0.forward', 'model.layers.1.forward', ... A local function will be created for each distinct layer. + + Example: + + python -m onnx_diagnostic partition \\ + model.onnx +.part -v 1 -r "model.layers.0.s.*" """), ) parser.add_argument("input", help="input model") - parser.add_argument("output", help="output model") + parser.add_argument( + "output", + help=textwrap.dedent(""" + output model, an expression like '+.part' + inserts '.part' just before the extension" + """).strip("\n"), + ) parser.add_argument( "-r", "--regex", @@ -1619,6 +1634,7 @@ def get_parser_partition() -> ArgumentParser: def _cmd_partition(argv: List[Any]): + from .helpers.args_helper import process_outputname from .helpers.onnx_helper import make_model_with_local_functions parser = get_parser_partition() @@ -1635,9 +1651,10 @@ def _cmd_partition(argv: List[Any]): metadata_key_prefix=tuple(args.meta_prefix.split(",")), verbose=args.verbose, ) + outname = process_outputname(args.output, args.input) if args.verbose: - print(f"-- save into {args.output!r}") - onnx.save(onx2, args.output) + print(f"-- save into {outname!r}") + onnx.save(onx2, outname) if args.verbose: print("-- done") diff --git a/onnx_diagnostic/helpers/args_helper.py b/onnx_diagnostic/helpers/args_helper.py index df840b77..4c238597 100644 --- a/onnx_diagnostic/helpers/args_helper.py +++ b/onnx_diagnostic/helpers/args_helper.py @@ -1,3 +1,4 @@ +import os import subprocess from argparse import ArgumentParser, Namespace from typing import Dict, List, Optional, Tuple, Union @@ -131,3 +132,14 @@ def get_parsed_args( if update: res.__dict__.update(update) return res + + +def process_outputname(output_name: str, input_name: str) -> str: + """ + If 'output_name' starts with '+', then it is modified into + ``.extension``. + """ + if not output_name.startswith("+"): + return output_name + name, ext = os.path.splitext(input_name) + return f"{name}{output_name[1:]}{ext}" diff --git a/onnx_diagnostic/helpers/onnx_helper.py b/onnx_diagnostic/helpers/onnx_helper.py index 605a6f31..96e2f8c1 100644 --- a/onnx_diagnostic/helpers/onnx_helper.py +++ b/onnx_diagnostic/helpers/onnx_helper.py @@ -1856,6 +1856,12 @@ def make_model_with_local_functions( if verbose: print(f"[make_model_with_local_functions] matched {len(unique)} partitions") + for un, nid in unique.items(): + print(f"[make_model_with_local_functions] {un!r}: {len(nid)} nodes") + for ind in nid[:5]: + print(f" {pretty_onnx(model.graph.node[ind])}") + if len(nid) > 5: + print(" ...") functions = [] new_nodes: List[Optional[NodeProto]] = list(model.graph.node) for key, node_indices in unique.items(): diff --git a/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py b/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py index 4183c1b9..daee5225 100644 --- a/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +++ b/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py @@ -4972,3 +4972,36 @@ def _ccached_qwen_qwen2_5_vl_7b_instruct(): "vocab_size": 152064, }, ) + + +def _ccached_hf_internal_testing_tiny_random_LlamaForCausalLM(): + "hf-internal-testing/tiny-random-LlamaForCausalLM" + return transformers.LlamaConfig( + **{ + "architectures": ["LlamaForCausalLM"], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 0, + "dtype": "float32", + "eos_token_id": 1, + "head_dim": 4, + "hidden_act": "silu", + "hidden_size": 16, + "initializer_range": 0.02, + "intermediate_size": 64, + "max_position_embeddings": 2048, + "mlp_bias": false, + "model_type": "llama", + "num_attention_heads": 4, + "num_hidden_layers": 2, + "num_key_value_heads": 4, + "pad_token_id": -1, + "pretraining_tp": 1, + "rms_norm_eps": 1e-06, + "rope_parameters": {"rope_theta": 10000.0, "rope_type": "default"}, + "tie_word_embeddings": false, + "transformers_version": "5.2.0.dev0", + "use_cache": true, + "vocab_size": 32000, + } + ) From 1560c57522591d3e08cb6e7baa9554315bf396db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sat, 14 Feb 2026 13:01:16 +0100 Subject: [PATCH 2/9] remove one patch --- .../test_patch_transformers.py | 10 +++++----- .../_patch_transformers_generation_mixin.py | 16 +++++++++------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/_unittests/ut_torch_export_patches/test_patch_transformers.py b/_unittests/ut_torch_export_patches/test_patch_transformers.py index 2dd18f9b..872cf0fd 100644 --- a/_unittests/ut_torch_export_patches/test_patch_transformers.py +++ b/_unittests/ut_torch_export_patches/test_patch_transformers.py @@ -984,12 +984,12 @@ def test_prepare_inputs_for_generation_decoder_llm(self): self.assertTrue(model_inputs["use_cache"] is True) self.assertTrue(model_inputs["foo"] == "bar") + init_input_ids = input_ids[:, :2] + dynamic_cache = transformers.cache_utils.DynamicCache(config=config) + dynamic_cache = model( + init_input_ids, past_key_values=dynamic_cache + ).past_key_values with self.subTest(case="case5"): - init_input_ids = input_ids[:, :2] - dynamic_cache = transformers.cache_utils.DynamicCache(config=config) - dynamic_cache = model( - init_input_ids, past_key_values=dynamic_cache - ).past_key_values with self.assertRaises((AttributeError, TypeError)): model_inputs = model.prepare_inputs_for_generation( input_ids, past_key_values=dynamic_cache diff --git a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py index 2eaeb35c..722e30b2 100644 --- a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +++ b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py @@ -1,6 +1,5 @@ import inspect -import os -from typing import Optional, Tuple, Union +from typing import Optional, Tuple import packaging.version as pv import torch import transformers @@ -22,11 +21,11 @@ class patched_GenerationMixin: if pv.Version(transformers.__version__) >= pv.Version("4.56") else "prepare_inputs_for_generation" ), - ( - "_sample" - if pv.Version(transformers.__version__) == pv.Version("4.57.0.dev0") - else None - ), + # ( + # "_sample" + # if pv.Version(transformers.__version__) == pv.Version("4.57.0.dev0") + # else None + # ), ] _PATCHED_CLASS_ = transformers.generation.utils.GenerationMixin @@ -299,6 +298,8 @@ def prepare_inputs_for_generation( model_inputs.pop("labels", None) return model_inputs + ''' + # drops a patch since it is for a very specific version. def _sample( self, input_ids: torch.LongTensor, @@ -484,3 +485,4 @@ def _sample( ) else: return input_ids + ''' From a7f0bb69c1186691652f06f3c24ed23cf18d3317 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sat, 14 Feb 2026 13:14:14 +0100 Subject: [PATCH 3/9] fix --- .../test_patch_transformers.py | 7 +++++ onnx_diagnostic/helpers/onnx_helper.py | 27 ++++++++++--------- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/_unittests/ut_torch_export_patches/test_patch_transformers.py b/_unittests/ut_torch_export_patches/test_patch_transformers.py index 872cf0fd..2975dcc4 100644 --- a/_unittests/ut_torch_export_patches/test_patch_transformers.py +++ b/_unittests/ut_torch_export_patches/test_patch_transformers.py @@ -12,6 +12,7 @@ requires_torch, ignore_warnings, has_onnxscript, + has_transformers, requires_onnxscript, ) from onnx_diagnostic.helpers.torch_helper import torch_deepcopy, fake_torchdynamo_exporting @@ -908,6 +909,7 @@ def test_cache_dependant_input_preparation_exporting(self): torch.testing.assert_close(eager2, export2) with self.subTest(case="case2"): + raise unittest.SkipTest("torch 2.10+ has probably a bug here.") input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64) inputs_embeds = torch.rand((2, 8), dtype=torch.float32) cache_position = torch.arange(0, 8, dtype=torch.int64) @@ -989,7 +991,10 @@ def test_prepare_inputs_for_generation_decoder_llm(self): dynamic_cache = model( init_input_ids, past_key_values=dynamic_cache ).past_key_values + with self.subTest(case="case5"): + if not has_transformers("4.57"): + raise unittest.SkipTest("transformers 4.57+.") with self.assertRaises((AttributeError, TypeError)): model_inputs = model.prepare_inputs_for_generation( input_ids, past_key_values=dynamic_cache @@ -1039,6 +1044,8 @@ def test_prepare_inputs_for_generation_decoder_llm(self): ) with self.subTest(case="case7"): + if not has_transformers("4.57"): + raise unittest.SkipTest("transformers 4.57+.") init_inputs_embeds = model.get_input_embeddings()(init_input_ids) model_inputs = model.prepare_inputs_for_generation( input_ids, diff --git a/onnx_diagnostic/helpers/onnx_helper.py b/onnx_diagnostic/helpers/onnx_helper.py index 96e2f8c1..0c94f961 100644 --- a/onnx_diagnostic/helpers/onnx_helper.py +++ b/onnx_diagnostic/helpers/onnx_helper.py @@ -1355,6 +1355,20 @@ def _mkv_(name, itype, irank): ) +def unknown_names_within_nodes(nodes: List[NodeProto]) -> Set[str]: + """Returns the list of unkonwn results from a list of nodes.""" + not_known: Set[str] = set() + for node in nodes[::-1]: + not_known -= {o for o in node.output if o} + not_known |= {i for i in node.input if i} + if node.op_type in {"Scan", "If", "Loop"}: + # there are hidden inputs + for att in node.attribute: + if att.type == onnx.AttributeProto.GRAPH: + not_known |= get_hidden_inputs(att.g) + return not_known + + def make_subfunction( name: str, nodes: List[NodeProto], @@ -1374,21 +1388,12 @@ def make_subfunction( :param domain: function domain :return: model proto """ - not_known: Set[str] = set() - for node in nodes[::-1]: - not_known -= {o for o in node.output if o} - not_known |= {i for i in node.input if i} - if node.op_type in {"Scan", "If", "Loop"}: - # there are hidden inputs - for att in node.attribute: - if att.type == onnx.AttributeProto.GRAPH: - not_known |= get_hidden_inputs(att.g) return oh.make_function( domain, name, nodes=nodes, - inputs=sorted(not_known), + inputs=sorted(unknown_names_within_nodes(nodes)), outputs=output_names, opset_imports=opset_imports, ) @@ -1775,8 +1780,6 @@ def check_for_non_recursivity( needs an output of the function and is also required by the function: it is probably missing from the initial set. - - :param node_list: list of nodes :param inputs: input names to consider :param outputs: output names which cannot be involved in input names From b26324a2f532e8002ab8296cbadd019cc4376979 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sat, 14 Feb 2026 13:22:56 +0100 Subject: [PATCH 4/9] fix --- .../test_patch_transformers.py | 1 + onnx_diagnostic/helpers/onnx_helper.py | 15 +++++++++++---- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/_unittests/ut_torch_export_patches/test_patch_transformers.py b/_unittests/ut_torch_export_patches/test_patch_transformers.py index 2975dcc4..aa654a69 100644 --- a/_unittests/ut_torch_export_patches/test_patch_transformers.py +++ b/_unittests/ut_torch_export_patches/test_patch_transformers.py @@ -948,6 +948,7 @@ def test_cache_dependant_input_preparation_exporting(self): torch.testing.assert_close(eager1, export1) torch.testing.assert_close(eager2, export2) + @requires_transformers("4.57") def test_prepare_inputs_for_generation_decoder_llm(self): data = get_untrained_model_with_inputs( "hf-internal-testing/tiny-random-LlamaForCausalLM" diff --git a/onnx_diagnostic/helpers/onnx_helper.py b/onnx_diagnostic/helpers/onnx_helper.py index 0c94f961..5f802fe5 100644 --- a/onnx_diagnostic/helpers/onnx_helper.py +++ b/onnx_diagnostic/helpers/onnx_helper.py @@ -1356,7 +1356,7 @@ def _mkv_(name, itype, irank): def unknown_names_within_nodes(nodes: List[NodeProto]) -> Set[str]: - """Returns the list of unkonwn results from a list of nodes.""" + """Returns the list of unknown results from a list of nodes.""" not_known: Set[str] = set() for node in nodes[::-1]: not_known -= {o for o in node.output if o} @@ -1875,15 +1875,22 @@ def make_model_with_local_functions( f"nodes in partition {function_name!r}" ) outputs = _find_used_names(new_nodes, node_indices) - function_nodes = [new_nodes[i] for i in node_indices] + function_nodes = [new_nodes[i] for i in node_indices if new_nodes[i]] + + check_for_non_recursivity( + function_nodes, unknown_names_within_nodes(function_nodes), outputs + ) + lf = make_subfunction( function_name, - [n for n in function_nodes if n], + function_nodes, model.opset_import, outputs, domain=domain, ) - check_for_non_recursivity(new_nodes, lf.input, lf.output) + + check_for_non_recursivity(function_nodes, lf.input, lf.output) + functions.append(lf) maxi = max(node_indices) for i in node_indices: From f5e36362d56c43c1e48853debde942ca396c09fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sat, 14 Feb 2026 13:34:18 +0100 Subject: [PATCH 5/9] verbose --- _unittests/ut_helpers/test_onnx_helper.py | 3 ++- onnx_diagnostic/helpers/onnx_helper.py | 6 +++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/_unittests/ut_helpers/test_onnx_helper.py b/_unittests/ut_helpers/test_onnx_helper.py index 8549d4d6..9844ebb7 100644 --- a/_unittests/ut_helpers/test_onnx_helper.py +++ b/_unittests/ut_helpers/test_onnx_helper.py @@ -890,6 +890,7 @@ def test_make_model_with_local_functions_3(self): opset_imports=[oh.make_opsetid("", 18)], ir_version=9, ) + check_model(model) for i_node in range(len(model.graph.node) - 1): if i_node == 2: continue @@ -900,7 +901,7 @@ def test_make_model_with_local_functions_3(self): new_model = make_model_with_local_functions( model, "^LLL$", metadata_key_prefix="source[", verbose=1 ) - check_model(model) + check_model(new_model) self.assertEqual(len(new_model.functions), 1) p = pretty_onnx(new_model) self.assertIn("LLL0[local_function]", p) diff --git a/onnx_diagnostic/helpers/onnx_helper.py b/onnx_diagnostic/helpers/onnx_helper.py index 5f802fe5..4965850b 100644 --- a/onnx_diagnostic/helpers/onnx_helper.py +++ b/onnx_diagnostic/helpers/onnx_helper.py @@ -1388,7 +1388,6 @@ def make_subfunction( :param domain: function domain :return: model proto """ - return oh.make_function( domain, name, @@ -1891,6 +1890,11 @@ def make_model_with_local_functions( check_for_non_recursivity(function_nodes, lf.input, lf.output) + if verbose: + print( + f"[make_model_with_local_functions] add function {function_name}" + f"({', '.join(lf.input)}) -> {', '.join(lf.input)}" + ) functions.append(lf) maxi = max(node_indices) for i in node_indices: From 0c3e47e941ee5bcd35d1d0da91761679c9ec3955 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sat, 14 Feb 2026 17:53:33 +0100 Subject: [PATCH 6/9] improve algo --- _unittests/ut_helpers/test_onnx_helper.py | 8 +- onnx_diagnostic/helpers/onnx_helper.py | 112 ++++++++++++++-------- 2 files changed, 74 insertions(+), 46 deletions(-) diff --git a/_unittests/ut_helpers/test_onnx_helper.py b/_unittests/ut_helpers/test_onnx_helper.py index 9844ebb7..79b38722 100644 --- a/_unittests/ut_helpers/test_onnx_helper.py +++ b/_unittests/ut_helpers/test_onnx_helper.py @@ -793,13 +793,7 @@ def test_make_model_with_local_functions_bug(self): meta = node.metadata_props.add() meta.key = "namespace" meta.value = "LLL" - self.assertRaise( - lambda: make_model_with_local_functions(model, "^LLL$"), - ValueError, - "Results {'xu1'} are needed for inputs ['X', 'Y', 'shape1', " - "'shape2', 'xu2', 'zero'] but also requires ['xm1', 'xm2', 'xu1'] " - "which is not allowed.", - ) + self.assertRaise(lambda: make_model_with_local_functions(model, "^LLL$"), ValueError) check_model(model) @hide_stdout() diff --git a/onnx_diagnostic/helpers/onnx_helper.py b/onnx_diagnostic/helpers/onnx_helper.py index 4965850b..619c329c 100644 --- a/onnx_diagnostic/helpers/onnx_helper.py +++ b/onnx_diagnostic/helpers/onnx_helper.py @@ -1771,36 +1771,88 @@ def _find_used_names(node_list, node_indices): def check_for_non_recursivity( - node_list: List[Optional[NodeProto]], inputs: Sequence[str], outputs: Sequence[str] + node_indices: List[int], + node_list: List[Optional[NodeProto]], + inputs: Sequence[str], + outputs: Sequence[str], ): """ - We finally need to check that any of this output is not required + We need to check that any of this output is not required by one input from the function itself, that would mean one node needs an output of the function and is also required by the function: it is probably missing from the initial set. + :param node_indices: node_indices part of the subset :param node_list: list of nodes :param inputs: input names to consider :param outputs: output names which cannot be involved in input names """ + orginal_set_inputs = set(inputs) set_inputs = set(inputs) - set_outputs = set(outputs) - for node in node_list[::-1]: + original_set_outputs = set(outputs) + subset = set(node_indices) + before_inputs = set() + indexed_node = list(enumerate(node_list)) + for ind, node in indexed_node[::-1]: if not node: continue - si = set(node.output) - if si & set_inputs: - set_inputs |= set(node.input) - if node.op_type in {"Scan", "If", "Loop"}: - # there are hidden inputs - for att in node.attribute: - if att.type == onnx.AttributeProto.GRAPH: - set_inputs |= get_hidden_inputs(att.g) - if set_outputs & set_inputs: - raise ValueError( - f"Results {set_outputs & set_inputs} are needed for inputs {inputs} " - f"but also requires {outputs} which is not allowed." - ) + s_outputs = set(o for o in node.output if o) + if ind in subset: + # The node is part of the subset. + if s_outputs & set_inputs: + set_inputs |= set(i for i in node.input if i) + if node.op_type in {"Scan", "If", "Loop"}: + # there are hidden inputs + for att in node.attribute: + if att.type == onnx.AttributeProto.GRAPH: + set_inputs |= get_hidden_inputs(att.g) + if original_set_outputs & set_inputs: + raise ValueError( + f"Results {original_set_outputs & set_inputs} " + f"are needed for inputs {inputs} " + f"but also requires {outputs} which is not allowed." + ) + else: + # Not part of the function. Let's check + if s_outputs & orginal_set_inputs: + before_inputs |= set(i for i in node.input if i) + if node.op_type in {"Scan", "If", "Loop"}: + # there are hidden inputs + for att in node.attribute: + if att.type == onnx.AttributeProto.GRAPH: + before_inputs |= get_hidden_inputs(att.g) + if original_set_outputs & before_inputs: + raise ValueError( + f"Results {original_set_outputs & before_inputs} " + f"are needed for inputs {inputs} " + f"but also requires {outputs} which is not allowed." + ) + + +def _select_nodes_from_metadata_with_regex( + model: ModelProto, prefix: str, regex: str +) -> Tuple[Dict[str, List[int]], Set[str]]: + reg = re.compile(regex) + unique_values = set() + unique: Dict[str, List[int]] = {} + for i, node in enumerate(model.graph.node): + selected = False + for data in node.metadata_props: + if data.key.startswith(prefix): + values = re.split("[,:]", data.value) + for v in values: + if not v: + continue + if reg.match(v): + if v not in unique: + unique[v] = [] + unique[v].append(i) + selected = True + break + unique_values.add(v) + if selected: + break + return unique, unique_values def make_model_with_local_functions( @@ -1830,26 +1882,8 @@ def make_model_with_local_functions( if isinstance(metadata_key_prefix, tuple) else (metadata_key_prefix,) ) - reg = re.compile(regex) - unique_values = set() - unique: Dict[str, List[int]] = {} - for i, node in enumerate(model.graph.node): - selected = False - for data in node.metadata_props: - if data.key.startswith(prefix): - values = re.split("[,:]", data.value) - for v in values: - if not v: - continue - if reg.match(v): - if v not in unique: - unique[v] = [] - unique[v].append(i) - selected = True - break - unique_values.add(v) - if selected: - break + unique, unique_values = _select_nodes_from_metadata_with_regex(model, prefix, regex) + # sets of nodes. if not unique: if verbose: @@ -1877,7 +1911,7 @@ def make_model_with_local_functions( function_nodes = [new_nodes[i] for i in node_indices if new_nodes[i]] check_for_non_recursivity( - function_nodes, unknown_names_within_nodes(function_nodes), outputs + node_indices, model.graph.node, unknown_names_within_nodes(function_nodes), outputs ) lf = make_subfunction( @@ -1888,7 +1922,7 @@ def make_model_with_local_functions( domain=domain, ) - check_for_non_recursivity(function_nodes, lf.input, lf.output) + check_for_non_recursivity(node_indices, model.graph.node, lf.input, lf.output) if verbose: print( From 821fa764ce9921495a361e784fd7fd80b18030f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sun, 15 Feb 2026 13:06:27 +0100 Subject: [PATCH 7/9] fix partition --- _unittests/ut_helpers/test_onnx_helper.py | 29 +++-- onnx_diagnostic/helpers/onnx_helper.py | 137 +++++++++++++++------- 2 files changed, 112 insertions(+), 54 deletions(-) diff --git a/_unittests/ut_helpers/test_onnx_helper.py b/_unittests/ut_helpers/test_onnx_helper.py index 79b38722..e0737a62 100644 --- a/_unittests/ut_helpers/test_onnx_helper.py +++ b/_unittests/ut_helpers/test_onnx_helper.py @@ -892,26 +892,31 @@ def test_make_model_with_local_functions_3(self): meta = node.metadata_props.add() meta.key = f"source[{i_node}]" meta.value = "LLL" + self.assertRaise( + lambda: make_model_with_local_functions( + model, + "^LLL$", + metadata_key_prefix="source[", + verbose=1, + allow_extensions=False, + ), + ValueError, + ) new_model = make_model_with_local_functions( model, "^LLL$", metadata_key_prefix="source[", verbose=1 ) check_model(new_model) self.assertEqual(len(new_model.functions), 1) p = pretty_onnx(new_model) - self.assertIn("LLL0[local_function]", p) - self.assertIn("LLL1[local_function]", p) + self.assertIn("LLL[local_function]", p) - self.assertEqual(["X", "shape1", "un", "zero"], new_model.functions[0].input) - self.assertEqual(["xm1"], new_model.functions[0].output) - self.assertEqual("LLL0", new_model.functions[0].name) + self.assertEqual( + ["X", "Y", "shape1", "shape2", "un", "zero"], new_model.functions[0].input + ) + self.assertEqual(["xm"], new_model.functions[0].output) + self.assertEqual("LLL", new_model.functions[0].name) self.assertEqual("local_function", new_model.functions[0].domain) - self.assertEqual(len(new_model.functions[0].node), 3) - - self.assertEqual(["Y", "shape2"], new_model.functions[1].input) - self.assertEqual(["xm2c"], new_model.functions[1].output) - self.assertEqual("LLL1", new_model.functions[1].name) - self.assertEqual("local_function", new_model.functions[1].domain) - self.assertEqual(len(new_model.functions[1].node), 1) + self.assertEqual(len(new_model.functions[0].node), 6) check_model(new_model) diff --git a/onnx_diagnostic/helpers/onnx_helper.py b/onnx_diagnostic/helpers/onnx_helper.py index 619c329c..a011c48a 100644 --- a/onnx_diagnostic/helpers/onnx_helper.py +++ b/onnx_diagnostic/helpers/onnx_helper.py @@ -1773,9 +1773,10 @@ def _find_used_names(node_list, node_indices): def check_for_non_recursivity( node_indices: List[int], node_list: List[Optional[NodeProto]], - inputs: Sequence[str], - outputs: Sequence[str], -): + inputs: Union[Set[str], Sequence[str]], + outputs: Union[Set[str], Sequence[str]], + exc: bool = True, +) -> List[int]: """ We need to check that any of this output is not required by one input from the function itself, that would mean one node @@ -1786,13 +1787,17 @@ def check_for_non_recursivity( :param node_list: list of nodes :param inputs: input names to consider :param outputs: output names which cannot be involved in input names + :param exc: raise an exception as soon as possible it becomes impossible + :return: list of nodes to add to make the list of node consistence + with the list of inputs and outputs (they should be recomputed) """ - orginal_set_inputs = set(inputs) - set_inputs = set(inputs) - original_set_outputs = set(outputs) + orginal_set_inputs = inputs if isinstance(inputs, set) else set(inputs) + set_inputs = orginal_set_inputs.copy() + original_set_outputs = outputs if isinstance(outputs, set) else set(outputs) subset = set(node_indices) before_inputs = set() indexed_node = list(enumerate(node_list)) + additional_nodes: List[int] = [] for ind, node in indexed_node[::-1]: if not node: continue @@ -1822,15 +1827,18 @@ def check_for_non_recursivity( if att.type == onnx.AttributeProto.GRAPH: before_inputs |= get_hidden_inputs(att.g) if original_set_outputs & before_inputs: - raise ValueError( - f"Results {original_set_outputs & before_inputs} " - f"are needed for inputs {inputs} " - f"but also requires {outputs} which is not allowed." - ) + if exc: + raise ValueError( + f"Results {original_set_outputs & before_inputs} " + f"are needed for inputs {inputs} " + f"but also requires {outputs} which is not allowed." + ) + additional_nodes.append(ind) + return additional_nodes def _select_nodes_from_metadata_with_regex( - model: ModelProto, prefix: str, regex: str + model: ModelProto, prefix: Union[str, Tuple[str, ...]], regex: str ) -> Tuple[Dict[str, List[int]], Set[str]]: reg = re.compile(regex) unique_values = set() @@ -1860,6 +1868,7 @@ def make_model_with_local_functions( regex: str = ".*[.]layers[.][0-9]+[.]forward$", domain: str = "local_function", metadata_key_prefix: Union[str, Tuple[str, ...]] = ("namespace", "source["), + allow_extensions: bool = True, verbose: int = 0, ) -> ModelProto: """ @@ -1874,6 +1883,8 @@ def make_model_with_local_functions( :param domain: function domain :param metadata_keys: list of metadata keys to consider, every value is split into multiple ones. + :param allow_extensions: allows the function to take nodes outside + a partition if there are not already inside another partition :param verbose: verbosity :return: model proto """ @@ -1900,40 +1911,82 @@ def make_model_with_local_functions( print(" ...") functions = [] new_nodes: List[Optional[NodeProto]] = list(model.graph.node) - for key, node_indices in unique.items(): - function_name = key.strip().replace(".", "_") - if verbose: - print( - f"[make_model_with_local_functions] move {len(node_indices)} " - f"nodes in partition {function_name!r}" + processed = {} + unique_as_set = {k: set(v) for k, v in unique.items()} + while len(processed) < len(unique): + for key, node_indices in unique.items(): + if key in processed: + # already processed + continue + function_name = key.strip().replace(".", "_") + if verbose: + print( + f"[make_model_with_local_functions] move {len(node_indices)} " + f"nodes in partition {key!r} (function={function_name!r})" + ) + outputs = _find_used_names(new_nodes, node_indices) + # pyrefly: ignore[bad-assignment] + function_nodes: List[NodeProto] = [ + new_nodes[i] for i in node_indices if new_nodes[i] + ] + + function_inputs = unknown_names_within_nodes(function_nodes) + additional_nodes = check_for_non_recursivity( + node_indices, new_nodes, function_inputs, outputs, exc=False ) - outputs = _find_used_names(new_nodes, node_indices) - function_nodes = [new_nodes[i] for i in node_indices if new_nodes[i]] - - check_for_non_recursivity( - node_indices, model.graph.node, unknown_names_within_nodes(function_nodes), outputs - ) + if additional_nodes: + if not allow_extensions: + raise ValueError( + f"Function for key={key!r} cannot be added because " + f"it must steal a node outside the partition, node ids " + f"{additional_nodes} are needed for inputs {function_inputs} " + f"but also requires {outputs} which is not allowed." + ) + # Additional nodes are needed to make the function consistence. + # We check they are not in conflict with other partitions not + # yet processed. + set_add = set(additional_nodes) + for k, v in unique_as_set.items(): + if v & set_add: + raise ValueError( + f"Function for key={key!r} cannot be added because " + f"it is conflict with other key {k!r} with node ids " + f"{set_add & v} are needed for inputs {function_inputs} " + f"but also requires {outputs} which is not allowed." + ) + # If no exception, everything is fine, let's add the nodes. + node_indices.extend(additional_nodes) + node_indices[:] = sorted(node_indices) + # Inputs and outputs needed to be recomputed. Let's do that in another + # iteration. + if verbose: + print( + f"[make_model_with_local_functions] add {len(additional_nodes)} " + f"nodes in partition {key!r}" + ) + continue - lf = make_subfunction( - function_name, - function_nodes, - model.opset_import, - outputs, - domain=domain, - ) + lf = make_subfunction( + function_name, + function_nodes, + model.opset_import, + outputs, + domain=domain, + ) - check_for_non_recursivity(node_indices, model.graph.node, lf.input, lf.output) + check_for_non_recursivity(node_indices, new_nodes, lf.input, lf.output) - if verbose: - print( - f"[make_model_with_local_functions] add function {function_name}" - f"({', '.join(lf.input)}) -> {', '.join(lf.input)}" - ) - functions.append(lf) - maxi = max(node_indices) - for i in node_indices: - new_nodes[i] = None - new_nodes[maxi] = oh.make_node(lf.name, lf.input, lf.output, domain=lf.domain) + if verbose: + print( + f"[make_model_with_local_functions] add function {function_name}" + f"({', '.join(lf.input)}) -> {', '.join(lf.input)}" + ) + functions.append(lf) + maxi = max(node_indices) + for i in node_indices: + new_nodes[i] = None + new_nodes[maxi] = oh.make_node(lf.name, lf.input, lf.output, domain=lf.domain) + processed[key] = lf return oh.make_model( oh.make_graph( From a6532fd16c0597692af1c8130c356b47f768c3a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sun, 15 Feb 2026 13:16:41 +0100 Subject: [PATCH 8/9] fix exceptions --- _unittests/ut_helpers/test_onnx_helper.py | 5 ++++- onnx_diagnostic/helpers/onnx_helper.py | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/_unittests/ut_helpers/test_onnx_helper.py b/_unittests/ut_helpers/test_onnx_helper.py index e0737a62..2889c553 100644 --- a/_unittests/ut_helpers/test_onnx_helper.py +++ b/_unittests/ut_helpers/test_onnx_helper.py @@ -793,7 +793,10 @@ def test_make_model_with_local_functions_bug(self): meta = node.metadata_props.add() meta.key = "namespace" meta.value = "LLL" - self.assertRaise(lambda: make_model_with_local_functions(model, "^LLL$"), ValueError) + self.assertRaise( + lambda: make_model_with_local_functions(model, "^LLL$", allow_extensions=False), + ValueError, + ) check_model(model) @hide_stdout() diff --git a/onnx_diagnostic/helpers/onnx_helper.py b/onnx_diagnostic/helpers/onnx_helper.py index a011c48a..9d55eab0 100644 --- a/onnx_diagnostic/helpers/onnx_helper.py +++ b/onnx_diagnostic/helpers/onnx_helper.py @@ -1911,7 +1911,7 @@ def make_model_with_local_functions( print(" ...") functions = [] new_nodes: List[Optional[NodeProto]] = list(model.graph.node) - processed = {} + processed: Dict[str, FunctionProto] = {} unique_as_set = {k: set(v) for k, v in unique.items()} while len(processed) < len(unique): for key, node_indices in unique.items(): @@ -1932,7 +1932,7 @@ def make_model_with_local_functions( function_inputs = unknown_names_within_nodes(function_nodes) additional_nodes = check_for_non_recursivity( - node_indices, new_nodes, function_inputs, outputs, exc=False + node_indices, new_nodes, function_inputs, outputs, exc=not allow_extensions ) if additional_nodes: if not allow_extensions: From 6afc650ae1abd72ee97660bbaf754ba7bed25356 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sun, 15 Feb 2026 13:38:39 +0100 Subject: [PATCH 9/9] documentation --- CHANGELOGS.rst | 1 + onnx_diagnostic/helpers/onnx_helper.py | 72 ++++++++++++++++++++++++++ 2 files changed, 73 insertions(+) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 9c1150b1..3f5a4a22 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.9.2 +++++ +* :pr:`415`: improves function make_model_with_local_functions to support ill-defined partitions * :pr:`413`: fix InputObserver in the generic case * :pr:`412`: patches for ViTModel (through rewriting) diff --git a/onnx_diagnostic/helpers/onnx_helper.py b/onnx_diagnostic/helpers/onnx_helper.py index 9d55eab0..ae0b61fa 100644 --- a/onnx_diagnostic/helpers/onnx_helper.py +++ b/onnx_diagnostic/helpers/onnx_helper.py @@ -1887,6 +1887,78 @@ def make_model_with_local_functions( a partition if there are not already inside another partition :param verbose: verbosity :return: model proto + + Example: + + .. runpython:: + :showcode: + + import numpy as np + import onnx + import onnx.helper as oh + import onnx.numpy_helper as onh + from onnx_diagnostic.helpers.onnx_helper import ( + make_model_with_local_functions, + pretty_onnx, + ) + + model = oh.make_model( + oh.make_graph( + [ + oh.make_node("Unsqueeze", ["X", "zero"], ["xu1"]), + oh.make_node("Unsqueeze", ["xu1", "un"], ["xu2"]), + oh.make_node("Reshape", ["xu2", "shape1"], ["xm1"]), + oh.make_node("Reshape", ["Y", "shape2"], ["xm2c"]), + oh.make_node("Cast", ["xm2c"], ["xm2"], to=1), + oh.make_node("MatMul", ["xm1", "xm2"], ["xm"]), + oh.make_node("Reshape", ["xm", "shape3"], ["Z"]), + ], + "dummy", + [oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [320, 1280])], + [oh.make_tensor_value_info("Z", onnx.TensorProto.FLOAT, [3, 5, 320, 640])], + [ + onh.from_array( + np.random.rand(3, 5, 1280, 640).astype(np.float32), name="Y" + ), + onh.from_array(np.array([0], dtype=np.int64), name="zero"), + onh.from_array(np.array([1], dtype=np.int64), name="un"), + onh.from_array(np.array([1, 320, 1280], dtype=np.int64), name="shape1"), + onh.from_array(np.array([15, 1280, 640], dtype=np.int64), name="shape2"), + onh.from_array(np.array([3, 5, 320, 640], dtype=np.int64), name="shape3"), + ], + ), + opset_imports=[oh.make_opsetid("", 18)], + ir_version=9, + ) + for i_node in [0, 1, 2, 3]: + node = model.graph.node[i_node] + meta = node.metadata_props.add() + meta.key = f"source[{i_node}]" + meta.value = f"LLL{i_node//3}" + + print("-- model before --") + print(pretty_onnx(model)) + print() + print("-- metadata --") + for node in model.graph.node: + text = ( + f" -- [{node.metadata_props[0].key}: {node.metadata_props[0].value}]" + if node.metadata_props + else "" + ) + print( + f"-- {node.op_type}({', '.join(node.input)}) -> " + f"{', '.join(node.output)}{text}" + ) + print() + + new_model = make_model_with_local_functions( + model, "^LLL[01]$", metadata_key_prefix="source[", verbose=1 + ) + + print() + print("-- model after --") + print(pretty_onnx(new_model)) """ prefix = ( metadata_key_prefix