diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 9c1150b1..3f5a4a22 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.9.2 +++++ +* :pr:`415`: improves function make_model_with_local_functions to support ill-defined partitions * :pr:`413`: fix InputObserver in the generic case * :pr:`412`: patches for ViTModel (through rewriting) diff --git a/_unittests/ut_helpers/test_args_helper.py b/_unittests/ut_helpers/test_args_helper.py index aa3fbbaa..298c81a8 100644 --- a/_unittests/ut_helpers/test_args_helper.py +++ b/_unittests/ut_helpers/test_args_helper.py @@ -1,6 +1,10 @@ import unittest from onnx_diagnostic.ext_test_case import ExtTestCase -from onnx_diagnostic.helpers.args_helper import get_parsed_args, check_cuda_availability +from onnx_diagnostic.helpers.args_helper import ( + get_parsed_args, + check_cuda_availability, + process_outputname, +) class TestHelpers(ExtTestCase): @@ -52,6 +56,10 @@ def test_args_expose(self): self.assertEqual(args.repeat, 10) self.assertEqual(args.warmup, 5) + def test_process_outputname(self): + self.assertEqual("ggg.g", process_outputname("ggg.g", "hhh.h")) + self.assertEqual("hhh.ggg.h", process_outputname("+.ggg", "hhh.h")) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_helpers/test_onnx_helper.py b/_unittests/ut_helpers/test_onnx_helper.py index 5d83a711..2889c553 100644 --- a/_unittests/ut_helpers/test_onnx_helper.py +++ b/_unittests/ut_helpers/test_onnx_helper.py @@ -794,11 +794,8 @@ def test_make_model_with_local_functions_bug(self): meta.key = "namespace" meta.value = "LLL" self.assertRaise( - lambda: make_model_with_local_functions(model, "^LLL$"), + lambda: make_model_with_local_functions(model, "^LLL$", allow_extensions=False), ValueError, - "Results {'xu1'} are needed for inputs ['X', 'Y', 'shape1', " - "'shape2', 'xu2', 'zero'] but also requires ['xm1', 'xm2', 'xu1'] " - "which is not allowed.", ) check_model(model) @@ -860,6 +857,72 @@ def test_make_model_with_local_functions_2(self): check_model(new_model) + @hide_stdout() + def test_make_model_with_local_functions_3(self): + model = oh.make_model( + oh.make_graph( + [ + oh.make_node("Unsqueeze", ["X", "zero"], ["xu1"]), + oh.make_node("Unsqueeze", ["xu1", "un"], ["xu2"]), + oh.make_node("Reshape", ["xu2", "shape1"], ["xm1"]), + oh.make_node("Reshape", ["Y", "shape2"], ["xm2c"]), + oh.make_node("Cast", ["xm2c"], ["xm2"], to=1), + oh.make_node("MatMul", ["xm1", "xm2"], ["xm"]), + oh.make_node("Reshape", ["xm", "shape3"], ["Z"]), + ], + "dummy", + [oh.make_tensor_value_info("X", TFLOAT, [320, 1280])], + [oh.make_tensor_value_info("Z", TFLOAT, [3, 5, 320, 640])], + [ + onh.from_array( + np.random.rand(3, 5, 1280, 640).astype(np.float32), name="Y" + ), + onh.from_array(np.array([0], dtype=np.int64), name="zero"), + onh.from_array(np.array([1], dtype=np.int64), name="un"), + onh.from_array(np.array([1, 320, 1280], dtype=np.int64), name="shape1"), + onh.from_array(np.array([15, 1280, 640], dtype=np.int64), name="shape2"), + onh.from_array(np.array([3, 5, 320, 640], dtype=np.int64), name="shape3"), + ], + ), + opset_imports=[oh.make_opsetid("", 18)], + ir_version=9, + ) + check_model(model) + for i_node in range(len(model.graph.node) - 1): + if i_node == 2: + continue + node = model.graph.node[i_node] + meta = node.metadata_props.add() + meta.key = f"source[{i_node}]" + meta.value = "LLL" + self.assertRaise( + lambda: make_model_with_local_functions( + model, + "^LLL$", + metadata_key_prefix="source[", + verbose=1, + allow_extensions=False, + ), + ValueError, + ) + new_model = make_model_with_local_functions( + model, "^LLL$", metadata_key_prefix="source[", verbose=1 + ) + check_model(new_model) + self.assertEqual(len(new_model.functions), 1) + p = pretty_onnx(new_model) + self.assertIn("LLL[local_function]", p) + + self.assertEqual( + ["X", "Y", "shape1", "shape2", "un", "zero"], new_model.functions[0].input + ) + self.assertEqual(["xm"], new_model.functions[0].output) + self.assertEqual("LLL", new_model.functions[0].name) + self.assertEqual("local_function", new_model.functions[0].domain) + self.assertEqual(len(new_model.functions[0].node), 6) + + check_model(new_model) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_torch_export_patches/test_patch_transformers.py b/_unittests/ut_torch_export_patches/test_patch_transformers.py index e909e23a..aa654a69 100644 --- a/_unittests/ut_torch_export_patches/test_patch_transformers.py +++ b/_unittests/ut_torch_export_patches/test_patch_transformers.py @@ -12,6 +12,7 @@ requires_torch, ignore_warnings, has_onnxscript, + has_transformers, requires_onnxscript, ) from onnx_diagnostic.helpers.torch_helper import torch_deepcopy, fake_torchdynamo_exporting @@ -19,6 +20,7 @@ from onnx_diagnostic.torch_models.hghub.hub_api import get_cached_configuration from onnx_diagnostic.torch_export_patches import torch_export_patches from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str +from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs from onnx_diagnostic.torch_export_patches.patches.patch_transformers import ( patch_qwen2_5, patch_funnel, @@ -392,6 +394,20 @@ def forward(self, q, k, cos, sin): rtol=1, ) + @requires_transformers("4.55") + @requires_onnxscript("0.6.2") + @unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers") + def test_qwen_function_proto(self): + from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import ( + LoopAttention23, + LoopMHAAttention, + PackedAttention, + ) + + LoopMHAAttention.to_function_proto() + LoopAttention23.to_function_proto() + PackedAttention.to_function_proto() + @requires_transformers("4.55") @unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers") def test_patched_qwen2_5_vl_rot_pos_emb(self): @@ -874,6 +890,173 @@ def test_model_funnel(self): got = patched.relative_positional_attention(**inputs) self.assertEqualArray(expected, got) + def test_cache_dependant_input_preparation_exporting(self): + from onnx_diagnostic.torch_export_patches.patches._patch_transformers_generation_mixin import ( # noqa: E501 + patched_GenerationMixin as GenerationMixin, + ) + + with self.subTest(case="case1"): + input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)[:, :0] + inputs_embeds = torch.rand((2, 8), dtype=torch.float32) + cache_position = torch.arange(0, 8, dtype=torch.int64) + eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation( + input_ids, inputs_embeds, cache_position + ) + export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting( + input_ids, inputs_embeds, cache_position + ) + torch.testing.assert_close(eager1, export1) + torch.testing.assert_close(eager2, export2) + + with self.subTest(case="case2"): + raise unittest.SkipTest("torch 2.10+ has probably a bug here.") + input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64) + inputs_embeds = torch.rand((2, 8), dtype=torch.float32) + cache_position = torch.arange(0, 8, dtype=torch.int64) + eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation( + input_ids, inputs_embeds, cache_position + ) + export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting( + input_ids, inputs_embeds, cache_position + ) + torch.testing.assert_close(eager1, export1) + torch.testing.assert_close(eager2, export2) + + with self.subTest(case="case3"): + input_ids = torch.randint(0, 16, (2, 12), dtype=torch.int64) + inputs_embeds = None + cache_position = torch.arange(0, 8, dtype=torch.int64) + eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation( + input_ids, inputs_embeds, cache_position + ) + export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting( + input_ids, inputs_embeds, cache_position + ) + torch.testing.assert_close(eager1, export1) + torch.testing.assert_close(eager2, export2) + + with self.subTest(case="case4"): + input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64) + inputs_embeds = None + cache_position = torch.arange(0, 8, dtype=torch.int64) + eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation( + input_ids, inputs_embeds, cache_position + ) + export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting( + input_ids, inputs_embeds, cache_position + ) + torch.testing.assert_close(eager1, export1) + torch.testing.assert_close(eager2, export2) + + @requires_transformers("4.57") + def test_prepare_inputs_for_generation_decoder_llm(self): + data = get_untrained_model_with_inputs( + "hf-internal-testing/tiny-random-LlamaForCausalLM" + ) + model = data["model"] + config = model.config + torch_device = "cpu" + + with torch_export_patches(patch_transformers=True): + with self.subTest(case="case1"): + self.assertTrue("GenerationMixin" in str(model.prepare_inputs_for_generation)) + + input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]]).to(torch_device) + cache_position = torch.arange(input_ids.shape[1], device=input_ids.device) + + with self.subTest(case="case2"): + input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]]).to(torch_device) + model_inputs = model.prepare_inputs_for_generation( + input_ids, cache_position=cache_position + ) + self.assertTrue(torch.all(model_inputs["input_ids"] == input_ids)) + + with self.subTest(case="case3"): + attention_mask = torch.tensor([[1, 1, 1], [1, 1, 1]]).to(torch_device) + model_inputs = model.prepare_inputs_for_generation( + input_ids, attention_mask=attention_mask, cache_position=cache_position + ) + self.assertTrue(torch.all(model_inputs["attention_mask"] == attention_mask)) + self.assertTrue(model_inputs["position_ids"].shape == input_ids.shape) + + with self.subTest(case="case4"): + self.assertFalse("use_cache" in model_inputs) + model_inputs = model.prepare_inputs_for_generation( + input_ids, use_cache=True, foo="bar", cache_position=cache_position + ) + self.assertTrue(model_inputs["use_cache"] is True) + self.assertTrue(model_inputs["foo"] == "bar") + + init_input_ids = input_ids[:, :2] + dynamic_cache = transformers.cache_utils.DynamicCache(config=config) + dynamic_cache = model( + init_input_ids, past_key_values=dynamic_cache + ).past_key_values + + with self.subTest(case="case5"): + if not has_transformers("4.57"): + raise unittest.SkipTest("transformers 4.57+.") + with self.assertRaises((AttributeError, TypeError)): + model_inputs = model.prepare_inputs_for_generation( + input_ids, past_key_values=dynamic_cache + ) + + with self.subTest(case="case6"): + cache_position = torch.arange(input_ids.shape[-1], dtype=torch.long).to( + torch_device + ) + cache_position = cache_position[dynamic_cache.get_seq_length() :] + model_inputs = model.prepare_inputs_for_generation( + input_ids, + past_key_values=dynamic_cache, + cache_position=cache_position, + attention_mask=attention_mask, + ) + self.assertTrue("past_key_values" in model_inputs) + self.assertTrue(torch.all(model_inputs["cache_position"] == cache_position)) + self.assertTrue( + model_inputs["input_ids"].shape[-1] == 1 + ) # 1 = 3 fed tokens - 2 tokens in the cache + self.assertTrue(model_inputs["position_ids"].shape[-1] == 1) + self.assertTrue( + model_inputs["attention_mask"].shape[-1] == 3 + ) # we still need the full attention mask! + + with self.subTest(case="case6.2"): + max_cache_len = 10 + batch_size = 2 + query_length = input_ids.shape[-1] - init_input_ids.shape[-1] + static_cache = transformers.cache_utils.StaticCache( + config=config, max_cache_len=max_cache_len + ) + static_cache = model( + init_input_ids, past_key_values=static_cache + ).past_key_values + model_inputs = model.prepare_inputs_for_generation( + input_ids, + past_key_values=static_cache, + cache_position=cache_position, + attention_mask=attention_mask, + ) + self.assertTrue("past_key_values" in model_inputs) + self.assertTrue( + list(model_inputs["attention_mask"].shape) + == [batch_size, 1, query_length, max_cache_len] + ) + + with self.subTest(case="case7"): + if not has_transformers("4.57"): + raise unittest.SkipTest("transformers 4.57+.") + init_inputs_embeds = model.get_input_embeddings()(init_input_ids) + model_inputs = model.prepare_inputs_for_generation( + input_ids, + past_key_values=dynamic_cache, + inputs_embeds=init_inputs_embeds, + cache_position=cache_position, + ) + self.assertTrue(model_inputs["input_ids"] is not None) + self.assertTrue(model_inputs["inputs_embeds"] is None) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index 25897600..d4a908b3 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -47,6 +47,7 @@ def get_parser_dot() -> ArgumentParser: def _cmd_dot(argv: List[Any]): import subprocess + from .helpers.args_helper import process_outputname from .helpers.dot_helper import to_dot parser = get_parser_dot() @@ -58,15 +59,17 @@ def _cmd_dot(argv: List[Any]): print("-- converts into dot") dot = to_dot(onx) if args.output: + outname = process_outputname(args.output, args.input) if args.verbose: - print(f"-- saves into {args.output}") - with open(args.output, "w") as f: + print(f"-- saves into {outname!r}") + with open(outname, "w") as f: f.write(dot) else: print(dot) if args.run: assert args.output, "Cannot run dot without an output file." - cmds = ["dot", f"-T{args.run}", args.output, "-o", f"{args.output}.{args.run}"] + outname = process_outputname(outname, args.input) + cmds = ["dot", f"-T{args.run}", outname, "-o", f"{args.output}.{args.run}"] if args.verbose: print(f"-- run {' '.join(cmds)}") p = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE) @@ -1553,10 +1556,11 @@ def _cmd_optimize(argv: List[Any]): parser = get_parser_optimize() args = parser.parse_args(argv[1:]) + from .helpers.args_helper import process_outputname from .helpers.optim_helper import optimize_model output = ( - args.output + process_outputname(args.output, args.input) if args.output else f"{os.path.splitext(args.input)[0]}.o-{args.algorithm}.onnx" ) @@ -1586,10 +1590,21 @@ def get_parser_partition() -> ArgumentParser: The regular may match the following values, 'model.layers.0.forward', 'model.layers.1.forward', ... A local function will be created for each distinct layer. + + Example: + + python -m onnx_diagnostic partition \\ + model.onnx +.part -v 1 -r "model.layers.0.s.*" """), ) parser.add_argument("input", help="input model") - parser.add_argument("output", help="output model") + parser.add_argument( + "output", + help=textwrap.dedent(""" + output model, an expression like '+.part' + inserts '.part' just before the extension" + """).strip("\n"), + ) parser.add_argument( "-r", "--regex", @@ -1619,6 +1634,7 @@ def get_parser_partition() -> ArgumentParser: def _cmd_partition(argv: List[Any]): + from .helpers.args_helper import process_outputname from .helpers.onnx_helper import make_model_with_local_functions parser = get_parser_partition() @@ -1635,9 +1651,10 @@ def _cmd_partition(argv: List[Any]): metadata_key_prefix=tuple(args.meta_prefix.split(",")), verbose=args.verbose, ) + outname = process_outputname(args.output, args.input) if args.verbose: - print(f"-- save into {args.output!r}") - onnx.save(onx2, args.output) + print(f"-- save into {outname!r}") + onnx.save(onx2, outname) if args.verbose: print("-- done") diff --git a/onnx_diagnostic/helpers/args_helper.py b/onnx_diagnostic/helpers/args_helper.py index df840b77..4c238597 100644 --- a/onnx_diagnostic/helpers/args_helper.py +++ b/onnx_diagnostic/helpers/args_helper.py @@ -1,3 +1,4 @@ +import os import subprocess from argparse import ArgumentParser, Namespace from typing import Dict, List, Optional, Tuple, Union @@ -131,3 +132,14 @@ def get_parsed_args( if update: res.__dict__.update(update) return res + + +def process_outputname(output_name: str, input_name: str) -> str: + """ + If 'output_name' starts with '+', then it is modified into + ``.extension``. + """ + if not output_name.startswith("+"): + return output_name + name, ext = os.path.splitext(input_name) + return f"{name}{output_name[1:]}{ext}" diff --git a/onnx_diagnostic/helpers/onnx_helper.py b/onnx_diagnostic/helpers/onnx_helper.py index 605a6f31..ae0b61fa 100644 --- a/onnx_diagnostic/helpers/onnx_helper.py +++ b/onnx_diagnostic/helpers/onnx_helper.py @@ -1355,6 +1355,20 @@ def _mkv_(name, itype, irank): ) +def unknown_names_within_nodes(nodes: List[NodeProto]) -> Set[str]: + """Returns the list of unknown results from a list of nodes.""" + not_known: Set[str] = set() + for node in nodes[::-1]: + not_known -= {o for o in node.output if o} + not_known |= {i for i in node.input if i} + if node.op_type in {"Scan", "If", "Loop"}: + # there are hidden inputs + for att in node.attribute: + if att.type == onnx.AttributeProto.GRAPH: + not_known |= get_hidden_inputs(att.g) + return not_known + + def make_subfunction( name: str, nodes: List[NodeProto], @@ -1374,21 +1388,11 @@ def make_subfunction( :param domain: function domain :return: model proto """ - not_known: Set[str] = set() - for node in nodes[::-1]: - not_known -= {o for o in node.output if o} - not_known |= {i for i in node.input if i} - if node.op_type in {"Scan", "If", "Loop"}: - # there are hidden inputs - for att in node.attribute: - if att.type == onnx.AttributeProto.GRAPH: - not_known |= get_hidden_inputs(att.g) - return oh.make_function( domain, name, nodes=nodes, - inputs=sorted(not_known), + inputs=sorted(unknown_names_within_nodes(nodes)), outputs=output_names, opset_imports=opset_imports, ) @@ -1767,38 +1771,96 @@ def _find_used_names(node_list, node_indices): def check_for_non_recursivity( - node_list: List[Optional[NodeProto]], inputs: Sequence[str], outputs: Sequence[str] -): + node_indices: List[int], + node_list: List[Optional[NodeProto]], + inputs: Union[Set[str], Sequence[str]], + outputs: Union[Set[str], Sequence[str]], + exc: bool = True, +) -> List[int]: """ - We finally need to check that any of this output is not required + We need to check that any of this output is not required by one input from the function itself, that would mean one node needs an output of the function and is also required by the function: it is probably missing from the initial set. - - + :param node_indices: node_indices part of the subset :param node_list: list of nodes :param inputs: input names to consider :param outputs: output names which cannot be involved in input names + :param exc: raise an exception as soon as possible it becomes impossible + :return: list of nodes to add to make the list of node consistence + with the list of inputs and outputs (they should be recomputed) """ - set_inputs = set(inputs) - set_outputs = set(outputs) - for node in node_list[::-1]: + orginal_set_inputs = inputs if isinstance(inputs, set) else set(inputs) + set_inputs = orginal_set_inputs.copy() + original_set_outputs = outputs if isinstance(outputs, set) else set(outputs) + subset = set(node_indices) + before_inputs = set() + indexed_node = list(enumerate(node_list)) + additional_nodes: List[int] = [] + for ind, node in indexed_node[::-1]: if not node: continue - si = set(node.output) - if si & set_inputs: - set_inputs |= set(node.input) - if node.op_type in {"Scan", "If", "Loop"}: - # there are hidden inputs - for att in node.attribute: - if att.type == onnx.AttributeProto.GRAPH: - set_inputs |= get_hidden_inputs(att.g) - if set_outputs & set_inputs: - raise ValueError( - f"Results {set_outputs & set_inputs} are needed for inputs {inputs} " - f"but also requires {outputs} which is not allowed." - ) + s_outputs = set(o for o in node.output if o) + if ind in subset: + # The node is part of the subset. + if s_outputs & set_inputs: + set_inputs |= set(i for i in node.input if i) + if node.op_type in {"Scan", "If", "Loop"}: + # there are hidden inputs + for att in node.attribute: + if att.type == onnx.AttributeProto.GRAPH: + set_inputs |= get_hidden_inputs(att.g) + if original_set_outputs & set_inputs: + raise ValueError( + f"Results {original_set_outputs & set_inputs} " + f"are needed for inputs {inputs} " + f"but also requires {outputs} which is not allowed." + ) + else: + # Not part of the function. Let's check + if s_outputs & orginal_set_inputs: + before_inputs |= set(i for i in node.input if i) + if node.op_type in {"Scan", "If", "Loop"}: + # there are hidden inputs + for att in node.attribute: + if att.type == onnx.AttributeProto.GRAPH: + before_inputs |= get_hidden_inputs(att.g) + if original_set_outputs & before_inputs: + if exc: + raise ValueError( + f"Results {original_set_outputs & before_inputs} " + f"are needed for inputs {inputs} " + f"but also requires {outputs} which is not allowed." + ) + additional_nodes.append(ind) + return additional_nodes + + +def _select_nodes_from_metadata_with_regex( + model: ModelProto, prefix: Union[str, Tuple[str, ...]], regex: str +) -> Tuple[Dict[str, List[int]], Set[str]]: + reg = re.compile(regex) + unique_values = set() + unique: Dict[str, List[int]] = {} + for i, node in enumerate(model.graph.node): + selected = False + for data in node.metadata_props: + if data.key.startswith(prefix): + values = re.split("[,:]", data.value) + for v in values: + if not v: + continue + if reg.match(v): + if v not in unique: + unique[v] = [] + unique[v].append(i) + selected = True + break + unique_values.add(v) + if selected: + break + return unique, unique_values def make_model_with_local_functions( @@ -1806,6 +1868,7 @@ def make_model_with_local_functions( regex: str = ".*[.]layers[.][0-9]+[.]forward$", domain: str = "local_function", metadata_key_prefix: Union[str, Tuple[str, ...]] = ("namespace", "source["), + allow_extensions: bool = True, verbose: int = 0, ) -> ModelProto: """ @@ -1820,34 +1883,90 @@ def make_model_with_local_functions( :param domain: function domain :param metadata_keys: list of metadata keys to consider, every value is split into multiple ones. + :param allow_extensions: allows the function to take nodes outside + a partition if there are not already inside another partition :param verbose: verbosity :return: model proto + + Example: + + .. runpython:: + :showcode: + + import numpy as np + import onnx + import onnx.helper as oh + import onnx.numpy_helper as onh + from onnx_diagnostic.helpers.onnx_helper import ( + make_model_with_local_functions, + pretty_onnx, + ) + + model = oh.make_model( + oh.make_graph( + [ + oh.make_node("Unsqueeze", ["X", "zero"], ["xu1"]), + oh.make_node("Unsqueeze", ["xu1", "un"], ["xu2"]), + oh.make_node("Reshape", ["xu2", "shape1"], ["xm1"]), + oh.make_node("Reshape", ["Y", "shape2"], ["xm2c"]), + oh.make_node("Cast", ["xm2c"], ["xm2"], to=1), + oh.make_node("MatMul", ["xm1", "xm2"], ["xm"]), + oh.make_node("Reshape", ["xm", "shape3"], ["Z"]), + ], + "dummy", + [oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [320, 1280])], + [oh.make_tensor_value_info("Z", onnx.TensorProto.FLOAT, [3, 5, 320, 640])], + [ + onh.from_array( + np.random.rand(3, 5, 1280, 640).astype(np.float32), name="Y" + ), + onh.from_array(np.array([0], dtype=np.int64), name="zero"), + onh.from_array(np.array([1], dtype=np.int64), name="un"), + onh.from_array(np.array([1, 320, 1280], dtype=np.int64), name="shape1"), + onh.from_array(np.array([15, 1280, 640], dtype=np.int64), name="shape2"), + onh.from_array(np.array([3, 5, 320, 640], dtype=np.int64), name="shape3"), + ], + ), + opset_imports=[oh.make_opsetid("", 18)], + ir_version=9, + ) + for i_node in [0, 1, 2, 3]: + node = model.graph.node[i_node] + meta = node.metadata_props.add() + meta.key = f"source[{i_node}]" + meta.value = f"LLL{i_node//3}" + + print("-- model before --") + print(pretty_onnx(model)) + print() + print("-- metadata --") + for node in model.graph.node: + text = ( + f" -- [{node.metadata_props[0].key}: {node.metadata_props[0].value}]" + if node.metadata_props + else "" + ) + print( + f"-- {node.op_type}({', '.join(node.input)}) -> " + f"{', '.join(node.output)}{text}" + ) + print() + + new_model = make_model_with_local_functions( + model, "^LLL[01]$", metadata_key_prefix="source[", verbose=1 + ) + + print() + print("-- model after --") + print(pretty_onnx(new_model)) """ prefix = ( metadata_key_prefix if isinstance(metadata_key_prefix, tuple) else (metadata_key_prefix,) ) - reg = re.compile(regex) - unique_values = set() - unique: Dict[str, List[int]] = {} - for i, node in enumerate(model.graph.node): - selected = False - for data in node.metadata_props: - if data.key.startswith(prefix): - values = re.split("[,:]", data.value) - for v in values: - if not v: - continue - if reg.match(v): - if v not in unique: - unique[v] = [] - unique[v].append(i) - selected = True - break - unique_values.add(v) - if selected: - break + unique, unique_values = _select_nodes_from_metadata_with_regex(model, prefix, regex) + # sets of nodes. if not unique: if verbose: @@ -1856,30 +1975,90 @@ def make_model_with_local_functions( if verbose: print(f"[make_model_with_local_functions] matched {len(unique)} partitions") + for un, nid in unique.items(): + print(f"[make_model_with_local_functions] {un!r}: {len(nid)} nodes") + for ind in nid[:5]: + print(f" {pretty_onnx(model.graph.node[ind])}") + if len(nid) > 5: + print(" ...") functions = [] new_nodes: List[Optional[NodeProto]] = list(model.graph.node) - for key, node_indices in unique.items(): - function_name = key.strip().replace(".", "_") - if verbose: - print( - f"[make_model_with_local_functions] move {len(node_indices)} " - f"nodes in partition {function_name!r}" + processed: Dict[str, FunctionProto] = {} + unique_as_set = {k: set(v) for k, v in unique.items()} + while len(processed) < len(unique): + for key, node_indices in unique.items(): + if key in processed: + # already processed + continue + function_name = key.strip().replace(".", "_") + if verbose: + print( + f"[make_model_with_local_functions] move {len(node_indices)} " + f"nodes in partition {key!r} (function={function_name!r})" + ) + outputs = _find_used_names(new_nodes, node_indices) + # pyrefly: ignore[bad-assignment] + function_nodes: List[NodeProto] = [ + new_nodes[i] for i in node_indices if new_nodes[i] + ] + + function_inputs = unknown_names_within_nodes(function_nodes) + additional_nodes = check_for_non_recursivity( + node_indices, new_nodes, function_inputs, outputs, exc=not allow_extensions ) - outputs = _find_used_names(new_nodes, node_indices) - function_nodes = [new_nodes[i] for i in node_indices] - lf = make_subfunction( - function_name, - [n for n in function_nodes if n], - model.opset_import, - outputs, - domain=domain, - ) - check_for_non_recursivity(new_nodes, lf.input, lf.output) - functions.append(lf) - maxi = max(node_indices) - for i in node_indices: - new_nodes[i] = None - new_nodes[maxi] = oh.make_node(lf.name, lf.input, lf.output, domain=lf.domain) + if additional_nodes: + if not allow_extensions: + raise ValueError( + f"Function for key={key!r} cannot be added because " + f"it must steal a node outside the partition, node ids " + f"{additional_nodes} are needed for inputs {function_inputs} " + f"but also requires {outputs} which is not allowed." + ) + # Additional nodes are needed to make the function consistence. + # We check they are not in conflict with other partitions not + # yet processed. + set_add = set(additional_nodes) + for k, v in unique_as_set.items(): + if v & set_add: + raise ValueError( + f"Function for key={key!r} cannot be added because " + f"it is conflict with other key {k!r} with node ids " + f"{set_add & v} are needed for inputs {function_inputs} " + f"but also requires {outputs} which is not allowed." + ) + # If no exception, everything is fine, let's add the nodes. + node_indices.extend(additional_nodes) + node_indices[:] = sorted(node_indices) + # Inputs and outputs needed to be recomputed. Let's do that in another + # iteration. + if verbose: + print( + f"[make_model_with_local_functions] add {len(additional_nodes)} " + f"nodes in partition {key!r}" + ) + continue + + lf = make_subfunction( + function_name, + function_nodes, + model.opset_import, + outputs, + domain=domain, + ) + + check_for_non_recursivity(node_indices, new_nodes, lf.input, lf.output) + + if verbose: + print( + f"[make_model_with_local_functions] add function {function_name}" + f"({', '.join(lf.input)}) -> {', '.join(lf.input)}" + ) + functions.append(lf) + maxi = max(node_indices) + for i in node_indices: + new_nodes[i] = None + new_nodes[maxi] = oh.make_node(lf.name, lf.input, lf.output, domain=lf.domain) + processed[key] = lf return oh.make_model( oh.make_graph( diff --git a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py index 2eaeb35c..722e30b2 100644 --- a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +++ b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py @@ -1,6 +1,5 @@ import inspect -import os -from typing import Optional, Tuple, Union +from typing import Optional, Tuple import packaging.version as pv import torch import transformers @@ -22,11 +21,11 @@ class patched_GenerationMixin: if pv.Version(transformers.__version__) >= pv.Version("4.56") else "prepare_inputs_for_generation" ), - ( - "_sample" - if pv.Version(transformers.__version__) == pv.Version("4.57.0.dev0") - else None - ), + # ( + # "_sample" + # if pv.Version(transformers.__version__) == pv.Version("4.57.0.dev0") + # else None + # ), ] _PATCHED_CLASS_ = transformers.generation.utils.GenerationMixin @@ -299,6 +298,8 @@ def prepare_inputs_for_generation( model_inputs.pop("labels", None) return model_inputs + ''' + # drops a patch since it is for a very specific version. def _sample( self, input_ids: torch.LongTensor, @@ -484,3 +485,4 @@ def _sample( ) else: return input_ids + ''' diff --git a/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py b/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py index 4183c1b9..daee5225 100644 --- a/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +++ b/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py @@ -4972,3 +4972,36 @@ def _ccached_qwen_qwen2_5_vl_7b_instruct(): "vocab_size": 152064, }, ) + + +def _ccached_hf_internal_testing_tiny_random_LlamaForCausalLM(): + "hf-internal-testing/tiny-random-LlamaForCausalLM" + return transformers.LlamaConfig( + **{ + "architectures": ["LlamaForCausalLM"], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 0, + "dtype": "float32", + "eos_token_id": 1, + "head_dim": 4, + "hidden_act": "silu", + "hidden_size": 16, + "initializer_range": 0.02, + "intermediate_size": 64, + "max_position_embeddings": 2048, + "mlp_bias": false, + "model_type": "llama", + "num_attention_heads": 4, + "num_hidden_layers": 2, + "num_key_value_heads": 4, + "pad_token_id": -1, + "pretraining_tp": 1, + "rms_norm_eps": 1e-06, + "rope_parameters": {"rope_theta": 10000.0, "rope_type": "default"}, + "tie_word_embeddings": false, + "transformers_version": "5.2.0.dev0", + "use_cache": true, + "vocab_size": 32000, + } + )