Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions _unittests/ut_ci_models/test_ci_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def test_main_qwen25_tiny_llm(self):
pretrained=False,
part="",
output_folder=self.get_dump_folder("test_main_qwen25_tiny_llm"),
opset=24,
)
self.clean_dump()

Expand Down
4 changes: 3 additions & 1 deletion _unittests/ut_tasks/test_tasks_image_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def test_image_text_to_text_tiny_gemma3(self):
def test_image_text_to_text_gemma3_4b_it(self):
make_hybrid_cache = get_make_hybrid_cache()
if make_hybrid_cache is None:
raise unittest.SkipTest("not implemented yet for transformers>=5")
raise unittest.SkipTest(
"not implemented yet for transformers>=5 (make_hybrid_cache is None)"
)
mid = "google/gemma-3-4b-it"
data = get_untrained_model_with_inputs(
mid,
Expand Down
16 changes: 11 additions & 5 deletions _unittests/ut_torch_export_patches/test_patch_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ def forward(
for exporter in ("custom", "onnx-dynamo"):
# onnx-dynamo needs OpOverload(op='aten.sym_storage_offset' (transformers>=5.0?)
if exporter == "onnx-dynamo" and not has_onnxscript("0.5.7"):
raise unittest.SkipTest("needs onnxscript>=0.5.7")
self.skipTest("needs onnxscript>=0.5.7")
filename = self.get_dump_file(
f"test_patched_qwen2_5_vl_vision_attention_forward.{exporter}.onnx"
)
Expand Down Expand Up @@ -640,7 +640,7 @@ def test_qwen2_5_vl_vision_attention_iteration(self):
)
for exporter in ("custom", "onnx-dynamo"):
if exporter == "onnx-dynamo" and aten_sym_storage_offset is None:
raise unittest.SkipTest("update onnxscript to make this test run")
self.skipTest("update onnxscript to make this test run")
# onnx-dynamo needs OpOverload(op='aten.sym_storage_offset' (transformers>=5.0?)
filename = self.get_dump_file(
f"test_qwen2_5_vl_vision_attention_iteration.{exporter}.onnx"
Expand Down Expand Up @@ -909,7 +909,7 @@ def test_cache_dependant_input_preparation_exporting(self):
torch.testing.assert_close(eager2, export2)

with self.subTest(case="case2"):
raise unittest.SkipTest("torch 2.10+ has probably a bug here.")
self.skipTest("torch 2.10+ has probably a bug here.")
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)
inputs_embeds = torch.rand((2, 8), dtype=torch.float32)
cache_position = torch.arange(0, 8, dtype=torch.int64)
Expand Down Expand Up @@ -995,13 +995,17 @@ def test_prepare_inputs_for_generation_decoder_llm(self):

with self.subTest(case="case5"):
if not has_transformers("4.57"):
raise unittest.SkipTest("transformers 4.57+.")
self.skipTest("This test only works with transformers>=4.57, <5.3.")
if has_transformers("5.2.99"):
self.skipTest("This test is no longer valid with transformers>=5.3.")
with self.assertRaises((AttributeError, TypeError)):
model_inputs = model.prepare_inputs_for_generation(
input_ids, past_key_values=dynamic_cache
)

with self.subTest(case="case6"):
if has_transformers("5.2.99"):
self.skipTest("This test is no longer valid with transformers>=5.3.")
cache_position = torch.arange(input_ids.shape[-1], dtype=torch.long).to(
torch_device
)
Expand All @@ -1023,6 +1027,8 @@ def test_prepare_inputs_for_generation_decoder_llm(self):
) # we still need the full attention mask!

with self.subTest(case="case6.2"):
if has_transformers("5.2.99"):
self.skipTest("This test is no longer valid with transformers>=5.3.")
max_cache_len = 10
batch_size = 2
query_length = input_ids.shape[-1] - init_input_ids.shape[-1]
Expand All @@ -1046,7 +1052,7 @@ def test_prepare_inputs_for_generation_decoder_llm(self):

