Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.9.2
+++++

* :pr:`415`: improves function make_model_with_local_functions to support ill-defined partitions
* :pr:`413`: fix InputObserver in the generic case
* :pr:`412`: patches for ViTModel (through rewriting)

Expand Down
10 changes: 9 additions & 1 deletion _unittests/ut_helpers/test_args_helper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import unittest
from onnx_diagnostic.ext_test_case import ExtTestCase
from onnx_diagnostic.helpers.args_helper import get_parsed_args, check_cuda_availability
from onnx_diagnostic.helpers.args_helper import (
get_parsed_args,
check_cuda_availability,
process_outputname,
)


class TestHelpers(ExtTestCase):
Expand Down Expand Up @@ -52,6 +56,10 @@ def test_args_expose(self):
self.assertEqual(args.repeat, 10)
self.assertEqual(args.warmup, 5)

def test_process_outputname(self):
self.assertEqual("ggg.g", process_outputname("ggg.g", "hhh.h"))
self.assertEqual("hhh.ggg.h", process_outputname("+.ggg", "hhh.h"))


if __name__ == "__main__":
unittest.main(verbosity=2)
71 changes: 67 additions & 4 deletions _unittests/ut_helpers/test_onnx_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,11 +794,8 @@ def test_make_model_with_local_functions_bug(self):
meta.key = "namespace"
meta.value = "LLL"
self.assertRaise(
lambda: make_model_with_local_functions(model, "^LLL$"),
lambda: make_model_with_local_functions(model, "^LLL$", allow_extensions=False),
ValueError,
"Results {'xu1'} are needed for inputs ['X', 'Y', 'shape1', "
"'shape2', 'xu2', 'zero'] but also requires ['xm1', 'xm2', 'xu1'] "
"which is not allowed.",
)
check_model(model)

Expand Down Expand Up @@ -860,6 +857,72 @@ def test_make_model_with_local_functions_2(self):

check_model(new_model)

@hide_stdout()
def test_make_model_with_local_functions_3(self):
model = oh.make_model(
oh.make_graph(
[
oh.make_node("Unsqueeze", ["X", "zero"], ["xu1"]),
oh.make_node("Unsqueeze", ["xu1", "un"], ["xu2"]),
oh.make_node("Reshape", ["xu2", "shape1"], ["xm1"]),
oh.make_node("Reshape", ["Y", "shape2"], ["xm2c"]),
oh.make_node("Cast", ["xm2c"], ["xm2"], to=1),
oh.make_node("MatMul", ["xm1", "xm2"], ["xm"]),
oh.make_node("Reshape", ["xm", "shape3"], ["Z"]),
],
"dummy",
[oh.make_tensor_value_info("X", TFLOAT, [320, 1280])],
[oh.make_tensor_value_info("Z", TFLOAT, [3, 5, 320, 640])],
[
onh.from_array(
np.random.rand(3, 5, 1280, 640).astype(np.float32), name="Y"
),
onh.from_array(np.array([0], dtype=np.int64), name="zero"),
onh.from_array(np.array([1], dtype=np.int64), name="un"),
onh.from_array(np.array([1, 320, 1280], dtype=np.int64), name="shape1"),
onh.from_array(np.array([15, 1280, 640], dtype=np.int64), name="shape2"),
onh.from_array(np.array([3, 5, 320, 640], dtype=np.int64), name="shape3"),
],
),
opset_imports=[oh.make_opsetid("", 18)],
ir_version=9,
)
check_model(model)
for i_node in range(len(model.graph.node) - 1):
if i_node == 2:
continue
node = model.graph.node[i_node]
meta = node.metadata_props.add()
meta.key = f"source[{i_node}]"
meta.value = "LLL"
self.assertRaise(
lambda: make_model_with_local_functions(
model,
"^LLL$",
metadata_key_prefix="source[",
verbose=1,
allow_extensions=False,
),
ValueError,
)
new_model = make_model_with_local_functions(
model, "^LLL$", metadata_key_prefix="source[", verbose=1
)
check_model(new_model)
self.assertEqual(len(new_model.functions), 1)
p = pretty_onnx(new_model)
self.assertIn("LLL[local_function]", p)

