diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 37506809..154b4fd1 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,9 @@ Change Logs 0.9.3 +++++ +* :pr:`422`: add remove_inputs to InputObserver +* :pr:`421`: fix a few patches for MoE + 0.9.2 +++++ diff --git a/_doc/conf.py b/_doc/conf.py index 86aa5461..4ff33d80 100644 --- a/_doc/conf.py +++ b/_doc/conf.py @@ -212,7 +212,7 @@ def linkcode_resolve(domain, info): if int(os.environ.get("UNITTEST_GOING", "0")): sphinx_gallery_conf["ignore_pattern"] = ( - ".*((tiny_llm)|(dort)|(draft_mode)|(hub_codellama.py)|(whisper)|(optimind)).*" + ".*((tiny_llm)|(dort)|(draft_mode)|(hub_codellama.py)|(whisper)|(optimind)|(export_with_modelbuilder)).*" ) elif pv.Version(torch.__version__) < pv.Version("2.8"): sphinx_gallery_conf["ignore_pattern"] = ".*((_oe_)|(dort)|(draft_mode)).*" diff --git a/_doc/examples/plot_export_with_modelbuilder.py b/_doc/examples/plot_export_with_modelbuilder.py new file mode 100644 index 00000000..0ff4ea39 --- /dev/null +++ b/_doc/examples/plot_export_with_modelbuilder.py @@ -0,0 +1,128 @@ +""" +.. _l-plot-export-model-builder: + +Export with ModelBuilder +======================== + +""" + +import sys +import os +import pandas +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig +from onnx_diagnostic import doc +from onnx_diagnostic.investigate.input_observer import InputObserver +from onnx_diagnostic.helpers.rt_helper import onnx_generate +from onnx_diagnostic.torch_export_patches import ( + register_additional_serialization_functions, + torch_export_patches, +) +from onnx_diagnostic.export.api import to_onnx + + +def generate_text( + prompt, + model, + tokenizer, + max_length=50, + temperature=0.01, + top_k=50, + top_p=0.95, + do_sample=True, + device="cpu", +): + inputs = tokenizer(prompt, return_tensors="pt") + input_ids = inputs["input_ids"].to(device) + attention_mask = inputs["attention_mask"].to(device) + + outputs = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_length=max_length, + temperature=temperature, + top_k=top_k, + top_p=top_p, + do_sample=do_sample, + ) + + generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) + return generated_text + + +# %% +# filename for the model +MODEL_NAME = sys.argv[1] if sys.argv and len(sys.argv) > 1 else "arnir0/Tiny-LLM" +cache_dir = "dump_modelbuilder" +os.makedirs(cache_dir, exist_ok=True) +name = MODEL_NAME.replace("/", "_") +filename = os.path.join(cache_dir, f"plot_export_with_modelbuilder_{name}.onnx") + + +# %% +# Creating the model +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +if not os.path.exists(filename): + print(f"-- creating... on {device} into {filename!r}") + model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype=torch.bfloat16) + model = model.to(device) + config = model.config +else: + config = AutoConfig.from_pretrained(MODEL_NAME) + + +# %% +# Capturing inputs/outputs to infer dynamic shapes and arguments +print("-- capturing...") +prompt = "Continue: it rains, what should I do?" +if not os.path.exists(filename): + observer = InputObserver() + with register_additional_serialization_functions(patch_transformers=True), observer(model): + generate_text(prompt, model, tokenizer, device=device) + + +# %% +# Exporting. +if not os.path.exists(filename): + print("-- exporting...") + observer.remove_inputs(["cache_position", "logits_to_keep", "position_ids"]) + ds = observer.infer_dynamic_shapes(set_batch_dimension_for=True) + kwargs = observer.infer_arguments() + + with torch_export_patches(patch_transformers=True): + to_onnx( + model, + filename=filename, + kwargs=kwargs, + dynamic_shapes=ds, + exporter="modelbuilder", + ) + + data = observer.check_discrepancies(filename, progress_bar=True) + print(pandas.DataFrame(data)) + +# %% +# ONNX Prompt +# +++++++++++ +print("-- ONNX prompts...") +inputs = tokenizer(prompt, return_tensors="pt") +input_ids = inputs["input_ids"].to(device) +attention_mask = inputs["attention_mask"].to(device) + +onnx_tokens = onnx_generate( + filename, + input_ids=input_ids, + attention_mask=attention_mask, + eos_token_id=config.eos_token_id, + max_new_tokens=50, +) +onnx_generated_text = tokenizer.decode(onnx_tokens, skip_special_tokens=True) + +print("-----------------") +print("\n".join(onnx_generated_text)) +print("-----------------") + +# %% +if os.stat(filename).st_size < 2**14: + doc.save_fig(doc.plot_dot(filename), f"{filename}.png", dpi=400) diff --git a/_doc/technical/plot_generate.py b/_doc/technical/plot_generate.py index c01547d7..c40512ac 100644 --- a/_doc/technical/plot_generate.py +++ b/_doc/technical/plot_generate.py @@ -47,7 +47,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id) else: model_id = "microsoft/phi-1_5" - model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto") + model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto") tokenizer = AutoTokenizer.from_pretrained(model_id) config = get_pretrained_config(model_id) task = task = task_from_id(model_id) diff --git a/_unittests/ut_investigate/test_input_observer.py b/_unittests/ut_investigate/test_input_observer.py index 056c4765..cbd90623 100644 --- a/_unittests/ut_investigate/test_input_observer.py +++ b/_unittests/ut_investigate/test_input_observer.py @@ -1196,6 +1196,112 @@ def forward(self, a, *args, **kwargs): ) torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=ds) + def test_remove_inputs_kwargs(self): + """Test that remove_inputs removes a kwarg from the observer info.""" + + class Model(torch.nn.Module): + def forward(self, x, y, z=None): + r = x + y + if z is not None: + r += z + return r + + inputs = [ + dict(x=torch.randn((5, 6)), y=torch.randn((1, 6)), z=torch.randn((5, 6))), + dict(x=torch.randn((7, 7)), y=torch.randn((1, 7)), z=torch.randn((7, 7))), + dict(x=torch.randn((7, 8)), y=torch.randn((1, 8)), z=torch.randn((7, 8))), + ] + + model = Model() + observer = InputObserver() + with observer(model): + for kwargs in inputs: + model(**kwargs) + self.assertEqual(len(observer.info), 3) + + cst = torch.export.Dim.DYNAMIC + ds = observer.infer_dynamic_shapes() + self.assertIn("z", ds) + self.assertIn("x", ds) + self.assertIn("y", ds) + + # Remove z input + observer.remove_inputs(["z"]) + + ds_after = observer.infer_dynamic_shapes() + self.assertNotIn("z", ds_after) + self.assertIn("x", ds_after) + self.assertIn("y", ds_after) + self.assertEqual(dict(x={0: cst, 1: cst}, y={1: cst}), ds_after) + + args_after = observer.infer_arguments() + self.assertIsInstance(args_after, dict) + self.assertNotIn("z", args_after) + self.assertIn("x", args_after) + self.assertIn("y", args_after) + + def test_remove_inputs_multiple_kwargs(self): + """Test that remove_inputs removes multiple kwargs at once.""" + + class Model(torch.nn.Module): + def forward(self, x, y, z=None, w=None): + r = x + y + if z is not None: + r += z + if w is not None: + r += w + return r + + inputs = [ + dict( + x=torch.randn((5, 6)), + y=torch.randn((1, 6)), + z=torch.randn((5, 6)), + w=torch.randn((1, 6)), + ), + dict( + x=torch.randn((6, 7)), + y=torch.randn((1, 7)), + z=torch.randn((6, 7)), + w=torch.randn((1, 7)), + ), + dict( + x=torch.randn((7, 8)), + y=torch.randn((1, 8)), + z=torch.randn((7, 8)), + w=torch.randn((1, 8)), + ), + ] + + model = Model() + observer = InputObserver() + with observer(model): + for kwargs in inputs: + model(**kwargs) + self.assertEqual(len(observer.info), 3) + + cst = torch.export.Dim.DYNAMIC + ds = observer.infer_dynamic_shapes() + self.assertIn("z", ds) + self.assertIn("w", ds) + + # Remove z and w inputs + observer.remove_inputs(["z", "w"]) + + ds_after = observer.infer_dynamic_shapes() + self.assertNotIn("z", ds_after) + self.assertNotIn("w", ds_after) + self.assertIn("x", ds_after) + self.assertIn("y", ds_after) + self.assertEqual(dict(x={0: cst, 1: cst}, y={1: cst}), ds_after) + + args_after = observer.infer_arguments() + self.assertIsInstance(args_after, dict) + self.assertNotIn("z", args_after) + self.assertNotIn("w", args_after) + self.assertIn("x", args_after) + self.assertIn("y", args_after) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_tasks/try_tasks.py b/_unittests/ut_tasks/try_tasks.py index 68cd9d77..e47aeed0 100644 --- a/_unittests/ut_tasks/try_tasks.py +++ b/_unittests/ut_tasks/try_tasks.py @@ -263,7 +263,7 @@ def test_text_generation_phi4_moe(self): model = AutoModelForCausalLM.from_pretrained( model_path, device_map="cuda", - torch_dtype="auto", + dtype="auto", trust_remote_code=True, # if you do not use Ampere or later GPUs, change attention to "eager" # _attn_implementation='flash_attention_2', @@ -352,7 +352,7 @@ def test_imagetext2text_generation_idefics(self): mid = "HuggingFaceM4/tiny-random-idefics" processor = AutoProcessor.from_pretrained(mid) model = IdeficsForVisionText2Text.from_pretrained( - mid, torch_dtype=torch.bfloat16, device_map="auto" + mid, dtype=torch.bfloat16, device_map="auto" ) prompt = [ @@ -699,7 +699,7 @@ def test_falcon_mamba_dev(self): "text-generation", model=model, tokenizer=tokenizer, - torch_dtype=torch.bfloat16, + dtype=torch.bfloat16, trust_remote_code=True, device_map="auto", ) @@ -736,7 +736,7 @@ def test_falcon_mamba_7b(self): "text-generation", model=model, tokenizer=tokenizer, - torch_dtype=torch.bfloat16, + dtype=torch.bfloat16, trust_remote_code=True, device_map="auto", ) @@ -802,7 +802,7 @@ def test_text_to_image(self): from diffusers import StableDiffusionPipeline model_id = "diffusers/tiny-torch-full-checker" # "stabilityai/stable-diffusion-2" - pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to( + pipe = StableDiffusionPipeline.from_pretrained(model_id, dtype=torch.float16).to( "cuda" ) diff --git a/_unittests/ut_xrun_doc/test_documentation_examples.py b/_unittests/ut_xrun_doc/test_documentation_examples.py index 7817472a..c5f0f1df 100644 --- a/_unittests/ut_xrun_doc/test_documentation_examples.py +++ b/_unittests/ut_xrun_doc/test_documentation_examples.py @@ -85,6 +85,9 @@ def add_test_methods(cls): # transformers + if not reason and name in {"plot_export_with_modelbuilder.py"}: + reason = "downloading" + if ( not reason and name in {"plot_export_tiny_llm.py", "plot_export_tiny_llm_patched.py"} diff --git a/onnx_diagnostic/ci_models/export_phi4_mm.py b/onnx_diagnostic/ci_models/export_phi4_mm.py index 4063e8af..462fe4b0 100644 --- a/onnx_diagnostic/ci_models/export_phi4_mm.py +++ b/onnx_diagnostic/ci_models/export_phi4_mm.py @@ -794,7 +794,7 @@ def main( model_id, config=config, trust_remote_code=True, - torch_dtype=torch_dtype, + dtype=torch_dtype, device_map=device, attn_implementation="sdpa", ).eval() diff --git a/onnx_diagnostic/export/api.py b/onnx_diagnostic/export/api.py index 03898e78..a6b5816b 100644 --- a/onnx_diagnostic/export/api.py +++ b/onnx_diagnostic/export/api.py @@ -4,6 +4,7 @@ import time from collections.abc import Mapping, Iterable from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union +import onnx import torch from .dynamic_shapes import ModelInputs from .onnx_plug import EagerDirectReplacementWithOnnx @@ -312,10 +313,14 @@ def to_onnx( mod, precision=str(first_float[0].dtype).split(".")[-1], execution_provider="cuda" if first.is_cuda else "cpu", - cache_dir=os.path.dirname(filename), + cache_dir=os.path.dirname(filename) or ".", **(exporter_kwargs or {}), ) save_model_builder(onx, os.path.dirname(filename)) + temp_filename = os.path.join(os.path.dirname(filename), "model.onnx") + # renaming + onx = onnx.load(temp_filename, load_external_data=True) + onnx.save(onx, filename, save_as_external_data=True) return onx raise ValueError(f"Unknown exporter={exporter!r}") diff --git a/onnx_diagnostic/investigate/input_observer.py b/onnx_diagnostic/investigate/input_observer.py index 5004f466..94796d32 100644 --- a/onnx_diagnostic/investigate/input_observer.py +++ b/onnx_diagnostic/investigate/input_observer.py @@ -145,6 +145,31 @@ def __init__( self.aligned_spec: torch.utils._pytree.PyTreeSpec | None = None self.aligned_flat_list: list[torch.Tensor | None] | None = None + def remove_inputs(self, input_names: Sequence[str | int]): + """Removes inputs.""" + # Work on a mutable copy of positional arguments. + args_list = list(self.args) + + for name_or_pos in sorted(input_names, reverse=True): + if isinstance(name_or_pos, int): + idx = name_or_pos + if 0 <= idx < len(args_list): + del args_list[idx] + else: + if name_or_pos in self.kwargs: + del self.kwargs[name_or_pos] + elif name_or_pos in self.cst_kwargs: + del self.cst_kwargs[name_or_pos] + + # Update stored positional arguments. + self.args = tuple(args_list) + # remove any temporary structures + self.flat_list, self.spec = torch.utils._pytree.tree_flatten((self.args, self.kwargs)) + self._position_to_args_kwargs = None + self._n_tensors_for_args_kwargs = None + self.aligned_spec = None + self.aligned_flat_list = None + def __str__(self) -> str: return ( f"{self.__class__.__name__}({len(self.args)} args, " @@ -811,6 +836,42 @@ def _post_process_for_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]: ) return {**new_kwargs, self.kwargs_name: keywords} + def remove_inputs(self, input_names: Sequence[str | int]): + """Lets the users drops inputs.""" + if ( + self.args_name_and_position is not None + and self.args_name_and_position in input_names + ): + raise ValueError(f"Cannot remove variadic {self.args_name_and_position}") + if self.kwargs_name is not None and self.kwargs_name in input_names: + raise ValueError(f"Cannot remove variadic {self.kwargs_name}") + for candidate in self.inputs: + candidate.remove_inputs(input_names) + if self._best_candidate: + self._best_candidate.remove_inputs(input_names) + + for name_or_pos in sorted(input_names, reverse=True): + if ( + isinstance(name_or_pos, str) + and self.default_values + and name_or_pos in self.default_values + ): + del self.default_values[name_or_pos] + if self.value_if_missing and name_or_pos in self.value_if_missing: + del self.value_if_missing[name_or_pos] + if self._captured_inputs and name_or_pos in self._captured_inputs: + del self._captured_inputs[name_or_pos] + + assert ( + not self.args_name_and_position + ), f"Not implemented when {self.args_name_and_position=}" + input_names_str = { + (self.signature_names[i] if isinstance(i, int) else i) for i in input_names + } + self.signature_names = [ + name for name in self.signature_names if name not in input_names_str + ] + class InputObserver: """Steals forward method to collect inputs and outputs. @@ -1209,3 +1270,9 @@ def check_discrepancies( diff["outputs_ort"] = string_type(ort_outputs, with_shape=True) data.append(diff) return data + + def remove_inputs(self, input_names: Sequence[str | int]): + """Lets the users drops inputs.""" + if self.info is None: + raise RuntimeError("No input was captured.") + self.info.remove_inputs(input_names)