Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ onnx-diagnostic: investigate onnx models
.. image:: https://github.com/sdpython/onnx-diagnostic/actions/workflows/documentation.yml/badge.svg
:target: https://github.com/sdpython/onnx-diagnostic/actions/workflows/documentation.yml

.. image:: https://badge.fury.io/py/onnx-diagnostic.svg
:target: http://badge.fury.io/py/onnx-diagnostic
.. image:: https://img.shields.io/pypi/v/onnx-diagnostic.svg
:target: https://pypi.org/project/onnx-diagnostic

.. image:: https://img.shields.io/badge/license-MIT-blue.svg
:alt: MIT License
Expand All @@ -19,6 +19,9 @@ onnx-diagnostic: investigate onnx models
:target: https://github.com/sdpython/onnx-diagnostic/
:alt: size

.. image:: https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json
:target: https://github.com/astral-sh/ruff

.. image:: https://img.shields.io/badge/code%20style-black-000000.svg
:target: https://github.com/psf/black

Expand Down
13 changes: 0 additions & 13 deletions _doc/api/reference/ops/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,24 @@
onnx_diagnostic.reference.ops
=============================



.. toctree::
:maxdepth: 1
:caption: modules


op_add_add_mul_mul
op_average_pool_grad
op_cast_like
op_complex
op_concat
op_constant_of_shape
op_fused_matmul
op_gather_grad
op_memcpy_host
op_mul_sigmoid
op_negxplus1
op_quick_gelu
op_replace_zero
op_rotary
op_qlinear_average_pool
op_qlinear_conv
op_scatter_elements
op_scatternd_of_shape
op_simplified_layer_normalization
op_skip_layer_normalization
op_slice
op_transpose_cast
op_tri_matrix


.. automodule:: onnx_diagnostic.reference.ops
:members:
Expand Down
6 changes: 0 additions & 6 deletions _doc/api/reference/ops/op_add_add_mul_mul.rst

This file was deleted.

6 changes: 0 additions & 6 deletions _doc/api/reference/ops/op_average_pool_grad.rst

This file was deleted.

6 changes: 0 additions & 6 deletions _doc/api/reference/ops/op_gather_grad.rst

This file was deleted.

6 changes: 0 additions & 6 deletions _doc/api/reference/ops/op_mul_sigmoid.rst

This file was deleted.

6 changes: 0 additions & 6 deletions _doc/api/reference/ops/op_negxplus1.rst

This file was deleted.

6 changes: 0 additions & 6 deletions _doc/api/reference/ops/op_replace_zero.rst

This file was deleted.

6 changes: 0 additions & 6 deletions _doc/api/reference/ops/op_scatternd_of_shape.rst

This file was deleted.

5 changes: 0 additions & 5 deletions _doc/api/reference/ops/op_transpose_cast.rst

This file was deleted.

6 changes: 0 additions & 6 deletions _doc/api/reference/ops/op_tri_matrix.rst

This file was deleted.

4 changes: 2 additions & 2 deletions _doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ onnx-diagnostic: investigate onnx models
.. image:: https://github.com/sdpython/onnx-diagnostic/actions/workflows/documentation.yml/badge.svg
:target: https://github.com/sdpython/onnx-diagnostic/actions/workflows/documentation.yml

.. image:: https://badge.fury.io/py/onnx-diagnostic.svg
:target: http://badge.fury.io/py/onnx-diagnostic
.. image:: https://img.shields.io/pypi/v/onnx-diagnostic.svg
:target: https://pypi.org/project/onnx-diagnostic

.. image:: https://img.shields.io/badge/license-MIT-blue.svg
:alt: MIT License
Expand Down
38 changes: 37 additions & 1 deletion _unittests/ut_reference/test_reference_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def test_quick_gelu(self):
got = ref.run(None, {"X": a})
self.assertEqualArray(expected[0], got[0])