self.assertEqual(
["X", "Y", "shape1", "shape2", "un", "zero"], new_model.functions[0].input
)
self.assertEqual(["xm"], new_model.functions[0].output)
self.assertEqual("LLL", new_model.functions[0].name)
self.assertEqual("local_function", new_model.functions[0].domain)
self.assertEqual(len(new_model.functions[0].node), 6)

check_model(new_model)


if __name__ == "__main__":
unittest.main(verbosity=2)
183 changes: 183 additions & 0 deletions _unittests/ut_torch_export_patches/test_patch_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
requires_torch,
ignore_warnings,
has_onnxscript,
has_transformers,
requires_onnxscript,
)
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy, fake_torchdynamo_exporting
from onnx_diagnostic.export.shape_helper import make_fake_with_dynamic_dimensions
from onnx_diagnostic.torch_models.hghub.hub_api import get_cached_configuration
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
patch_qwen2_5,
patch_funnel,
Expand Down Expand Up @@ -392,6 +394,20 @@ def forward(self, q, k, cos, sin):
rtol=1,
)

@requires_transformers("4.55")
@requires_onnxscript("0.6.2")
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
def test_qwen_function_proto(self):
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
LoopAttention23,
LoopMHAAttention,
PackedAttention,
)

LoopMHAAttention.to_function_proto()
LoopAttention23.to_function_proto()
PackedAttention.to_function_proto()

@requires_transformers("4.55")
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
def test_patched_qwen2_5_vl_rot_pos_emb(self):
Expand Down Expand Up @@ -874,6 +890,173 @@ def test_model_funnel(self):
got = patched.relative_positional_attention(**inputs)
self.assertEqualArray(expected, got)

def test_cache_dependant_input_preparation_exporting(self):
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_generation_mixin import ( # noqa: E501
patched_GenerationMixin as GenerationMixin,
)

with self.subTest(case="case1"):
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)[:, :0]
inputs_embeds = torch.rand((2, 8), dtype=torch.float32)
cache_position = torch.arange(0, 8, dtype=torch.int64)
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(
input_ids, inputs_embeds, cache_position
)
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
input_ids, inputs_embeds, cache_position
)
torch.testing.assert_close(eager1, export1)
torch.testing.assert_close(eager2, export2)

with self.subTest(case="case2"):
raise unittest.SkipTest("torch 2.10+ has probably a bug here.")
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)
inputs_embeds = torch.rand((2, 8), dtype=torch.float32)
cache_position = torch.arange(0, 8, dtype=torch.int64)
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(
input_ids, inputs_embeds, cache_position
)
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
input_ids, inputs_embeds, cache_position
)
torch.testing.assert_close(eager1, export1)
torch.testing.assert_close(eager2, export2)

with self.subTest(case="case3"):
input_ids = torch.randint(0, 16, (2, 12), dtype=torch.int64)
inputs_embeds = None
cache_position = torch.arange(0, 8, dtype=torch.int64)
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(
input_ids, inputs_embeds, cache_position
)
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
input_ids, inputs_embeds, cache_position
)
torch.testing.assert_close(eager1, export1)
torch.testing.assert_close(eager2, export2)

with self.subTest(case="case4"):
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)
inputs_embeds = None
cache_position = torch.arange(0, 8, dtype=torch.int64)
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(
input_ids, inputs_embeds, cache_position
)
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
input_ids, inputs_embeds, cache_position
)
torch.testing.assert_close(eager1, export1)
torch.testing.assert_close(eager2, export2)

@requires_transformers("4.57")
def test_prepare_inputs_for_generation_decoder_llm(self):
data = get_untrained_model_with_inputs(
"hf-internal-testing/tiny-random-LlamaForCausalLM"
)
model = data["model"]
config = model.config
torch_device = "cpu"