with self.subTest(case="case7"):
if not has_transformers("4.57"):
raise unittest.SkipTest("transformers 4.57+.")
self.skipTest("This test only works with transformers>=4.57.")
init_inputs_embeds = model.get_input_embeddings()(init_input_ids)
model_inputs = model.prepare_inputs_for_generation(
input_ids,
Expand Down
6 changes: 6 additions & 0 deletions onnx_diagnostic/ci_models/ci_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ def get_parser(name: str, epilog: str = "") -> ArgumentParser:
help="Profiles the exporter and outputs an html document from pyinstrument",
action=BooleanOptionalAction,
)
parser.add_argument(
"--opset",
type=int,
default=0,
help="default opsets, 0 to let the exporter choose",
)
return parser


Expand Down
5 changes: 4 additions & 1 deletion onnx_diagnostic/ci_models/export_phi4_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,7 @@ def main(
atol: float = 2,
mismatch01: float = 0.01,
profile_exporter: bool = False,
opset: Optional[int] = None,
):
"""
Exports model Qwen/Qwen2.5-VL-7B-Instruct or pieces of it.
Expand All @@ -733,6 +734,7 @@ def main(
:param atol: raises an exception if tolerance is above that threshold
:param mismatch01: raises an exception if the ratio of mismatches
is above that threshold
:param opset: opset to choose
:param profile_exporter: profiles the exporter
"""
prefix = simplify_model_id_for_a_filename(model_id)
Expand Down Expand Up @@ -947,7 +949,7 @@ def forward(

begin = time.perf_counter()

target_opset = 22
target_opset = opset or 22

details = PatchDetails()
with torch_export_patches(
Expand Down Expand Up @@ -1062,4 +1064,5 @@ def forward(
atol=args.atol,
mismatch01=args.mismatch01,
profile_exporter=args.profile_exporter,
opset=args.opset if args.opset > 0 else None,
)
11 changes: 8 additions & 3 deletions onnx_diagnostic/ci_models/export_qwen25_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
import sys
import time
import warnings
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple
from .ci_helpers import (
check_for_discrepancies_and_log_everything_into_a_json_file,
compute_expected_outputs,
Expand Down Expand Up @@ -199,6 +199,7 @@ def main(
atol: float = 0.01,
mismatch01: float = 0.1,
profile_exporter: bool = False,
opset: Optional[int] = None,
):
"""
Exports model Qwen/Qwen2.5-VL-7B-Instruct or pieces of it.
Expand All @@ -221,6 +222,8 @@ def main(
:param atol: raises an exception if tolerance is above that threshold
:param mismatch01: raises an exception if the ratio of mismatches
is above that threshold
:param opset: opset, if not specified, a value is chosen based on the
proposed rewriting
:param profile_exporter: profiles the exporter
"""
prefix = simplify_model_id_for_a_filename(model_id)
Expand All @@ -243,6 +246,7 @@ def main(
print(f"-- make_zip={make_zip}")
print(f"-- output_folder={output_folder}")
print(f"-- atol={atol}")
print(f"-- opset={opset}")
print(f"-- mismatch01={mismatch01}")
print(f"-- profile_exporter={profile_exporter}")
print("------------------------------------------------------------------")
Expand Down Expand Up @@ -473,15 +477,15 @@ def process_image(inputs_embeds, image_features):

begin = time.perf_counter()

target_opset = 22
target_opset = opset or 22
if (
exporter == "onnx-dynamo"
and device == "cuda"
and "QWEN25ATTENTION" not in os.environ
):
os.environ["QWEN25ATTENTION"] = "PACKED"
elif "QWEN25ATTENTION" in os.environ and os.environ["QWEN25ATTENTION"] == "LOOPA23":
target_opset = 23
target_opset = opset or 23

with torch_export_patches(
patch_torch=False,
Expand Down Expand Up @@ -565,4 +569,5 @@ def process_image(inputs_embeds, image_features):
atol=args.atol,
mismatch01=args.mismatch01,
profile_exporter=args.profile_exporter,
opset=args.opset if args.opset > 0 else None,
)
33 changes: 25 additions & 8 deletions onnx_diagnostic/helpers/cache_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,8 +539,10 @@ def make_encoder_decoder_cache(

def make_mamba_cache(
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
cls_layers: Optional[Union[str, List[type]]] = None,
cls_kwargs: Optional[Union[Dict[str, int], List[Dict[str, int]]]] = None,
) -> "MambaCache": # noqa: F821
"Creates a ``MambaCache``."
"""Creates a ``MambaCache``. `cls_layers`, `cls_kwargs` are unused."""
# import is moved here because this part is slow.
try:
from transformers.models.mamba.modeling_mamba import MambaCache
Expand Down Expand Up @@ -591,8 +593,13 @@ def get_text_config(self, *args, **kwargs):

def make_sliding_window_cache(
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
cls_layers: Optional[Union[str, List[type]]] = None,
cls_kwargs: Optional[Union[Dict[str, int], List[Dict[str, int]]]] = None,
) -> transformers.cache_utils.SlidingWindowCache:
"Creates a :class:`transformers.cache_utils.SlidingWindowCache`."
"""
Creates a :class:`transformers.cache_utils.SlidingWindowCache`.
`cls_layers`, `cls_kwargs` are unused.
"""
key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)

class _config:
Expand Down Expand Up @@ -654,6 +661,8 @@ def make_hybrid_cache(
max_cache_len: Optional[int] = None,
max_batch_size: Optional[int] = None,
sliding_window: Optional[int] = None,
cls_layers: Optional[Union[str, List[type]]] = None,
cls_kwargs: Optional[Union[Dict[str, int], List[Dict[str, int]]]] = None,
) -> transformers.cache_utils.HybridCache:
"""
Creates an instance of :class:`transformers.cache_utils.HybridCache`.
Expand All @@ -662,6 +671,8 @@ def make_hybrid_cache(
:param key_value_pairs: list of pairs of (key, values)
:return: :class:`transformers.cache_utils.HybridCache`

`cls_layers`, `cls_kwargs` are unused.

Example:

.. runpython::
Expand Down Expand Up @@ -742,16 +753,22 @@ def make_hybrid_cache(
not max_batch_size and not max_cache_len
), "key_value_pairs is not empty, do not specify max_cache_len and max_batch_size"
max_batch_size = key_value_pairs[0][0].shape[0]
assert max_cache_len is not None or all(
isinstance(kv[0].shape[2], int) for kv in key_value_pairs
), (
f"Cannot determine max_cache_len with "
f"shapes={[kv[0].shape for kv in key_value_pairs]}"
)
sets_of_dim = set(kv[0].shape[2] for kv in key_value_pairs)
if len(sets_of_dim) == 1:
max_cache_len = sets_of_dim.pop()
sliding_window = max_cache_len
if max_cache_len is None:
max_cache_len = sets_of_dim.pop()
else:
assert (
len(sets_of_dim) == 2
), f"Not implemented for more than 2 dimensions {sets_of_dim}"
max_cache_len = max(sets_of_dim)
sliding_window = min(sets_of_dim)
if max_cache_len is None:
max_cache_len = max(sets_of_dim)
layer_types = [
"full_attention" if i == max_cache_len else "sliding_attention"
for i in [kv[0].shape[2] for kv in key_value_pairs]
Expand All @@ -760,8 +777,8 @@ def make_hybrid_cache(
assert (
max_batch_size and max_cache_len
), "key_value_pairs is empty, max_batch_size and max_cache_len are required"
if sliding_window is None:
sliding_window = max_cache_len
if sliding_window is None:
sliding_window = max_cache_len
_max_cache_len = max_cache_len
_sliding_window = sliding_window

Expand Down
2 changes: 1 addition & 1 deletion onnx_diagnostic/tasks/image_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def get_inputs_default(
"past_key_values": list(
itertools.chain.from_iterable(
zip(
[{0: batch} for _ in range(num_hidden_layers)],
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,18 @@ def patched_sdpa_attention_forward(
if is_causal is None and attention_mask is not None:
is_causal = False
if is_causal is not None:
torch._check(query.shape[0] > 0)
torch._check(query.shape[1] > 0)
torch._check(query.shape[2] > 0)
torch._check(query.shape[3] > 0)
torch._check(key.shape[0] > 0)
torch._check(key.shape[1] > 0)
torch._check(key.shape[2] > 0)
torch._check(key.shape[3] > 0)
torch._check(value.shape[0] > 0)
torch._check(value.shape[1] > 0)
torch._check(value.shape[2] > 0)
torch._check(value.shape[3] > 0)
return (
torch.nn.functional.scaled_dot_product_attention(
query,
Expand Down
Loading
Loading