From b41c260f331f1c78a4c7728829fed6479dc6819d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 17 Feb 2026 15:22:27 +0100 Subject: [PATCH 1/4] improves unit tests --- README.rst | 7 +- _doc/index.rst | 4 +- _unittests/ut_reference/test_reference_ops.py | 38 ++++++++++- .../ut_xrun_doc/test_command_lines_exe.py | 23 +++++++ onnx_diagnostic/reference/evaluator.py | 34 ---------- .../reference/ops/op_add_add_mul_mul.py | 68 ------------------- .../reference/ops/op_average_pool_grad.py | 63 ----------------- .../reference/ops/op_gather_grad.py | 12 ---- .../reference/ops/op_mul_sigmoid.py | 23 ------- onnx_diagnostic/reference/ops/op_negxplus1.py | 8 --- .../reference/ops/op_scatternd_of_shape.py | 22 ------ .../reference/ops/op_transpose_cast.py | 16 ----- .../reference/ops/op_tri_matrix.py | 17 ----- onnx_diagnostic/torch_models/code_sample.py | 2 +- 14 files changed, 68 insertions(+), 269 deletions(-) delete mode 100644 onnx_diagnostic/reference/ops/op_add_add_mul_mul.py delete mode 100644 onnx_diagnostic/reference/ops/op_average_pool_grad.py delete mode 100644 onnx_diagnostic/reference/ops/op_gather_grad.py delete mode 100644 onnx_diagnostic/reference/ops/op_mul_sigmoid.py delete mode 100644 onnx_diagnostic/reference/ops/op_negxplus1.py delete mode 100644 onnx_diagnostic/reference/ops/op_scatternd_of_shape.py delete mode 100644 onnx_diagnostic/reference/ops/op_transpose_cast.py delete mode 100644 onnx_diagnostic/reference/ops/op_tri_matrix.py diff --git a/README.rst b/README.rst index 01a0cd2c..24510432 100644 --- a/README.rst +++ b/README.rst @@ -8,8 +8,8 @@ onnx-diagnostic: investigate onnx models .. image:: https://github.com/sdpython/onnx-diagnostic/actions/workflows/documentation.yml/badge.svg :target: https://github.com/sdpython/onnx-diagnostic/actions/workflows/documentation.yml -.. image:: https://badge.fury.io/py/onnx-diagnostic.svg - :target: http://badge.fury.io/py/onnx-diagnostic +.. image:: https://img.shields.io/pypi/v/onnx-diagnostic.svg + :target: https://pypi.org/project/onnx-diagnostic .. image:: https://img.shields.io/badge/license-MIT-blue.svg :alt: MIT License @@ -19,6 +19,9 @@ onnx-diagnostic: investigate onnx models :target: https://github.com/sdpython/onnx-diagnostic/ :alt: size +.. image:: https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json + :target: https://github.com/astral-sh/ruff + .. image:: https://img.shields.io/badge/code%20style-black-000000.svg :target: https://github.com/psf/black diff --git a/_doc/index.rst b/_doc/index.rst index 8caca441..5e46d514 100644 --- a/_doc/index.rst +++ b/_doc/index.rst @@ -5,8 +5,8 @@ onnx-diagnostic: investigate onnx models .. image:: https://github.com/sdpython/onnx-diagnostic/actions/workflows/documentation.yml/badge.svg :target: https://github.com/sdpython/onnx-diagnostic/actions/workflows/documentation.yml -.. image:: https://badge.fury.io/py/onnx-diagnostic.svg - :target: http://badge.fury.io/py/onnx-diagnostic +.. image:: https://img.shields.io/pypi/v/onnx-diagnostic.svg + :target: https://pypi.org/project/onnx-diagnostic .. image:: https://img.shields.io/badge/license-MIT-blue.svg :alt: MIT License diff --git a/_unittests/ut_reference/test_reference_ops.py b/_unittests/ut_reference/test_reference_ops.py index b290b2eb..6f2d7395 100644 --- a/_unittests/ut_reference/test_reference_ops.py +++ b/_unittests/ut_reference/test_reference_ops.py @@ -117,7 +117,7 @@ def test_quick_gelu(self): got = ref.run(None, {"X": a}) self.assertEqualArray(expected[0], got[0]) - def test_scatter_elements(self): + def test_scatter_elements_4d(self): model = oh.make_model( oh.make_graph( [ @@ -149,6 +149,42 @@ def test_scatter_elements(self): got = ref.run(None, {"data": data, "indices": indices, "updates": updates}) self.assertEqualArray(y, got[0]) + def test_scatter_elements_3d(self): + ys = [ + np.array([1, 0, 0, 0, 0, 0, 0, 0], dtype=np.float32).reshape((2, 2, 2)), + np.array([1, 0, 0, 0, 0, 0, 0, 0], dtype=np.float32).reshape((2, 2, 2)), + np.array([1, 0, 0, 0, 0, 0, 0, 0], dtype=np.float32).reshape((2, 2, 2)), + ] + + for axis, y in zip([0, 1, 2], ys): + model = oh.make_model( + oh.make_graph( + [ + oh.make_node( + "ScatterElements", + ["data", "indices", "updates"], + ["Z"], + axis=axis, + reduction="add", + ) + ], + "name", + [ + oh.make_tensor_value_info("data", TensorProto.FLOAT, None), + oh.make_tensor_value_info("indices", TensorProto.INT64, None), + oh.make_tensor_value_info("updates", TensorProto.FLOAT, None), + ], + [make_tensor_value_info("Z", TensorProto.FLOAT, None)], + ), + opset_imports=[make_opsetid("", 18)], + ) + data = np.zeros(2**3, dtype=np.float32).reshape((2, 2, 2)) + indices = np.array([[[0]]], dtype=np.int64) + updates = np.array([[[1]]], dtype=np.float32) + ref = ExtendedReferenceEvaluator(model) + got = ref.run(None, {"data": data, "indices": indices, "updates": updates}) + self.assertEqualArray(y, got[0]) + def test_skip_layer_normalization_nobias(self): import onnxruntime diff --git a/_unittests/ut_xrun_doc/test_command_lines_exe.py b/_unittests/ut_xrun_doc/test_command_lines_exe.py index bd78aae7..8c1c4592 100644 --- a/_unittests/ut_xrun_doc/test_command_lines_exe.py +++ b/_unittests/ut_xrun_doc/test_command_lines_exe.py @@ -227,6 +227,29 @@ def test_m_parser_partition(self): text = st.getvalue() self.assertIn("-- done", text) + def test_n_parser_export_sample(self): + st = StringIO() + with redirect_stdout(st): + main(["exportsample", "-m", "arnir0/Tiny-LLM", "--run", "-v", "1"]) + text = st.getvalue() + self.assertIn("def get_model_with_inputs(", text) + st = StringIO() + with redirect_stdout(st): + main( + [ + "exportsample", + "-m", + "arnir0/Tiny-LLM", + "--run", + "-v", + "1", + "--export", + "custom", + ] + ) + text = st.getvalue() + self.assertIn("def get_model_with_inputs(", text) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/reference/evaluator.py b/onnx_diagnostic/reference/evaluator.py index bd13cbfb..b57caf54 100644 --- a/onnx_diagnostic/reference/evaluator.py +++ b/onnx_diagnostic/reference/evaluator.py @@ -4,18 +4,7 @@ from onnx.defs import get_schema from onnx.reference import ReferenceEvaluator from onnx.reference.op_run import OpRun -from .ops.op_add_add_mul_mul import ( - AddAdd, - AddMul, - AddSharedInput, - MulAdd, - MulMul, - MulSharedInput, - MulSub, - SubMul, -) from .ops.op_attention import Attention -from .ops.op_average_pool_grad import AveragePoolGrad from .ops.op_bias_softmax import BiasSoftmax from .ops.op_cast_like import CastLike_15, CastLike_19 from .ops.op_complex import ComplexModule, ToComplex @@ -24,10 +13,7 @@ from .ops.op_fused_matmul import FusedMatMul from .ops.op_gather import Gather from .ops.op_gather_elements import GatherElements -from .ops.op_gather_grad import GatherGrad from .ops.op_memcpy_host import MemcpyFromHost, MemcpyToHost -from .ops.op_mul_sigmoid import MulSigmoid -from .ops.op_negxplus1 import NegXplus1 from .ops.op_qlinear_average_pool import QLinearAveragePool from .ops.op_qlinear_conv import QLinearConv from .ops.op_quick_gelu import QuickGelu @@ -35,12 +21,9 @@ from .ops.op_rotary import Rotary from .ops.op_scan import Scan from .ops.op_scatter_elements import ScatterElements -from .ops.op_scatternd_of_shape import MaskedScatterNDOfShape, ScatterNDOfShape from .ops.op_simplified_layer_normalization import SimplifiedLayerNormalization from .ops.op_skip_layer_normalization import SkipLayerNormalization from .ops.op_slice import Slice_1, Slice_10 -from .ops.op_transpose_cast import Transpose2DCastFP16, Transpose2DCastFP32 -from .ops.op_tri_matrix import TriMatrix logger = getLogger("onnx-diagnostic-eval") @@ -70,11 +53,7 @@ class ExtendedReferenceEvaluator(ReferenceEvaluator): """ default_ops: List[type[OpRun]] = [ - AddAdd, - AddMul, - AddSharedInput, Attention, - AveragePoolGrad, BiasSoftmax, Concat, CastLike_15, @@ -84,16 +63,8 @@ class ExtendedReferenceEvaluator(ReferenceEvaluator): FusedMatMul, Gather, GatherElements, - GatherGrad, - MaskedScatterNDOfShape, MemcpyFromHost, MemcpyToHost, - MulAdd, - MulMul, - MulSharedInput, - MulSigmoid, - MulSub, - NegXplus1, QLinearConv, QLinearAveragePool, QuickGelu, @@ -101,16 +72,11 @@ class ExtendedReferenceEvaluator(ReferenceEvaluator): Rotary, Scan, ScatterElements, - ScatterNDOfShape, SimplifiedLayerNormalization, SkipLayerNormalization, Slice_1, Slice_10, - SubMul, ToComplex, - Transpose2DCastFP16, - Transpose2DCastFP32, - TriMatrix, ] @staticmethod diff --git a/onnx_diagnostic/reference/ops/op_add_add_mul_mul.py b/onnx_diagnostic/reference/ops/op_add_add_mul_mul.py deleted file mode 100644 index 963d4305..00000000 --- a/onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +++ /dev/null @@ -1,68 +0,0 @@ -import numpy as np -from onnx.reference.op_run import OpRun - - -class AddAdd(OpRun): - op_domain = "onnx_extended.ortops.optim.cuda" - - def _run(self, x, y, z): - return (x + y + z,) - - -class MulMul(OpRun): - op_domain = "onnx_extended.ortops.optim.cuda" - - def _run(self, x, y, z): - return (x * y * z,) - - -class AddMul(OpRun): - op_domain = "onnx_extended.ortops.optim.cuda" - - def _run(self, x, y, z, transposeMiddle=None): - res = (x + y) * z - if transposeMiddle: - res = np.transpose(res, axes=[0, 2, 1, 3]) - return (res,) - - -class MulAdd(OpRun): - op_domain = "onnx_extended.ortops.optim.cuda" - - def _run(self, x, y, z, transposeMiddle=None): - res = (x * y) + z - if transposeMiddle: - res = np.transpose(res, axes=[0, 2, 1, 3]) - return (res,) - - -class SubMul(OpRun): - op_domain = "onnx_extended.ortops.optim.cuda" - - def _run(self, x, y, z, negative=None): - if negative: - return ((y - x) * z,) - return ((x - y) * z,) - - -class MulSub(OpRun): - op_domain = "onnx_extended.ortops.optim.cuda" - - def _run(self, x, y, z, negative=None): - if negative: - return (z - (x * y),) - return ((x * y) - z,) - - -class AddSharedInput(OpRun): - op_domain = "onnx_extended.ortops.optim.cuda" - - def _run(self, x, y, z): - return (x + y, x + z) - - -class MulSharedInput(OpRun): - op_domain = "onnx_extended.ortops.optim.cuda" - - def _run(self, x, y, z): - return (x * y, x * z) diff --git a/onnx_diagnostic/reference/ops/op_average_pool_grad.py b/onnx_diagnostic/reference/ops/op_average_pool_grad.py deleted file mode 100644 index 95cc7854..00000000 --- a/onnx_diagnostic/reference/ops/op_average_pool_grad.py +++ /dev/null @@ -1,63 +0,0 @@ -import numpy as np -from onnx.reference.op_run import OpRun - - -class AveragePoolGrad(OpRun): - def _run( - self, - out, - auto_pad=None, - ceil_mode=None, - count_include_pad=None, - kernel_shape=None, - pads=None, - strides=None, - ): - assert auto_pad is not None, "auto_pad is None" - assert ceil_mode is not None, "ceil_mode is None" - assert count_include_pad is not None, "count_include_pad is None" - assert kernel_shape is not None, "kernel_shape is None" - assert pads is not None, "pads is None" - assert strides is not None, "strides is None" - - assert auto_pad == "NOTSET", f"Not implemented for autopad={auto_pad!r}" - assert ceil_mode == 0, f"Not implemented for ceil_mode={ceil_mode!r}" - assert ( - count_include_pad == 1 - ), f"Not implemented for count_include_pad={count_include_pad!r}" - - grad_shape = list(out.shape[:2]) - for i in range(len(kernel_shape)): - d = ( - out.shape[i + 2] * strides[i] - + kernel_shape[i] - - 1 - + sum(pads[i * 2 : i * 2 + 2]) - ) - grad_shape.append(d) - - grad = np.zeros(tuple(grad_shape), dtype=out.dtype) - scale = (1.0 / np.prod(kernel_shape)).astype(out.dtype) - if len(grad_shape) == 4: - # 2D - for batch in range(grad.shape[0]): - for channel in range(grad.shape[1]): - for i in range(out.shape[2]): - t = max(i * strides[0] - pads[0], 0) - b = min(i * strides[0] - pads[0] + kernel_shape[0], grad.shape[2]) - for j in range(out.shape[3]): - le = max(j * strides[1] - pads[2], 0) - ri = min( - j * strides[1] - pads[2] + kernel_shape[1], - grad.shape[3], - ) - - grad[batch, channel, t:b, le:ri] += ( - out[batch, channel, i, j] * scale - ) - else: - raise NotImplementedError( - f"AveragePoolGrad is not implemented for shape={out.shape}." - ) - - return (grad.astype(out.dtype),) diff --git a/onnx_diagnostic/reference/ops/op_gather_grad.py b/onnx_diagnostic/reference/ops/op_gather_grad.py deleted file mode 100644 index 76d95a3f..00000000 --- a/onnx_diagnostic/reference/ops/op_gather_grad.py +++ /dev/null @@ -1,12 +0,0 @@ -import numpy as np -from onnx.reference.op_run import OpRun -from onnx.reference.ops.op_scatternd import _scatter_nd_impl - - -class GatherGrad(OpRun): - op_domain = "com.microsoft" - - def _run(self, shape, indices, updates, reduction=None): - data = np.zeros(shape, dtype=updates.dtype) - y = _scatter_nd_impl(data, indices, updates, reduction=reduction) - return (y,) diff --git a/onnx_diagnostic/reference/ops/op_mul_sigmoid.py b/onnx_diagnostic/reference/ops/op_mul_sigmoid.py deleted file mode 100644 index b49a7fef..00000000 --- a/onnx_diagnostic/reference/ops/op_mul_sigmoid.py +++ /dev/null @@ -1,23 +0,0 @@ -import numpy as np -from onnx.reference.op_run import OpRun - - -def sigmoid(x): # type: ignore - if x > 0: - return 1 / (1 + np.exp(-x)) - return np.exp(x) / (1 + np.exp(x)) - - -class MulSigmoid(OpRun): - op_domain = "onnx_extended.ortops.optim.cuda" - - def __init__(self, onnx_node, run_params): # type: ignore - OpRun.__init__(self, onnx_node, run_params) - self.vf = np.vectorize(sigmoid) - - def _run(self, X): - if len(X.shape) == 0: - return ((X * sigmoid(X)).astype(X.dtype),) - if X.size == 0: - return (X,) - return ((X * self.vf(X)).astype(X.dtype),) diff --git a/onnx_diagnostic/reference/ops/op_negxplus1.py b/onnx_diagnostic/reference/ops/op_negxplus1.py deleted file mode 100644 index 60fd5458..00000000 --- a/onnx_diagnostic/reference/ops/op_negxplus1.py +++ /dev/null @@ -1,8 +0,0 @@ -from onnx.reference.op_run import OpRun - - -class NegXplus1(OpRun): - op_domain = "onnx_extended.ortops.optim.cuda" - - def _run(self, X): - return ((1 - X).astype(X.dtype),) diff --git a/onnx_diagnostic/reference/ops/op_scatternd_of_shape.py b/onnx_diagnostic/reference/ops/op_scatternd_of_shape.py deleted file mode 100644 index 1e378a7d..00000000 --- a/onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +++ /dev/null @@ -1,22 +0,0 @@ -import numpy as np -from onnx.reference.op_run import OpRun -from onnx.reference.ops.op_scatternd import _scatter_nd_impl - - -class ScatterNDOfShape(OpRun): - op_domain = "onnx_extended.ortops.optim.cuda" - - def _run(self, shape, indices, updates, reduction=None, strategy=None): - data = np.zeros(shape, dtype=updates.dtype) - y = _scatter_nd_impl(data, indices, updates, reduction=reduction) - return (y,) - - -class MaskedScatterNDOfShape(OpRun): - op_domain = "onnx_extended.ortops.optim.cuda" - - def _run(self, shape, indices, updates, reduction=None, maskedValue=None): - data = np.zeros(shape, dtype=updates.dtype) - new_updates = np.where(indices == maskedValue, 0, updates) - y = _scatter_nd_impl(data, indices, new_updates, reduction=reduction) - return (y,) diff --git a/onnx_diagnostic/reference/ops/op_transpose_cast.py b/onnx_diagnostic/reference/ops/op_transpose_cast.py deleted file mode 100644 index 8d738209..00000000 --- a/onnx_diagnostic/reference/ops/op_transpose_cast.py +++ /dev/null @@ -1,16 +0,0 @@ -import numpy as np -from onnx.reference.op_run import OpRun - - -class Transpose2DCastFP16(OpRun): - op_domain = "onnx_extended.ortops.optim.cuda" - - def _run(self, X): - return (X.T.astype(np.float16),) - - -class Transpose2DCastFP32(OpRun): - op_domain = "onnx_extended.ortops.optim.cuda" - - def _run(self, X): - return (X.T.astype(np.float32),) diff --git a/onnx_diagnostic/reference/ops/op_tri_matrix.py b/onnx_diagnostic/reference/ops/op_tri_matrix.py deleted file mode 100644 index 0a18bb98..00000000 --- a/onnx_diagnostic/reference/ops/op_tri_matrix.py +++ /dev/null @@ -1,17 +0,0 @@ -import numpy as np -from onnx.reference.op_run import OpRun - - -class TriMatrix(OpRun): - op_domain = "onnx_extended.ortops.optim.cuda" - - def _run(self, shape, csts): - lower, diag, upper = list(csts) - dtype = csts.dtype - mat = np.empty(tuple(shape), dtype=dtype) - i = np.arange(shape[0], dtype=np.int32).reshape((-1, 1)) - j = np.arange(shape[1], dtype=np.int32).reshape((1, -1)) - mat[i > j] = lower - mat[i < j] = upper - mat[i == j] = diag - return (mat,) diff --git a/onnx_diagnostic/torch_models/code_sample.py b/onnx_diagnostic/torch_models/code_sample.py index 31930481..e83830b1 100644 --- a/onnx_diagnostic/torch_models/code_sample.py +++ b/onnx_diagnostic/torch_models/code_sample.py @@ -309,7 +309,7 @@ def code_sample( dynamic_shapes=data["dynamic_shapes"], ) if exporter is not None - else ([], []) + else ("", "") ) input_code = make_code_for_inputs(data["inputs"]) cache_import = ( From c49b4d592d23cf84fc281b01baa8506df142f58f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 17 Feb 2026 16:22:23 +0100 Subject: [PATCH 2/4] fix documentation --- _doc/api/reference/ops/index.rst | 13 ------------- _doc/api/reference/ops/op_add_add_mul_mul.rst | 6 ------ .../api/reference/ops/op_average_pool_grad.rst | 6 ------ _doc/api/reference/ops/op_gather_grad.rst | 6 ------ _doc/api/reference/ops/op_mul_sigmoid.rst | 6 ------ _doc/api/reference/ops/op_negxplus1.rst | 6 ------ _doc/api/reference/ops/op_replace_zero.rst | 6 ------ .../reference/ops/op_scatternd_of_shape.rst | 6 ------ _doc/api/reference/ops/op_transpose_cast.rst | 5 ----- _doc/api/reference/ops/op_tri_matrix.rst | 6 ------ .../test_patch_module.py | 18 +++++++++++++----- onnx_diagnostic/reference/evaluator.py | 2 -- .../reference/ops/op_replace_zero.py | 13 ------------- 13 files changed, 13 insertions(+), 86 deletions(-) delete mode 100644 _doc/api/reference/ops/op_add_add_mul_mul.rst delete mode 100644 _doc/api/reference/ops/op_average_pool_grad.rst delete mode 100644 _doc/api/reference/ops/op_gather_grad.rst delete mode 100644 _doc/api/reference/ops/op_mul_sigmoid.rst delete mode 100644 _doc/api/reference/ops/op_negxplus1.rst delete mode 100644 _doc/api/reference/ops/op_replace_zero.rst delete mode 100644 _doc/api/reference/ops/op_scatternd_of_shape.rst delete mode 100644 _doc/api/reference/ops/op_transpose_cast.rst delete mode 100644 _doc/api/reference/ops/op_tri_matrix.rst delete mode 100644 onnx_diagnostic/reference/ops/op_replace_zero.py diff --git a/_doc/api/reference/ops/index.rst b/_doc/api/reference/ops/index.rst index 65ab437c..ce83426d 100644 --- a/_doc/api/reference/ops/index.rst +++ b/_doc/api/reference/ops/index.rst @@ -2,37 +2,24 @@ onnx_diagnostic.reference.ops ============================= - - .. toctree:: :maxdepth: 1 :caption: modules - - op_add_add_mul_mul - op_average_pool_grad op_cast_like op_complex op_concat op_constant_of_shape op_fused_matmul - op_gather_grad op_memcpy_host - op_mul_sigmoid - op_negxplus1 op_quick_gelu - op_replace_zero op_rotary op_qlinear_average_pool op_qlinear_conv op_scatter_elements - op_scatternd_of_shape op_simplified_layer_normalization op_skip_layer_normalization op_slice - op_transpose_cast - op_tri_matrix - .. automodule:: onnx_diagnostic.reference.ops :members: diff --git a/_doc/api/reference/ops/op_add_add_mul_mul.rst b/_doc/api/reference/ops/op_add_add_mul_mul.rst deleted file mode 100644 index f95f78b4..00000000 --- a/_doc/api/reference/ops/op_add_add_mul_mul.rst +++ /dev/null @@ -1,6 +0,0 @@ - -onnx_diagnostic.reference.ops.op_add_add_mul_mul -================================================ - -.. automodule:: onnx_diagnostic.reference.ops.op_add_add_mul_mul - :members: diff --git a/_doc/api/reference/ops/op_average_pool_grad.rst b/_doc/api/reference/ops/op_average_pool_grad.rst deleted file mode 100644 index ce02ba3e..00000000 --- a/_doc/api/reference/ops/op_average_pool_grad.rst +++ /dev/null @@ -1,6 +0,0 @@ - -onnx_diagnostic.reference.ops.op_average_pool_grad -================================================== - -.. automodule:: onnx_diagnostic.reference.ops.op_average_pool_grad - :members: diff --git a/_doc/api/reference/ops/op_gather_grad.rst b/_doc/api/reference/ops/op_gather_grad.rst deleted file mode 100644 index abf6aeb8..00000000 --- a/_doc/api/reference/ops/op_gather_grad.rst +++ /dev/null @@ -1,6 +0,0 @@ - -onnx_diagnostic.reference.ops.op_gather_grad -============================================ - -.. automodule:: onnx_diagnostic.reference.ops.op_gather_grad - :members: diff --git a/_doc/api/reference/ops/op_mul_sigmoid.rst b/_doc/api/reference/ops/op_mul_sigmoid.rst deleted file mode 100644 index ea0f800e..00000000 --- a/_doc/api/reference/ops/op_mul_sigmoid.rst +++ /dev/null @@ -1,6 +0,0 @@ - -onnx_diagnostic.reference.ops.op_mul_sigmoid -============================================ - -.. automodule:: onnx_diagnostic.reference.ops.op_mul_sigmoid - :members: diff --git a/_doc/api/reference/ops/op_negxplus1.rst b/_doc/api/reference/ops/op_negxplus1.rst deleted file mode 100644 index 53e637e4..00000000 --- a/_doc/api/reference/ops/op_negxplus1.rst +++ /dev/null @@ -1,6 +0,0 @@ - -onnx_diagnostic.reference.ops.op_negxplus1 -========================================== - -.. automodule:: onnx_diagnostic.reference.ops.op_negxplus1 - :members: diff --git a/_doc/api/reference/ops/op_replace_zero.rst b/_doc/api/reference/ops/op_replace_zero.rst deleted file mode 100644 index 40b5309d..00000000 --- a/_doc/api/reference/ops/op_replace_zero.rst +++ /dev/null @@ -1,6 +0,0 @@ - -onnx_diagnostic.reference.ops.op_replace_zero -============================================= - -.. automodule:: onnx_diagnostic.reference.ops.op_replace_zero - :members: diff --git a/_doc/api/reference/ops/op_scatternd_of_shape.rst b/_doc/api/reference/ops/op_scatternd_of_shape.rst deleted file mode 100644 index 0c24ff2e..00000000 --- a/_doc/api/reference/ops/op_scatternd_of_shape.rst +++ /dev/null @@ -1,6 +0,0 @@ - -onnx_diagnostic.reference.ops.op_scatternd_of_shape -=================================================== - -.. automodule:: onnx_diagnostic.reference.ops.op_scatternd_of_shape - :members: diff --git a/_doc/api/reference/ops/op_transpose_cast.rst b/_doc/api/reference/ops/op_transpose_cast.rst deleted file mode 100644 index 0b230a7d..00000000 --- a/_doc/api/reference/ops/op_transpose_cast.rst +++ /dev/null @@ -1,5 +0,0 @@ -onnx_diagnostic.reference.ops.op_transpose_cast -=============================================== - -.. automodule:: onnx_diagnostic.reference.ops.op_transpose_cast - :members: diff --git a/_doc/api/reference/ops/op_tri_matrix.rst b/_doc/api/reference/ops/op_tri_matrix.rst deleted file mode 100644 index f03b1e13..00000000 --- a/_doc/api/reference/ops/op_tri_matrix.rst +++ /dev/null @@ -1,6 +0,0 @@ - -onnx_diagnostic.reference.ops.op_tri_matrix -=========================================== - -.. automodule:: onnx_diagnostic.reference.ops.op_tri_matrix - :members: diff --git a/_unittests/ut_torch_export_patches/test_patch_module.py b/_unittests/ut_torch_export_patches/test_patch_module.py index 8b8c1ca4..8eb343f3 100644 --- a/_unittests/ut_torch_export_patches/test_patch_module.py +++ b/_unittests/ut_torch_export_patches/test_patch_module.py @@ -423,12 +423,20 @@ def filter_node(node) -> bool: filter_node=filter_node, pre_rewriter=ast_or_into_bitor, ) - self.assertIn( + self.assertInOr( ( - "torch.cond(hidden_states.dtype == torch.float16 and " - "torch.isinf(hidden_states).any()" - " | torch.isnan(hidden_states).any(), " - "branch_cond_then_1, branch_cond_else_1, [hidden_states])" + ( + "torch.cond(hidden_states.dtype == torch.float16 and " + "torch.isinf(hidden_states).any()" + " | torch.isnan(hidden_states).any(), " + "branch_cond_then_1, branch_cond_else_1, [hidden_states])" + ), + # transformers>=5.2 + ( + "torch.cond(hidden_states.dtype == torch.float16 and " + "(not torch.isfinite(hidden_states).all()), " + "branch_cond_then_1, branch_cond_else_1, [hidden_states])" + ), ), rewritten.code, ) diff --git a/onnx_diagnostic/reference/evaluator.py b/onnx_diagnostic/reference/evaluator.py index b57caf54..87d7d1f6 100644 --- a/onnx_diagnostic/reference/evaluator.py +++ b/onnx_diagnostic/reference/evaluator.py @@ -17,7 +17,6 @@ from .ops.op_qlinear_average_pool import QLinearAveragePool from .ops.op_qlinear_conv import QLinearConv from .ops.op_quick_gelu import QuickGelu -from .ops.op_replace_zero import ReplaceZero from .ops.op_rotary import Rotary from .ops.op_scan import Scan from .ops.op_scatter_elements import ScatterElements @@ -68,7 +67,6 @@ class ExtendedReferenceEvaluator(ReferenceEvaluator): QLinearConv, QLinearAveragePool, QuickGelu, - ReplaceZero, Rotary, Scan, ScatterElements, diff --git a/onnx_diagnostic/reference/ops/op_replace_zero.py b/onnx_diagnostic/reference/ops/op_replace_zero.py deleted file mode 100644 index 8349c502..00000000 --- a/onnx_diagnostic/reference/ops/op_replace_zero.py +++ /dev/null @@ -1,13 +0,0 @@ -from onnx.reference.op_run import OpRun - - -class ReplaceZero(OpRun): - op_domain = "onnx_extended.ortops.optim.cuda" - - def _run(self, X, by=None, equal=None): - x2 = X.copy().flatten() - if equal: - x2[x2 == 0] = by - else: - x2[x2 != 0] = by - return (x2.reshape(X.shape),) From 3986a85cec836ee9687791e0cbe3ccac4047325c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 17 Feb 2026 17:03:03 +0100 Subject: [PATCH 3/4] disable still not fixed examples --- _unittests/ut_torch_export_patches/test_patch_torch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/_unittests/ut_torch_export_patches/test_patch_torch.py b/_unittests/ut_torch_export_patches/test_patch_torch.py index 33905bfa..0c28ff3b 100644 --- a/_unittests/ut_torch_export_patches/test_patch_torch.py +++ b/_unittests/ut_torch_export_patches/test_patch_torch.py @@ -45,7 +45,7 @@ def forward(self, x, y): got = ep.module()(x, y) self.assertEqualArray(expected, got) - @requires_torch("2.11") + @requires_torch("2.12") def test_export_vmap(self): class Model(torch.nn.Module): def forward(self, x, y): @@ -510,7 +510,7 @@ def _batch1(t): got = ep.module()(**torch_deepcopy(inputs)) self.assertEqualArrayAny(expected, got) - @requires_torch("2.11", "Eq(s3, Max(s10, s3)) is inconsistent!, until we know more") + @requires_torch("2.12", "Eq(s3, Max(s10, s3)) is inconsistent!, until we know more") def test_patch_tiny_llm_dim_meta_level_1(self): class Model(torch.nn.Module): def forward(self, x, ind1, ind2): From 728c9e35f6d6840abb2136eece420e4bce6ded88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 17 Feb 2026 17:50:46 +0100 Subject: [PATCH 4/4] disable one more test --- _unittests/ut_torch_models/test_tiny_llms_onnx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/_unittests/ut_torch_models/test_tiny_llms_onnx.py b/_unittests/ut_torch_models/test_tiny_llms_onnx.py index 70df68b5..45e98725 100644 --- a/_unittests/ut_torch_models/test_tiny_llms_onnx.py +++ b/_unittests/ut_torch_models/test_tiny_llms_onnx.py @@ -71,7 +71,7 @@ def test_onnx_export_tiny_llm_xdbg(self): @ignore_warnings((UserWarning, DeprecationWarning, FutureWarning)) @hide_stdout() - @requires_torch("2.11.99") # this test broke on CI but works locally + @requires_torch("2.12.99") # this test broke on CI but works locally def test_bypass_onnx_export_tiny_llm_official_nopositionids(self): data = get_tiny_llm() model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]