From be06dcce86a6f0a3195db8e292ccf1760067880b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 23 Feb 2026 16:17:18 +0100 Subject: [PATCH 01/10] add remove_inputs to InputObserver --- .../examples/plot_export_with_modelbuilder.py | 111 ++++++++++++++++++ onnx_diagnostic/export/api.py | 12 +- onnx_diagnostic/investigate/input_observer.py | 55 +++++++++ 3 files changed, 177 insertions(+), 1 deletion(-) create mode 100644 _doc/examples/plot_export_with_modelbuilder.py diff --git a/_doc/examples/plot_export_with_modelbuilder.py b/_doc/examples/plot_export_with_modelbuilder.py new file mode 100644 index 00000000..5d877628 --- /dev/null +++ b/_doc/examples/plot_export_with_modelbuilder.py @@ -0,0 +1,111 @@ +""" +.. _l-plot-export-model-builder: + +Export with ModelBuilder +======================== + +""" + +import os +import pandas +from transformers import AutoModelForCausalLM, AutoTokenizer +from onnx_diagnostic import doc +from onnx_diagnostic.investigate.input_observer import InputObserver +from onnx_diagnostic.helpers.rt_helper import onnx_generate +from onnx_diagnostic.torch_export_patches import ( + register_additional_serialization_functions, + torch_export_patches, +) +from onnx_diagnostic.export.api import to_onnx + + +def generate_text( + prompt, + model, + tokenizer, + max_length=50, + temperature=0.01, + top_k=50, + top_p=0.95, + do_sample=True, +): + inputs = tokenizer(prompt, return_tensors="pt") + input_ids = inputs["input_ids"] + attention_mask = inputs["attention_mask"] + + outputs = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_length=max_length, + temperature=temperature, + top_k=top_k, + top_p=top_p, + do_sample=do_sample, + ) + + generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) + return generated_text + + +# %% +# Creating the model +print("-- creating...") +MODEL_NAME = "arnir0/Tiny-LLM" +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) + + +# %% +# Capturing inputs/outputs to infer dynamic shapes and arguments +print("-- capturing...") +prompt = "Continue: it rains, what should I do?" +observer = InputObserver() +with register_additional_serialization_functions(patch_transformers=True), observer(model): + generate_text(prompt, model, tokenizer) + + +# %% +# Exporting. +print("-- exporting...") +observer.remove_inputs(["cache_position", "logits_to_keep", "position_ids"]) +ds = observer.infer_dynamic_shapes(set_batch_dimension_for=True) +kwargs = observer.infer_arguments() + +cache_dir = "dump_modelbuilder" +os.makedirs(cache_dir, exist_ok=True) +filename = os.path.join(cache_dir, "plot_export_with_modelbuilder.onnx") +with torch_export_patches(patch_transformers=True): + to_onnx( + model, + filename=filename, + kwargs=kwargs, + dynamic_shapes=ds, + exporter="modelbuilder", + ) + +data = observer.check_discrepancies(filename, progress_bar=True) +print(pandas.DataFrame(data)) + +# %% +# ONNX Prompt +# +++++++++++ +print("-- ONNX prompts...") +inputs = tokenizer(prompt, return_tensors="pt") +input_ids = inputs["input_ids"] +attention_mask = inputs["attention_mask"] + +onnx_tokens = onnx_generate( + filename, + input_ids=input_ids, + attention_mask=attention_mask, + eos_token_id=model.config.eos_token_id, + max_new_tokens=50, +) +onnx_generated_text = tokenizer.decode(onnx_tokens, skip_special_tokens=True) + +print("-----------------") +print("\n".join(onnx_generated_text)) +print("-----------------") + +# %% +doc.save_fig(doc.plot_dot(filename), f"{filename}.png", dpi=400) diff --git a/onnx_diagnostic/export/api.py b/onnx_diagnostic/export/api.py index 03898e78..58906979 100644 --- a/onnx_diagnostic/export/api.py +++ b/onnx_diagnostic/export/api.py @@ -4,6 +4,7 @@ import time from collections.abc import Mapping, Iterable from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union +import onnx import torch from .dynamic_shapes import ModelInputs from .onnx_plug import EagerDirectReplacementWithOnnx @@ -312,10 +313,19 @@ def to_onnx( mod, precision=str(first_float[0].dtype).split(".")[-1], execution_provider="cuda" if first.is_cuda else "cpu", - cache_dir=os.path.dirname(filename), + cache_dir=os.path.dirname(filename) or ".", **(exporter_kwargs or {}), ) save_model_builder(onx, os.path.dirname(filename)) + temp_filename = os.path.join(os.path.dirname(filename), "model.onnx") + # renaming + onx = onnx.load(temp_filename, load_external_data=True) + onnx.save( + onx, + filename, + save_as_external_data=True, + location=f"{os.path.splitext(filename[0])}.data", + ) return onx raise ValueError(f"Unknown exporter={exporter!r}") diff --git a/onnx_diagnostic/investigate/input_observer.py b/onnx_diagnostic/investigate/input_observer.py index 5004f466..ddec3fba 100644 --- a/onnx_diagnostic/investigate/input_observer.py +++ b/onnx_diagnostic/investigate/input_observer.py @@ -145,6 +145,25 @@ def __init__( self.aligned_spec: torch.utils._pytree.PyTreeSpec | None = None self.aligned_flat_list: list[torch.Tensor | None] | None = None + def remove_inputs(self, input_names: Sequence[str]): + """Removes inputs.""" + for name_or_pos in sorted(input_names, reverse=True): + if isinstance(name_or_pos, int): + if name_or_pos in self.args: + del self.args[name_or_pos] + else: + if name_or_pos in self.kwargs: + del self.kwargs[name_or_pos] + elif name_or_pos in self.cst_kwargs: + del self.cst_kwargs[name_or_pos] + + # remove any temporary structures + self.flat_list, self.spec = torch.utils._pytree.tree_flatten((self.args, self.kwargs)) + self._position_to_args_kwargs: list[int | str] | None = None + self._n_tensors_for_args_kwargs: dict[int | str, int] | None = None + self.aligned_spec: torch.utils._pytree.PyTreeSpec | None = None + self.aligned_flat_list: list[torch.Tensor | None] | None = None + def __str__(self) -> str: return ( f"{self.__class__.__name__}({len(self.args)} args, " @@ -811,6 +830,38 @@ def _post_process_for_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]: ) return {**new_kwargs, self.kwargs_name: keywords} + def remove_inputs(self, input_names: Sequence[str | int]): + """Lets the users drops inputs.""" + if ( + self.args_name_and_position is not None + and self.args_name_and_position in input_names + ): + raise ValueError(f"Cannot remove variadic {self.args_name_and_position}") + if self.kwargs_name is not None and self.kwargs_name in input_names: + raise ValueError(f"Cannot remove variadic {self.kwargs_name}") + for candidate in self.inputs: + candidate.remove_inputs(input_names) + if self._best_candidate: + self._best_candidate.remove_inputs(input_names) + + for name_or_pos in sorted(input_names, reverse=True): + if self.default_values and name_or_pos in self.default_values: + del self.default_values[name_or_pos] + if self.value_if_missing and name_or_pos in self.value_if_missing: + del self.value_if_missing[name_or_pos] + if self._captured_inputs and name_or_pos in self._captured_inputs: + del self._captured_inputs[name_or_pos] + + assert ( + not self.args_name_and_position + ), f"Not implemented when {self.args_name_and_position=}" + input_names_str = { + (self.signature_name[i] if isinstance(i, int) else i) for i in input_names + } + self.signature_names = [ + name for name in self.signature_names if name not in input_names_str + ] + class InputObserver: """Steals forward method to collect inputs and outputs. @@ -1209,3 +1260,7 @@ def check_discrepancies( diff["outputs_ort"] = string_type(ort_outputs, with_shape=True) data.append(diff) return data + + def remove_inputs(self, input_names: Sequence[str | int]): + """Lets the users drops inputs.""" + self.info.remove_inputs(input_names) From a5fedd6dab1ffcc0987a9ba2aa6ae8b4fc982608 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 23 Feb 2026 16:18:50 +0100 Subject: [PATCH 02/10] changes --- CHANGELOGS.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 37506809..154b4fd1 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,9 @@ Change Logs 0.9.3 +++++ +* :pr:`422`: add remove_inputs to InputObserver +* :pr:`421`: fix a few patches for MoE + 0.9.2 +++++ From 48aec1585806b9375b97b4e29a4e89cc54707ab1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 23 Feb 2026 16:19:41 +0100 Subject: [PATCH 03/10] mypy --- onnx_diagnostic/investigate/input_observer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnx_diagnostic/investigate/input_observer.py b/onnx_diagnostic/investigate/input_observer.py index ddec3fba..2cc4d9c1 100644 --- a/onnx_diagnostic/investigate/input_observer.py +++ b/onnx_diagnostic/investigate/input_observer.py @@ -145,7 +145,7 @@ def __init__( self.aligned_spec: torch.utils._pytree.PyTreeSpec | None = None self.aligned_flat_list: list[torch.Tensor | None] | None = None - def remove_inputs(self, input_names: Sequence[str]): + def remove_inputs(self, input_names: Sequence[str | int]): """Removes inputs.""" for name_or_pos in sorted(input_names, reverse=True): if isinstance(name_or_pos, int): From 795fc3b4a0a1eaf5ea744dda6cca34603572cbe5 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Mon, 23 Feb 2026 17:08:58 +0100 Subject: [PATCH 04/10] Add unit tests for `remove_inputs` in `InputObserver` (#423) * Initial plan * Add unit tests for remove_inputs in InputObserver Co-authored-by: xadupre <22452781+xadupre@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: xadupre <22452781+xadupre@users.noreply.github.com> --- .../ut_investigate/test_input_observer.py | 107 ++++++++++++++++++ 1 file changed, 107 insertions(+) diff --git a/_unittests/ut_investigate/test_input_observer.py b/_unittests/ut_investigate/test_input_observer.py index 056c4765..2f415d2e 100644 --- a/_unittests/ut_investigate/test_input_observer.py +++ b/_unittests/ut_investigate/test_input_observer.py @@ -1197,5 +1197,112 @@ def forward(self, a, *args, **kwargs): torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=ds) + def test_remove_inputs_kwargs(self): + """Test that remove_inputs removes a kwarg from the observer info.""" + + class Model(torch.nn.Module): + def forward(self, x, y, z=None): + r = x + y + if z is not None: + r += z + return r + + inputs = [ + dict(x=torch.randn((5, 6)), y=torch.randn((1, 6)), z=torch.randn((5, 6))), + dict(x=torch.randn((7, 7)), y=torch.randn((1, 7)), z=torch.randn((7, 7))), + dict(x=torch.randn((7, 8)), y=torch.randn((1, 8)), z=torch.randn((7, 8))), + ] + + model = Model() + observer = InputObserver() + with observer(model): + for kwargs in inputs: + model(**kwargs) + self.assertEqual(len(observer.info), 3) + + cst = torch.export.Dim.DYNAMIC + ds = observer.infer_dynamic_shapes() + self.assertIn("z", ds) + self.assertIn("x", ds) + self.assertIn("y", ds) + + # Remove z input + observer.remove_inputs(["z"]) + + ds_after = observer.infer_dynamic_shapes() + self.assertNotIn("z", ds_after) + self.assertIn("x", ds_after) + self.assertIn("y", ds_after) + self.assertEqual(dict(x={0: cst, 1: cst}, y={1: cst}), ds_after) + + args_after = observer.infer_arguments() + self.assertIsInstance(args_after, dict) + self.assertNotIn("z", args_after) + self.assertIn("x", args_after) + self.assertIn("y", args_after) + + def test_remove_inputs_multiple_kwargs(self): + """Test that remove_inputs removes multiple kwargs at once.""" + + class Model(torch.nn.Module): + def forward(self, x, y, z=None, w=None): + r = x + y + if z is not None: + r += z + if w is not None: + r += w + return r + + inputs = [ + dict( + x=torch.randn((5, 6)), + y=torch.randn((1, 6)), + z=torch.randn((5, 6)), + w=torch.randn((1, 6)), + ), + dict( + x=torch.randn((6, 7)), + y=torch.randn((1, 7)), + z=torch.randn((6, 7)), + w=torch.randn((1, 7)), + ), + dict( + x=torch.randn((7, 8)), + y=torch.randn((1, 8)), + z=torch.randn((7, 8)), + w=torch.randn((1, 8)), + ), + ] + + model = Model() + observer = InputObserver() + with observer(model): + for kwargs in inputs: + model(**kwargs) + self.assertEqual(len(observer.info), 3) + + cst = torch.export.Dim.DYNAMIC + ds = observer.infer_dynamic_shapes() + self.assertIn("z", ds) + self.assertIn("w", ds) + + # Remove z and w inputs + observer.remove_inputs(["z", "w"]) + + ds_after = observer.infer_dynamic_shapes() + self.assertNotIn("z", ds_after) + self.assertNotIn("w", ds_after) + self.assertIn("x", ds_after) + self.assertIn("y", ds_after) + self.assertEqual(dict(x={0: cst, 1: cst}, y={1: cst}), ds_after) + + args_after = observer.infer_arguments() + self.assertIsInstance(args_after, dict) + self.assertNotIn("z", args_after) + self.assertNotIn("w", args_after) + self.assertIn("x", args_after) + self.assertIn("y", args_after) + + if __name__ == "__main__": unittest.main(verbosity=2) From 9ea2e5f2198695a32fc15ad8feec087117f5f306 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 23 Feb 2026 16:18:43 +0000 Subject: [PATCH 05/10] a few fixes --- .../examples/plot_export_with_modelbuilder.py | 81 +++++++++++-------- .../ut_investigate/test_input_observer.py | 1 - onnx_diagnostic/export/api.py | 7 +- 3 files changed, 50 insertions(+), 39 deletions(-) diff --git a/_doc/examples/plot_export_with_modelbuilder.py b/_doc/examples/plot_export_with_modelbuilder.py index 5d877628..be96bd66 100644 --- a/_doc/examples/plot_export_with_modelbuilder.py +++ b/_doc/examples/plot_export_with_modelbuilder.py @@ -6,9 +6,11 @@ """ +import sys import os import pandas -from transformers import AutoModelForCausalLM, AutoTokenizer +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig from onnx_diagnostic import doc from onnx_diagnostic.investigate.input_observer import InputObserver from onnx_diagnostic.helpers.rt_helper import onnx_generate @@ -28,10 +30,11 @@ def generate_text( top_k=50, top_p=0.95, do_sample=True, + device="cpu", ): inputs = tokenizer(prompt, return_tensors="pt") - input_ids = inputs["input_ids"] - attention_mask = inputs["attention_mask"] + input_ids = inputs["input_ids"].to(device) + attention_mask = inputs["attention_mask"].to(device) outputs = model.generate( input_ids=input_ids, @@ -47,58 +50,71 @@ def generate_text( return generated_text +# %% +# filename for the model +MODEL_NAME = sys.argv[1] if sys.argv and len(sys.argv) > 1 else "arnir0/Tiny-LLM" +cache_dir = "dump_modelbuilder" +os.makedirs(cache_dir, exist_ok=True) +name = MODEL_NAME.replace("/", "_") +filename = os.path.join(cache_dir, f"plot_export_with_modelbuilder_{name}.onnx") + + # %% # Creating the model -print("-- creating...") -MODEL_NAME = "arnir0/Tiny-LLM" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) -model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) +if not os.path.exists(filename): + print(f"-- creating... on {device} into {filename!r}") + model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16) + model = model.to(device) + config = model.config +else: + config = AutoConfig.from_pretrained(MODEL_NAME) # %% # Capturing inputs/outputs to infer dynamic shapes and arguments print("-- capturing...") prompt = "Continue: it rains, what should I do?" -observer = InputObserver() -with register_additional_serialization_functions(patch_transformers=True), observer(model): - generate_text(prompt, model, tokenizer) +if not os.path.exists(filename): + observer = InputObserver() + with register_additional_serialization_functions(patch_transformers=True), observer(model): + generate_text(prompt, model, tokenizer, device=device) # %% # Exporting. -print("-- exporting...") -observer.remove_inputs(["cache_position", "logits_to_keep", "position_ids"]) -ds = observer.infer_dynamic_shapes(set_batch_dimension_for=True) -kwargs = observer.infer_arguments() - -cache_dir = "dump_modelbuilder" -os.makedirs(cache_dir, exist_ok=True) -filename = os.path.join(cache_dir, "plot_export_with_modelbuilder.onnx") -with torch_export_patches(patch_transformers=True): - to_onnx( - model, - filename=filename, - kwargs=kwargs, - dynamic_shapes=ds, - exporter="modelbuilder", - ) - -data = observer.check_discrepancies(filename, progress_bar=True) -print(pandas.DataFrame(data)) +if not os.path.exists(filename): + print("-- exporting...") + observer.remove_inputs(["cache_position", "logits_to_keep", "position_ids"]) + ds = observer.infer_dynamic_shapes(set_batch_dimension_for=True) + kwargs = observer.infer_arguments() + + with torch_export_patches(patch_transformers=True): + to_onnx( + model, + filename=filename, + kwargs=kwargs, + dynamic_shapes=ds, + exporter="modelbuilder", + ) + + data = observer.check_discrepancies(filename, progress_bar=True) + print(pandas.DataFrame(data)) # %% # ONNX Prompt # +++++++++++ print("-- ONNX prompts...") inputs = tokenizer(prompt, return_tensors="pt") -input_ids = inputs["input_ids"] -attention_mask = inputs["attention_mask"] +input_ids = inputs["input_ids"].to(device) +attention_mask = inputs["attention_mask"].to(device) onnx_tokens = onnx_generate( filename, input_ids=input_ids, attention_mask=attention_mask, - eos_token_id=model.config.eos_token_id, + eos_token_id=config.eos_token_id, max_new_tokens=50, ) onnx_generated_text = tokenizer.decode(onnx_tokens, skip_special_tokens=True) @@ -108,4 +124,5 @@ def generate_text( print("-----------------") # %% -doc.save_fig(doc.plot_dot(filename), f"{filename}.png", dpi=400) +if os.stat(filename).st_size < 2**14: + doc.save_fig(doc.plot_dot(filename), f"{filename}.png", dpi=400) diff --git a/_unittests/ut_investigate/test_input_observer.py b/_unittests/ut_investigate/test_input_observer.py index 2f415d2e..cbd90623 100644 --- a/_unittests/ut_investigate/test_input_observer.py +++ b/_unittests/ut_investigate/test_input_observer.py @@ -1196,7 +1196,6 @@ def forward(self, a, *args, **kwargs): ) torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=ds) - def test_remove_inputs_kwargs(self): """Test that remove_inputs removes a kwarg from the observer info.""" diff --git a/onnx_diagnostic/export/api.py b/onnx_diagnostic/export/api.py index 58906979..a6b5816b 100644 --- a/onnx_diagnostic/export/api.py +++ b/onnx_diagnostic/export/api.py @@ -320,12 +320,7 @@ def to_onnx( temp_filename = os.path.join(os.path.dirname(filename), "model.onnx") # renaming onx = onnx.load(temp_filename, load_external_data=True) - onnx.save( - onx, - filename, - save_as_external_data=True, - location=f"{os.path.splitext(filename[0])}.data", - ) + onnx.save(onx, filename, save_as_external_data=True) return onx raise ValueError(f"Unknown exporter={exporter!r}") From e7a83608f565d8fdb08f958364f7d2145b067aca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 23 Feb 2026 22:34:23 +0100 Subject: [PATCH 06/10] Update onnx_diagnostic/investigate/input_observer.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnx_diagnostic/investigate/input_observer.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/onnx_diagnostic/investigate/input_observer.py b/onnx_diagnostic/investigate/input_observer.py index 2cc4d9c1..5262ca42 100644 --- a/onnx_diagnostic/investigate/input_observer.py +++ b/onnx_diagnostic/investigate/input_observer.py @@ -147,16 +147,22 @@ def __init__( def remove_inputs(self, input_names: Sequence[str | int]): """Removes inputs.""" + # Work on a mutable copy of positional arguments. + args_list = list(self.args) + for name_or_pos in sorted(input_names, reverse=True): if isinstance(name_or_pos, int): - if name_or_pos in self.args: - del self.args[name_or_pos] + idx = name_or_pos + if 0 <= idx < len(args_list): + del args_list[idx] else: if name_or_pos in self.kwargs: del self.kwargs[name_or_pos] elif name_or_pos in self.cst_kwargs: del self.cst_kwargs[name_or_pos] + # Update stored positional arguments. + self.args = tuple(args_list) # remove any temporary structures self.flat_list, self.spec = torch.utils._pytree.tree_flatten((self.args, self.kwargs)) self._position_to_args_kwargs: list[int | str] | None = None From 25ccb7cf4fd1388a3da923568f071b0c178b7180 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 23 Feb 2026 23:31:25 +0100 Subject: [PATCH 07/10] fix --- _doc/examples/plot_export_with_modelbuilder.py | 2 +- _doc/technical/plot_generate.py | 2 +- _unittests/ut_tasks/try_tasks.py | 10 +++++----- _unittests/ut_xrun_doc/test_documentation_examples.py | 3 +++ onnx_diagnostic/ci_models/export_phi4_mm.py | 2 +- onnx_diagnostic/investigate/input_observer.py | 10 ++++++++-- 6 files changed, 19 insertions(+), 10 deletions(-) diff --git a/_doc/examples/plot_export_with_modelbuilder.py b/_doc/examples/plot_export_with_modelbuilder.py index be96bd66..0ff4ea39 100644 --- a/_doc/examples/plot_export_with_modelbuilder.py +++ b/_doc/examples/plot_export_with_modelbuilder.py @@ -65,7 +65,7 @@ def generate_text( tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) if not os.path.exists(filename): print(f"-- creating... on {device} into {filename!r}") - model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16) + model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype=torch.bfloat16) model = model.to(device) config = model.config else: diff --git a/_doc/technical/plot_generate.py b/_doc/technical/plot_generate.py index c01547d7..c40512ac 100644 --- a/_doc/technical/plot_generate.py +++ b/_doc/technical/plot_generate.py @@ -47,7 +47,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id) else: model_id = "microsoft/phi-1_5" - model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto") + model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto") tokenizer = AutoTokenizer.from_pretrained(model_id) config = get_pretrained_config(model_id) task = task = task_from_id(model_id) diff --git a/_unittests/ut_tasks/try_tasks.py b/_unittests/ut_tasks/try_tasks.py index 68cd9d77..e47aeed0 100644 --- a/_unittests/ut_tasks/try_tasks.py +++ b/_unittests/ut_tasks/try_tasks.py @@ -263,7 +263,7 @@ def test_text_generation_phi4_moe(self): model = AutoModelForCausalLM.from_pretrained( model_path, device_map="cuda", - torch_dtype="auto", + dtype="auto", trust_remote_code=True, # if you do not use Ampere or later GPUs, change attention to "eager" # _attn_implementation='flash_attention_2', @@ -352,7 +352,7 @@ def test_imagetext2text_generation_idefics(self): mid = "HuggingFaceM4/tiny-random-idefics" processor = AutoProcessor.from_pretrained(mid) model = IdeficsForVisionText2Text.from_pretrained( - mid, torch_dtype=torch.bfloat16, device_map="auto" + mid, dtype=torch.bfloat16, device_map="auto" ) prompt = [ @@ -699,7 +699,7 @@ def test_falcon_mamba_dev(self): "text-generation", model=model, tokenizer=tokenizer, - torch_dtype=torch.bfloat16, + dtype=torch.bfloat16, trust_remote_code=True, device_map="auto", ) @@ -736,7 +736,7 @@ def test_falcon_mamba_7b(self): "text-generation", model=model, tokenizer=tokenizer, - torch_dtype=torch.bfloat16, + dtype=torch.bfloat16, trust_remote_code=True, device_map="auto", ) @@ -802,7 +802,7 @@ def test_text_to_image(self): from diffusers import StableDiffusionPipeline model_id = "diffusers/tiny-torch-full-checker" # "stabilityai/stable-diffusion-2" - pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to( + pipe = StableDiffusionPipeline.from_pretrained(model_id, dtype=torch.float16).to( "cuda" ) diff --git a/_unittests/ut_xrun_doc/test_documentation_examples.py b/_unittests/ut_xrun_doc/test_documentation_examples.py index 7817472a..033c1705 100644 --- a/_unittests/ut_xrun_doc/test_documentation_examples.py +++ b/_unittests/ut_xrun_doc/test_documentation_examples.py @@ -85,6 +85,9 @@ def add_test_methods(cls): # transformers + if not reason and not has_dot and name in {"plot_export_with_modelbuilder.py"}: + reason = "downloading" + if ( not reason and name in {"plot_export_tiny_llm.py", "plot_export_tiny_llm_patched.py"} diff --git a/onnx_diagnostic/ci_models/export_phi4_mm.py b/onnx_diagnostic/ci_models/export_phi4_mm.py index 4063e8af..462fe4b0 100644 --- a/onnx_diagnostic/ci_models/export_phi4_mm.py +++ b/onnx_diagnostic/ci_models/export_phi4_mm.py @@ -794,7 +794,7 @@ def main( model_id, config=config, trust_remote_code=True, - torch_dtype=torch_dtype, + dtype=torch_dtype, device_map=device, attn_implementation="sdpa", ).eval() diff --git a/onnx_diagnostic/investigate/input_observer.py b/onnx_diagnostic/investigate/input_observer.py index 5262ca42..3f17aa5d 100644 --- a/onnx_diagnostic/investigate/input_observer.py +++ b/onnx_diagnostic/investigate/input_observer.py @@ -851,7 +851,11 @@ def remove_inputs(self, input_names: Sequence[str | int]): self._best_candidate.remove_inputs(input_names) for name_or_pos in sorted(input_names, reverse=True): - if self.default_values and name_or_pos in self.default_values: + if ( + isinstance(name_or_pos, str) + and self.default_values + and name_or_pos in self.default_values + ): del self.default_values[name_or_pos] if self.value_if_missing and name_or_pos in self.value_if_missing: del self.value_if_missing[name_or_pos] @@ -862,7 +866,7 @@ def remove_inputs(self, input_names: Sequence[str | int]): not self.args_name_and_position ), f"Not implemented when {self.args_name_and_position=}" input_names_str = { - (self.signature_name[i] if isinstance(i, int) else i) for i in input_names + (self.signature_names[i] if isinstance(i, int) else i) for i in input_names } self.signature_names = [ name for name in self.signature_names if name not in input_names_str @@ -1269,4 +1273,6 @@ def check_discrepancies( def remove_inputs(self, input_names: Sequence[str | int]): """Lets the users drops inputs.""" + if self.info is None: + raise RuntimeError("No input was captured.") self.info.remove_inputs(input_names) From e05311d301d8ad51206442e440cb8363acf660df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 23 Feb 2026 23:35:36 +0100 Subject: [PATCH 08/10] style --- onnx_diagnostic/investigate/input_observer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnx_diagnostic/investigate/input_observer.py b/onnx_diagnostic/investigate/input_observer.py index 3f17aa5d..94796d32 100644 --- a/onnx_diagnostic/investigate/input_observer.py +++ b/onnx_diagnostic/investigate/input_observer.py @@ -165,10 +165,10 @@ def remove_inputs(self, input_names: Sequence[str | int]): self.args = tuple(args_list) # remove any temporary structures self.flat_list, self.spec = torch.utils._pytree.tree_flatten((self.args, self.kwargs)) - self._position_to_args_kwargs: list[int | str] | None = None - self._n_tensors_for_args_kwargs: dict[int | str, int] | None = None - self.aligned_spec: torch.utils._pytree.PyTreeSpec | None = None - self.aligned_flat_list: list[torch.Tensor | None] | None = None + self._position_to_args_kwargs = None + self._n_tensors_for_args_kwargs = None + self.aligned_spec = None + self.aligned_flat_list = None def __str__(self) -> str: return ( From 30278d155a2dfa016035295eb089106a3638a6e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 24 Feb 2026 00:53:38 +0100 Subject: [PATCH 09/10] update doc --- _doc/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/_doc/conf.py b/_doc/conf.py index 86aa5461..4ff33d80 100644 --- a/_doc/conf.py +++ b/_doc/conf.py @@ -212,7 +212,7 @@ def linkcode_resolve(domain, info): if int(os.environ.get("UNITTEST_GOING", "0")): sphinx_gallery_conf["ignore_pattern"] = ( - ".*((tiny_llm)|(dort)|(draft_mode)|(hub_codellama.py)|(whisper)|(optimind)).*" + ".*((tiny_llm)|(dort)|(draft_mode)|(hub_codellama.py)|(whisper)|(optimind)|(export_with_modelbuilder)).*" ) elif pv.Version(torch.__version__) < pv.Version("2.8"): sphinx_gallery_conf["ignore_pattern"] = ".*((_oe_)|(dort)|(draft_mode)).*" From 885ea7f414e5d10c5108ba423cf624d538b7b749 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 24 Feb 2026 11:22:43 +0100 Subject: [PATCH 10/10] fix documentation --- _unittests/ut_xrun_doc/test_documentation_examples.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/_unittests/ut_xrun_doc/test_documentation_examples.py b/_unittests/ut_xrun_doc/test_documentation_examples.py index 033c1705..c5f0f1df 100644 --- a/_unittests/ut_xrun_doc/test_documentation_examples.py +++ b/_unittests/ut_xrun_doc/test_documentation_examples.py @@ -85,7 +85,7 @@ def add_test_methods(cls): # transformers - if not reason and not has_dot and name in {"plot_export_with_modelbuilder.py"}: + if not reason and name in {"plot_export_with_modelbuilder.py"}: reason = "downloading" if (