def test_scatter_elements(self):
def test_scatter_elements_4d(self):
model = oh.make_model(
oh.make_graph(
[
Expand Down Expand Up @@ -149,6 +149,42 @@ def test_scatter_elements(self):
got = ref.run(None, {"data": data, "indices": indices, "updates": updates})
self.assertEqualArray(y, got[0])

def test_scatter_elements_3d(self):
ys = [
np.array([1, 0, 0, 0, 0, 0, 0, 0], dtype=np.float32).reshape((2, 2, 2)),
np.array([1, 0, 0, 0, 0, 0, 0, 0], dtype=np.float32).reshape((2, 2, 2)),
np.array([1, 0, 0, 0, 0, 0, 0, 0], dtype=np.float32).reshape((2, 2, 2)),
]

for axis, y in zip([0, 1, 2], ys):
model = oh.make_model(
oh.make_graph(
[
oh.make_node(
"ScatterElements",
["data", "indices", "updates"],
["Z"],
axis=axis,
reduction="add",
)
],
"name",
[
oh.make_tensor_value_info("data", TensorProto.FLOAT, None),
oh.make_tensor_value_info("indices", TensorProto.INT64, None),
oh.make_tensor_value_info("updates", TensorProto.FLOAT, None),
],
[make_tensor_value_info("Z", TensorProto.FLOAT, None)],
),
opset_imports=[make_opsetid("", 18)],
)
data = np.zeros(2**3, dtype=np.float32).reshape((2, 2, 2))
indices = np.array([[[0]]], dtype=np.int64)
updates = np.array([[[1]]], dtype=np.float32)
ref = ExtendedReferenceEvaluator(model)
got = ref.run(None, {"data": data, "indices": indices, "updates": updates})
self.assertEqualArray(y, got[0])

def test_skip_layer_normalization_nobias(self):
import onnxruntime

Expand Down
18 changes: 13 additions & 5 deletions _unittests/ut_torch_export_patches/test_patch_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,12 +423,20 @@ def filter_node(node) -> bool:
filter_node=filter_node,
pre_rewriter=ast_or_into_bitor,
)
self.assertIn(
self.assertInOr(
(
"torch.cond(hidden_states.dtype == torch.float16 and "
"torch.isinf(hidden_states).any()"
" | torch.isnan(hidden_states).any(), "
"branch_cond_then_1, branch_cond_else_1, [hidden_states])"
(
"torch.cond(hidden_states.dtype == torch.float16 and "
"torch.isinf(hidden_states).any()"
" | torch.isnan(hidden_states).any(), "
"branch_cond_then_1, branch_cond_else_1, [hidden_states])"
),
# transformers>=5.2
(
"torch.cond(hidden_states.dtype == torch.float16 and "
"(not torch.isfinite(hidden_states).all()), "
"branch_cond_then_1, branch_cond_else_1, [hidden_states])"
),
),
rewritten.code,
)
Expand Down
4 changes: 2 additions & 2 deletions _unittests/ut_torch_export_patches/test_patch_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def forward(self, x, y):
got = ep.module()(x, y)
self.assertEqualArray(expected, got)

@requires_torch("2.11")
@requires_torch("2.12")
def test_export_vmap(self):
class Model(torch.nn.Module):
def forward(self, x, y):
Expand Down Expand Up @@ -510,7 +510,7 @@ def _batch1(t):
got = ep.module()(**torch_deepcopy(inputs))
self.assertEqualArrayAny(expected, got)

@requires_torch("2.11", "Eq(s3, Max(s10, s3)) is inconsistent!, until we know more")
@requires_torch("2.12", "Eq(s3, Max(s10, s3)) is inconsistent!, until we know more")
def test_patch_tiny_llm_dim_meta_level_1(self):
class Model(torch.nn.Module):
def forward(self, x, ind1, ind2):
Expand Down
2 changes: 1 addition & 1 deletion _unittests/ut_torch_models/test_tiny_llms_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_onnx_export_tiny_llm_xdbg(self):

@ignore_warnings((UserWarning, DeprecationWarning, FutureWarning))
@hide_stdout()
@requires_torch("2.11.99") # this test broke on CI but works locally
@requires_torch("2.12.99") # this test broke on CI but works locally
def test_bypass_onnx_export_tiny_llm_official_nopositionids(self):
data = get_tiny_llm()
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
Expand Down
23 changes: 23 additions & 0 deletions _unittests/ut_xrun_doc/test_command_lines_exe.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,29 @@ def test_m_parser_partition(self):
text = st.getvalue()
self.assertIn("-- done", text)

def test_n_parser_export_sample(self):
st = StringIO()
with redirect_stdout(st):
main(["exportsample", "-m", "arnir0/Tiny-LLM", "--run", "-v", "1"])
text = st.getvalue()
self.assertIn("def get_model_with_inputs(", text)
st = StringIO()
with redirect_stdout(st):
main(
[
"exportsample",
"-m",
"arnir0/Tiny-LLM",
"--run",
"-v",
"1",
"--export",
"custom",
]
)
text = st.getvalue()
self.assertIn("def get_model_with_inputs(", text)


if __name__ == "__main__":
unittest.main(verbosity=2)
36 changes: 0 additions & 36 deletions onnx_diagnostic/reference/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,7 @@
from onnx.defs import get_schema
from onnx.reference import ReferenceEvaluator
from onnx.reference.op_run import OpRun
from .ops.op_add_add_mul_mul import (
AddAdd,
AddMul,
AddSharedInput,
MulAdd,
MulMul,
MulSharedInput,
MulSub,
SubMul,
)
from .ops.op_attention import Attention
from .ops.op_average_pool_grad import AveragePoolGrad
from .ops.op_bias_softmax import BiasSoftmax
from .ops.op_cast_like import CastLike_15, CastLike_19
from .ops.op_complex import ComplexModule, ToComplex
Expand All @@ -24,23 +13,16 @@
from .ops.op_fused_matmul import FusedMatMul
from .ops.op_gather import Gather
from .ops.op_gather_elements import GatherElements
from .ops.op_gather_grad import GatherGrad
from .ops.op_memcpy_host import MemcpyFromHost, MemcpyToHost
from .ops.op_mul_sigmoid import MulSigmoid
from .ops.op_negxplus1 import NegXplus1
from .ops.op_qlinear_average_pool import QLinearAveragePool
from .ops.op_qlinear_conv import QLinearConv
from .ops.op_quick_gelu import QuickGelu
from .ops.op_replace_zero import ReplaceZero
from .ops.op_rotary import Rotary
from .ops.op_scan import Scan
from .ops.op_scatter_elements import ScatterElements
from .ops.op_scatternd_of_shape import MaskedScatterNDOfShape, ScatterNDOfShape
from .ops.op_simplified_layer_normalization import SimplifiedLayerNormalization
from .ops.op_skip_layer_normalization import SkipLayerNormalization
from .ops.op_slice import Slice_1, Slice_10
from .ops.op_transpose_cast import Transpose2DCastFP16, Transpose2DCastFP32
from .ops.op_tri_matrix import TriMatrix

logger = getLogger("onnx-diagnostic-eval")

Expand Down Expand Up @@ -70,11 +52,7 @@ class ExtendedReferenceEvaluator(ReferenceEvaluator):
"""

default_ops: List[type[OpRun]] = [
AddAdd,
AddMul,
AddSharedInput,
Attention,
AveragePoolGrad,
BiasSoftmax,
Concat,
CastLike_15,
Expand All @@ -84,33 +62,19 @@ class ExtendedReferenceEvaluator(ReferenceEvaluator):
FusedMatMul,
Gather,
GatherElements,
GatherGrad,
MaskedScatterNDOfShape,
MemcpyFromHost,
MemcpyToHost,
MulAdd,
MulMul,
MulSharedInput,
MulSigmoid,
MulSub,
NegXplus1,
QLinearConv,
QLinearAveragePool,
QuickGelu,
ReplaceZero,
Rotary,
Scan,
ScatterElements,
ScatterNDOfShape,
SimplifiedLayerNormalization,
SkipLayerNormalization,
Slice_1,
Slice_10,
SubMul,
ToComplex,
Transpose2DCastFP16,
Transpose2DCastFP32,
TriMatrix,
]

@staticmethod
Expand Down
Loading
Loading