with torch_export_patches(patch_transformers=True):
with self.subTest(case="case1"):
self.assertTrue("GenerationMixin" in str(model.prepare_inputs_for_generation))

input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]]).to(torch_device)
cache_position = torch.arange(input_ids.shape[1], device=input_ids.device)

with self.subTest(case="case2"):
input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]]).to(torch_device)
model_inputs = model.prepare_inputs_for_generation(
input_ids, cache_position=cache_position
)
self.assertTrue(torch.all(model_inputs["input_ids"] == input_ids))

with self.subTest(case="case3"):
attention_mask = torch.tensor([[1, 1, 1], [1, 1, 1]]).to(torch_device)
model_inputs = model.prepare_inputs_for_generation(
input_ids, attention_mask=attention_mask, cache_position=cache_position
)
self.assertTrue(torch.all(model_inputs["attention_mask"] == attention_mask))
self.assertTrue(model_inputs["position_ids"].shape == input_ids.shape)

with self.subTest(case="case4"):
self.assertFalse("use_cache" in model_inputs)
model_inputs = model.prepare_inputs_for_generation(
input_ids, use_cache=True, foo="bar", cache_position=cache_position
)
self.assertTrue(model_inputs["use_cache"] is True)
self.assertTrue(model_inputs["foo"] == "bar")

init_input_ids = input_ids[:, :2]
dynamic_cache = transformers.cache_utils.DynamicCache(config=config)
dynamic_cache = model(
init_input_ids, past_key_values=dynamic_cache
).past_key_values

with self.subTest(case="case5"):
if not has_transformers("4.57"):
raise unittest.SkipTest("transformers 4.57+.")
with self.assertRaises((AttributeError, TypeError)):
model_inputs = model.prepare_inputs_for_generation(
input_ids, past_key_values=dynamic_cache
)

with self.subTest(case="case6"):
cache_position = torch.arange(input_ids.shape[-1], dtype=torch.long).to(
torch_device
)
cache_position = cache_position[dynamic_cache.get_seq_length() :]
model_inputs = model.prepare_inputs_for_generation(
input_ids,
past_key_values=dynamic_cache,
cache_position=cache_position,
attention_mask=attention_mask,
)
self.assertTrue("past_key_values" in model_inputs)
self.assertTrue(torch.all(model_inputs["cache_position"] == cache_position))
self.assertTrue(
model_inputs["input_ids"].shape[-1] == 1
) # 1 = 3 fed tokens - 2 tokens in the cache
self.assertTrue(model_inputs["position_ids"].shape[-1] == 1)
self.assertTrue(
model_inputs["attention_mask"].shape[-1] == 3
) # we still need the full attention mask!

with self.subTest(case="case6.2"):
max_cache_len = 10
batch_size = 2
query_length = input_ids.shape[-1] - init_input_ids.shape[-1]
static_cache = transformers.cache_utils.StaticCache(
config=config, max_cache_len=max_cache_len
)
static_cache = model(
init_input_ids, past_key_values=static_cache
).past_key_values
model_inputs = model.prepare_inputs_for_generation(
input_ids,
past_key_values=static_cache,
cache_position=cache_position,
attention_mask=attention_mask,
)
self.assertTrue("past_key_values" in model_inputs)
self.assertTrue(
list(model_inputs["attention_mask"].shape)
== [batch_size, 1, query_length, max_cache_len]
)

with self.subTest(case="case7"):
if not has_transformers("4.57"):
raise unittest.SkipTest("transformers 4.57+.")
init_inputs_embeds = model.get_input_embeddings()(init_input_ids)
model_inputs = model.prepare_inputs_for_generation(
input_ids,
past_key_values=dynamic_cache,
inputs_embeds=init_inputs_embeds,
cache_position=cache_position,
)
self.assertTrue(model_inputs["input_ids"] is not None)
self.assertTrue(model_inputs["inputs_embeds"] is None)


if __name__ == "__main__":
unittest.main(verbosity=2)
Loading
Loading