From 5380d5576449cbbfc641a8d628c706201697e68a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 18 Feb 2026 11:52:27 +0100 Subject: [PATCH 1/4] improves documentation --- .../plot_export_gemma3_tiny_input_observer.py | 2 +- .../ut_investigate/test_input_observer.py | 4 +- .../test_input_observer_transformers.py | 4 +- onnx_diagnostic/investigate/input_observer.py | 147 +++++++++++++++--- 4 files changed, 130 insertions(+), 27 deletions(-) diff --git a/_doc/final/plot_export_gemma3_tiny_input_observer.py b/_doc/final/plot_export_gemma3_tiny_input_observer.py index ce675056..c5c08aa2 100644 --- a/_doc/final/plot_export_gemma3_tiny_input_observer.py +++ b/_doc/final/plot_export_gemma3_tiny_input_observer.py @@ -53,7 +53,7 @@ # %% # Captures inputs and outputs for the model. observer = InputObserver( - missing=dict(pixel_values=torch.empty((0, 3, 896, 896), dtype=torch.float16)) + value_if_missing=dict(pixel_values=torch.empty((0, 3, 896, 896), dtype=torch.float16)) ) with ( register_additional_serialization_functions(patch_transformers=True), diff --git a/_unittests/ut_investigate/test_input_observer.py b/_unittests/ut_investigate/test_input_observer.py index 4338985b..0eb1d3e8 100644 --- a/_unittests/ut_investigate/test_input_observer.py +++ b/_unittests/ut_investigate/test_input_observer.py @@ -929,7 +929,9 @@ def forward( ] model = Model() - observer = InputObserver(missing=dict(pixel_values=torch.empty((0, 3, 896, 896)))) + observer = InputObserver( + value_if_missing=dict(pixel_values=torch.empty((0, 3, 896, 896))) + ) with observer(model): for kwargs in inputs: model(**kwargs) diff --git a/_unittests/ut_investigate/test_input_observer_transformers.py b/_unittests/ut_investigate/test_input_observer_transformers.py index 7b1b75af..81acda79 100644 --- a/_unittests/ut_investigate/test_input_observer_transformers.py +++ b/_unittests/ut_investigate/test_input_observer_transformers.py @@ -279,7 +279,9 @@ def forward( ] model = Model() - observer = InputObserver(missing=dict(pixel_values=torch.empty((0, 3, 896, 896)))) + observer = InputObserver( + value_if_missing=dict(pixel_values=torch.empty((0, 3, 896, 896))) + ) with ( register_additional_serialization_functions(patch_transformers=True), observer(model), diff --git a/onnx_diagnostic/investigate/input_observer.py b/onnx_diagnostic/investigate/input_observer.py index 2b03121f..33bfcf68 100644 --- a/onnx_diagnostic/investigate/input_observer.py +++ b/onnx_diagnostic/investigate/input_observer.py @@ -25,7 +25,7 @@ def _flatten_unflatten_for_dynamic_shapes( the context gives the dictionary keys but it is not expressed in the dynamic shapes, these specifications seems to be different for the strict and non strict mode. It also preserves tuple. - change_function: If not empty, this function is called to modify the tensors + change_function: If not None, this function is called to modify the tensors in the structure itself, like replace them by a shape. Returns: @@ -290,7 +290,7 @@ class InputObserverInfo: to be the same in the ordered dictionaries `add_inputs` receive. default_values: Default values defined by the signature of the function, any value equal to that is ignored to simplify the export. - missing: If a named argument (in kwargs) is missing, + value_if_missing: If a named argument (in kwargs) is missing, a default value will be taken in this dictionary, this is used when after the prefill step, an argument disappears (such as `pixel_values`) and another one @@ -299,7 +299,7 @@ class InputObserverInfo: not to run the model. args_name_and_position: Name of parameter `*args` and its position if it exists. - kwargs_name: Name of parameter `**kwargs` if it exists. + kwargs_name: Name of the variable keyword parameter `**kwargs` if it exists. This is used by class :class:`InputObserver`. """ @@ -308,12 +308,12 @@ def __init__( self, signature_names: list[str], default_values: dict[str, int | bool | str | float], - missing: dict[str, Any], + value_if_missing: dict[str, Any], args_name_and_position: tuple[str, int] | None, kwargs_name: str | None, ): self.default_values = default_values - self.missing = missing + self.value_if_missing = value_if_missing self.inputs: list[InputCandidate] = [] self.outputs_specs: list[torch.utils._pytree.PyTreeSpec] = [] self.flat_outputs: list[list[torch.Tensor | None]] = [] @@ -349,12 +349,25 @@ def add_inputs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): if v is not None and not isinstance(v, (int, float, bool, str)) } - # adds missing attributes - for k, v in self.missing.items(): + # adds value_if_missing attributes + for k, v in self.value_if_missing.items(): if k not in kwargs: + # Validate that `value_if_missing` keys are compatible + # with the observed signature. + # If the function does not accept **kwargs, + # all value_if_missing keys must be + # present in the observed signature names. + if k not in self.signature_names and not self.kwargs_name: + raise ValueError( + f"Unexpected keyword argument '{k}' " + f"provided as a value_if_missing input " + "for a function that does not accept it. " + f"All value_if_missing keys must " + f"be in the observed signature: {tuple(self.signature_names)}." + ) kwargs[k] = v - # kwargs may come in a different ordeer teach. + # kwargs may come in a different order each time. # dictionaries are ordered and torch.export.export expects # dynamic shapes and kwargs to follow the same order. @@ -515,7 +528,8 @@ def _set_batch_dimension_for_flat_index(index) -> bool: **dict(zip(pos_names, flat_dynamic_shapes[:n_args])), **dict( zip( - list(self._best_candidate.kwargs), flat_dynamic_shapes[n_args:] + list(self._best_candidate.kwargs), + flat_dynamic_shapes[n_args:], ) ), **dict.fromkeys(self._best_candidate.cst_kwargs, None), @@ -531,7 +545,10 @@ def _set_batch_dimension_for_flat_index(index) -> bool: **dict(zip(pos_names, flat_dynamic_shapes[:n_args])), var_pos: tuple(flat_dynamic_shapes[n_args:i_kwargs]), **dict( - zip(list(self._best_candidate.kwargs), flat_dynamic_shapes[i_kwargs:]) + zip( + list(self._best_candidate.kwargs), + flat_dynamic_shapes[i_kwargs:], + ) ), **dict.fromkeys(self._best_candidate.cst_kwargs, None), } @@ -602,12 +619,34 @@ def infer_arguments( flat: bool = False, as_args_kwargs: bool = False, ) -> ( - list[torch.Tensor] + list[torch.Tensor | None] | tuple[torch.Tensor, ...] | dict[str, torch.Tensor] | tuple[list[torch.Tensor] | tuple[torch.Tensor, ...], dict[str, torch.Tensor]] ): - """Infers arguments based on the collected tensors.""" + """Infers arguments based on the collected tensors. + + Args: + index_or_candidate: If missing, the method selects one set of inputs + among the available ones, usually the set of inputs containing + with the highest number of tensors. + It then replaces None values and missing tensors with empty tensors. + If not missing, it can be an integer to fetch one of the stored set + or some inputs. + flat: If True, it returns a flattened list of tensors, + if False, it returns a tuple or a dictionary preserving + the nested structures. The flat version is used internally. + It produces a single list of tensors easier to process or modify + rather than a nested structure holding the same tensors. + The original structure can be restored with + ``torch.utils._pytree.tree_unflatten(flat_list, self.aligned_spec)``. + This mechanism is used to replace None values by empty tensors. + as_args_kwargs: If True, the method always returns `(args, kwargs)`, + otherwise, it returns either a tuple (only args) or a dictionary + (only kwargs) or raises an exception if it cannot do so. + Returns: + Inferred arguments, every optional tensor is replaced by an empty tensor. + """ # This is already checked by _build_inputs_completed_with_none_values # but this is not always well captured by tools checking types. self.align_inputs_none_values() @@ -616,8 +655,8 @@ def infer_arguments( if index_or_candidate is None: for cand in self.inputs: args, kwargs = cand.args, cand.kwargs - if len(args) == len(self._best_candidate.args) and len(kwargs) == len( - self._best_candidate.kwargs + if len(args) == len(self._best_candidate.args or ()) and len(kwargs) == len( + self._best_candidate.kwargs or {} ): candidate = cand break @@ -724,10 +763,11 @@ def infer_arguments( return tuple(args), kwargs def _post_process_for_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]: - """:func:`torch.export.export` requires to have dynamic shapes - and keyword arguments wrapped into `'kwargs': { 'param': shape or tensor }` - if 'param' is not part of the signature but is caught through `**kwargs`. - This function ensures this is the case. + """:func:`torch.export.export` requires dynamic shapes and keyword arguments + that are not part of the explicit function signature but are collected via + ``**`` to be wrapped under the corresponding parameter name + (``self.kwargs_name``) as ``{: {'param': shape or tensor}}``. + This function ensures this wrapping is performed when ``self.kwargs_name`` is set. """ if not self.kwargs_name: # Nothing to do here. @@ -737,6 +777,13 @@ def _post_process_for_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]: return kwargs keywords = {k: v for k, v in kwargs.items() if k in to_be_moved} new_kwargs = {k: v for k, v in kwargs.items() if k not in to_be_moved} + if self.kwargs_name in new_kwargs: + raise ValueError( + f"Keyword argument name collision: received a keyword argument " + f"'{self.kwargs_name}' which conflicts with the **{self.kwargs_name} " + "parameter used to collect extra keyword arguments. " + "Passing a keyword argument with this name is not supported." + ) return {**new_kwargs, self.kwargs_name: keywords} @@ -746,7 +793,7 @@ class InputObserver: export arguments. Args: - missing: If a named argument (in kwargs) is missing, + value_if_missing: If a named argument (in kwargs) is missing, a default value will be taken in this dictionary, this is used when after the prefill step, an argument disappears (such as `pixel_values`) and another one @@ -778,14 +825,61 @@ class InputObserver: >>> dynamic_shapes.input_observer.infer_dynamic_shapes(), >>> ) + The last example considers an LLM taking images and text as inputs. + The first call to the forward method which we try to export has `pixel_values` + but no `past_key_values`. The next calls do not have `pixel_values` but + `past_key_values`. The observer understands `pixel_values` and `past_key_values` + are needed but they may not be both specified at the same time. + Since `pixel_values` only appears in the first call, the observer cannot + tell how to infer an empty tensor for this argument. That's what the argument + `value_if_missing` is for. The following example is more than a dummy example + but shows how to use it with ``transformers``. + + .. code-block:: python + + from transformers import pipeline + + model_id = "tiny-random/gemma-3" + pipe = pipeline( + "image-text-to-text", + model=model_id, + device="cpu", + trust_remote_code=True, + max_new_tokens=3, + dtype=torch.float16, + ) + messages = [ + { + "role": "system", + "content": [{"type": "text", "text": "You are a helpful assistant."}], + }, + { + "role": "user", + "content": [ + { + "type": "image", + "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/p-blog/candy.JPG", + }, + {"type": "text", "text": "What animal is on the candy?"}, + ], + }, + ] + observer = InputObserver( + value_if_missing=dict( + pixel_values=torch.empty((0, 3, 896, 896), dtype=torch.float16) + ) + ) + with observer(pipe.model): + pipe(text=messages, max_new_tokens=4) + Examples can be found in :ref:`l-plot-tiny-llm-export-input-observer`, :ref:`l-plot-whisper-tiny-export-input-observer`, :ref:`l-plot-gemma3-tiny-export-input-observer`. """ - def __init__(self, missing: dict[str, Any] | None = None): + def __init__(self, value_if_missing: dict[str, Any] | None = None): self.info: InputObserverInfo | None = None # type: ignore[annotation-unchecked] - self.missing = missing or {} + self.value_if_missing = value_if_missing or {} def _replaced_method( self, @@ -851,7 +945,7 @@ def __call__( if p.default != inspect.Parameter.empty and isinstance(p.default, (int, bool, str, float)) }, - missing=self.missing, + value_if_missing=self.value_if_missing, args_name_and_position=args_names[0] if args_names else None, kwargs_name=kwargs_names[0] if kwargs_names else None, ) @@ -906,7 +1000,7 @@ def infer_arguments( flat: bool = False, as_args_kwargs: bool = False, ) -> ( - list[torch.Tensor] + list[torch.Tensor | None] | tuple[torch.Tensor, ...] | dict[str, torch.Tensor] | tuple[list[torch.Tensor] | tuple[torch.Tensor, ...], dict[str, torch.Tensor]] @@ -924,7 +1018,12 @@ def infer_arguments( flat: If True, it returns a flattened list of tensors, if False, it returns a tuple or a dictionary preserving - the nested structures. + the nested structures. The flat version is used internally. + It produces a single list of tensors easier to process or modify + rather than a nested structure holding the same tensors. + The original structure can be restored with + ``torch.utils._pytree.tree_unflatten(flat_list, self.aligned_spec)``. + This mechanism is used to replace None values by empty tensors. as_args_kwargs: If True, the method always returns `(args, kwargs)`, otherwise, it returns either a tuple (only args) or a dictionary From 6980eb0e953eec640eb86162c440fa1c83d07b56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 18 Feb 2026 13:19:22 +0100 Subject: [PATCH 2/4] supports integers for value_if_missing --- .../ut_investigate/test_input_observer.py | 150 ++++++++++++++++-- .../test_input_observer_transformers.py | 4 +- onnx_diagnostic/investigate/input_observer.py | 57 +++++-- 3 files changed, 182 insertions(+), 29 deletions(-) diff --git a/_unittests/ut_investigate/test_input_observer.py b/_unittests/ut_investigate/test_input_observer.py index 0eb1d3e8..056c4765 100644 --- a/_unittests/ut_investigate/test_input_observer.py +++ b/_unittests/ut_investigate/test_input_observer.py @@ -887,7 +887,7 @@ def forward(self, x=None, y=None): # self.assertEqual(2, len(args)) # self.assertEqual(len([v for v in args.values() if v is not None]), 2) - def test_infer_dynamic_shapes_missing(self): + def test_infer_dynamic_shapes_missing_kwargs(self): class Model(torch.nn.Module): def forward( self, @@ -903,26 +903,26 @@ def forward( inputs = [ dict( - input_ids=torch.ones((1, 282), dtype=torch.int64), - pixel_values=torch.ones((1, 3, 896, 896), dtype=torch.int64), - attention_mask=torch.ones((1, 282), dtype=torch.int64), - position_ids=torch.ones((1, 282), dtype=torch.int64), - token_type_ids=torch.ones((1, 282), dtype=torch.int64), - cache_position=torch.ones((282,), dtype=torch.int64), + input_ids=torch.ones((1, 28), dtype=torch.int64), + pixel_values=torch.ones((1, 3, 112, 112), dtype=torch.int64), + attention_mask=torch.ones((1, 28), dtype=torch.int64), + position_ids=torch.ones((1, 28), dtype=torch.int64), + token_type_ids=torch.ones((1, 28), dtype=torch.int64), + cache_position=torch.ones((28,), dtype=torch.int64), ), dict( input_ids=torch.ones((1, 1), dtype=torch.int64), - attention_mask=torch.ones((1, 283), dtype=torch.int64), + attention_mask=torch.ones((1, 29), dtype=torch.int64), position_ids=torch.ones((1, 1), dtype=torch.int64), - past_key_values=torch.rand((1, 1, 282, 32)), + past_key_values=torch.rand((1, 1, 28, 32)), token_type_ids=torch.ones((1, 1), dtype=torch.int64), cache_position=torch.ones((1,), dtype=torch.int64), ), dict( input_ids=torch.ones((1, 1), dtype=torch.int64), - attention_mask=torch.ones((1, 284), dtype=torch.int64), + attention_mask=torch.ones((1, 30), dtype=torch.int64), position_ids=torch.ones((1, 1), dtype=torch.int64), - past_key_values=torch.rand((1, 1, 283, 32)), + past_key_values=torch.rand((1, 1, 29, 32)), token_type_ids=torch.ones((1, 1), dtype=torch.int64), cache_position=torch.ones((1,), dtype=torch.int64), ), @@ -930,7 +930,7 @@ def forward( model = Model() observer = InputObserver( - value_if_missing=dict(pixel_values=torch.empty((0, 3, 896, 896))) + value_if_missing=dict(pixel_values=torch.empty((0, 3, 112, 112))) ) with observer(model): for kwargs in inputs: @@ -948,6 +948,132 @@ def forward( "cache_position": {0: cst}, } self.assertEqual(expected, shapes) + kwargs = observer.infer_arguments() + self.assertEqual(list(expected), list(kwargs)) + self.assertEqual((0, 3, 112, 112), kwargs["pixel_values"].shape) + + def test_infer_dynamic_shapes_missing_args(self): + class Model(torch.nn.Module): + def forward( + self, + input_ids=None, + pixel_values=None, + attention_mask=None, + past_key_values=None, + ): + return input_ids + + inputs = [ + ( + torch.ones((1, 28), dtype=torch.int64), + torch.ones((1, 3, 112, 112), dtype=torch.int64), + torch.ones((1, 28), dtype=torch.int64), + ), + ( + torch.ones((1, 1), dtype=torch.int64), + None, + torch.ones((1, 29), dtype=torch.int64), + torch.rand((1, 1, 28, 32)), + ), + ( + torch.ones((1, 1), dtype=torch.int64), + None, + torch.ones((1, 30), dtype=torch.int64), + torch.rand((1, 1, 29, 32)), + ), + ] + + model = Model() + observer = InputObserver( + value_if_missing={1: torch.empty((0, 3, 112, 112), dtype=torch.int64)} + ) + with observer(model): + for args in inputs: + model(*args) + + shapes = observer.infer_dynamic_shapes(set_batch_dimension_for=True) + cst = torch.export.Dim.DYNAMIC + expected = ({0: cst, 1: cst}, {0: cst}, {0: cst, 1: cst}, {0: cst, 2: cst}) + self.assertEqual(expected, shapes) + args = observer.infer_arguments() + self.assertEqual(len(expected), len(args)) + self.assertEqual((0, 3, 112, 112), args[1].shape) + + def test_infer_dynamic_shapes_missing_kwargs_nested(self): + class Model(torch.nn.Module): + def forward( + self, + input_ids=None, + pixel_values=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + token_type_ids=None, + cache_position=None, + ): + return input_ids + + inputs = [ + dict( + input_ids=torch.ones((1, 28), dtype=torch.int64), + pixel_values=( + torch.ones((1, 3, 112, 112), dtype=torch.int64), + torch.ones((1, 3, 112, 112), dtype=torch.int64), + ), + attention_mask=torch.ones((1, 28), dtype=torch.int64), + position_ids=torch.ones((1, 28), dtype=torch.int64), + token_type_ids=torch.ones((1, 28), dtype=torch.int64), + cache_position=torch.ones((28,), dtype=torch.int64), + ), + dict( + input_ids=torch.ones((1, 1), dtype=torch.int64), + attention_mask=torch.ones((1, 29), dtype=torch.int64), + position_ids=torch.ones((1, 1), dtype=torch.int64), + past_key_values=torch.rand((1, 1, 28, 32)), + token_type_ids=torch.ones((1, 1), dtype=torch.int64), + cache_position=torch.ones((1,), dtype=torch.int64), + ), + dict( + input_ids=torch.ones((1, 1), dtype=torch.int64), + attention_mask=torch.ones((1, 30), dtype=torch.int64), + position_ids=torch.ones((1, 1), dtype=torch.int64), + past_key_values=torch.rand((1, 1, 29, 32)), + token_type_ids=torch.ones((1, 1), dtype=torch.int64), + cache_position=torch.ones((1,), dtype=torch.int64), + ), + ] + + model = Model() + observer = InputObserver( + value_if_missing=dict( + pixel_values=( + torch.empty((0, 3, 112, 112), dtype=torch.int64), + torch.empty((0, 3, 112, 112), dtype=torch.int64), + ) + ) + ) + with observer(model): + for kwargs in inputs: + model(**kwargs) + + shapes = observer.infer_dynamic_shapes(set_batch_dimension_for=True) + cst = torch.export.Dim.DYNAMIC + expected = { + "input_ids": {0: cst, 1: cst}, + "pixel_values": ({0: cst}, {0: cst}), + "attention_mask": {0: cst, 1: cst}, + "position_ids": {0: cst, 1: cst}, + "past_key_values": {0: cst, 2: cst}, + "token_type_ids": {0: cst, 1: cst}, + "cache_position": {0: cst}, + } + self.assertEqual(expected, shapes) + kwargs = observer.infer_arguments() + self.assertEqual(list(expected), list(kwargs)) + self.assertIsInstance(kwargs["pixel_values"], tuple) + self.assertEqual(2, len(kwargs["pixel_values"])) + self.assertEqual((0, 3, 112, 112), kwargs["pixel_values"][0].shape) + self.assertEqual((0, 3, 112, 112), kwargs["pixel_values"][1].shape) def test_io_captured_kwargs_kwargs(self): class Model(torch.nn.Module): diff --git a/_unittests/ut_investigate/test_input_observer_transformers.py b/_unittests/ut_investigate/test_input_observer_transformers.py index 81acda79..7a88222c 100644 --- a/_unittests/ut_investigate/test_input_observer_transformers.py +++ b/_unittests/ut_investigate/test_input_observer_transformers.py @@ -280,7 +280,9 @@ def forward( model = Model() observer = InputObserver( - value_if_missing=dict(pixel_values=torch.empty((0, 3, 896, 896))) + value_if_missing=dict( + pixel_values=torch.empty((0, 3, 896, 896), dtype=torch.int64) + ) ) with ( register_additional_serialization_functions(patch_transformers=True), diff --git a/onnx_diagnostic/investigate/input_observer.py b/onnx_diagnostic/investigate/input_observer.py index 33bfcf68..aa2195e9 100644 --- a/onnx_diagnostic/investigate/input_observer.py +++ b/onnx_diagnostic/investigate/input_observer.py @@ -87,7 +87,7 @@ def _infer_dynamic_dimensions( unique_ranks = {len(shape) for shape in shape_list} torch._check( len(unique_ranks) == 1, - lambda: "all shapes in shape_list must have the same rank", + lambda: f"All shapes in shape_list must have the same rank but {shape_list=}.", ) rank = unique_ranks.pop() dynamic = [] @@ -129,7 +129,6 @@ def __init__( self.args = args self.kwargs = kwargs self.flat_list, self.spec = torch.utils._pytree.tree_flatten((args, kwargs)) - self.n_tensors = sum(t is not None for t in self.flat_list) self._position_to_args_kwargs: list[int | str] | None = None self._n_tensors_for_args_kwargs: dict[int | str, int] | None = None self.cst_kwargs = cst_kwargs.copy() @@ -290,7 +289,7 @@ class InputObserverInfo: to be the same in the ordered dictionaries `add_inputs` receive. default_values: Default values defined by the signature of the function, any value equal to that is ignored to simplify the export. - value_if_missing: If a named argument (in kwargs) is missing, + value_if_missing: If an argument is missing, a default value will be taken in this dictionary, this is used when after the prefill step, an argument disappears (such as `pixel_values`) and another one @@ -308,7 +307,7 @@ def __init__( self, signature_names: list[str], default_values: dict[str, int | bool | str | float], - value_if_missing: dict[str, Any], + value_if_missing: dict[str | int, Any], args_name_and_position: tuple[str, int] | None, kwargs_name: str | None, ): @@ -351,21 +350,47 @@ def add_inputs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): # adds value_if_missing attributes for k, v in self.value_if_missing.items(): - if k not in kwargs: - # Validate that `value_if_missing` keys are compatible - # with the observed signature. - # If the function does not accept **kwargs, - # all value_if_missing keys must be - # present in the observed signature names. - if k not in self.signature_names and not self.kwargs_name: + if isinstance(k, str): + if k not in kwargs: + # Validate that `value_if_missing` keys are compatible + # with the observed signature. + # If the function does not accept **kwargs, + # all value_if_missing keys must be + # present in the observed signature names. + if k not in self.signature_names and not self.kwargs_name: + raise ValueError( + f"Unexpected keyword argument {k!r} " + f"provided as a value_if_missing input " + "for a function that does not accept it. " + f"All value_if_missing keys must " + f"be in the observed signature: {tuple(self.signature_names)}." + ) + kwargs[k] = v + elif isinstance(k, int): + if k >= len(self.signature_names): raise ValueError( - f"Unexpected keyword argument '{k}' " + f"Unexpected keyword argument {k=} " f"provided as a value_if_missing input " "for a function that does not accept it. " - f"All value_if_missing keys must " + f"All value_if_missing indices must " f"be in the observed signature: {tuple(self.signature_names)}." ) - kwargs[k] = v + if k >= len(args): + raise NotImplementedError( + f"Unexpected keyword argument {k=} " + f"provided as a value_if_missing input " + "for a function that does not accept it. " + f"All value_if_missing indices must " + f"be in the observed signature: {tuple(self.signature_names)}, " + f"only {len(args)} were given." + ) + list_args = list(args) + list_args[k] = v + args = tuple(list_args) + else: + raise TypeError( + f"Unexepcted type {type(k)} for a missing value. The key is {k!r}." + ) # kwargs may come in a different order each time. # dictionaries are ordered and torch.export.export expects @@ -793,7 +818,7 @@ class InputObserver: export arguments. Args: - value_if_missing: If a named argument (in kwargs) is missing, + value_if_missing: If an argument is missing, a default value will be taken in this dictionary, this is used when after the prefill step, an argument disappears (such as `pixel_values`) and another one @@ -877,7 +902,7 @@ class InputObserver: :ref:`l-plot-gemma3-tiny-export-input-observer`. """ - def __init__(self, value_if_missing: dict[str, Any] | None = None): + def __init__(self, value_if_missing: dict[str | int, Any] | None = None): self.info: InputObserverInfo | None = None # type: ignore[annotation-unchecked] self.value_if_missing = value_if_missing or {} From 9ac2b45ddf951cad493fca489a8822d4f8e28dc4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 18 Feb 2026 13:31:18 +0100 Subject: [PATCH 3/4] spell --- onnx_diagnostic/investigate/input_observer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnx_diagnostic/investigate/input_observer.py b/onnx_diagnostic/investigate/input_observer.py index aa2195e9..5004f466 100644 --- a/onnx_diagnostic/investigate/input_observer.py +++ b/onnx_diagnostic/investigate/input_observer.py @@ -389,7 +389,7 @@ def add_inputs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): args = tuple(list_args) else: raise TypeError( - f"Unexepcted type {type(k)} for a missing value. The key is {k!r}." + f"Unexpected type {type(k)} for a missing value. The key is {k!r}." ) # kwargs may come in a different order each time. From d20d0cc2c46bb70eaefc40da7d61b384022bf6f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 18 Feb 2026 13:46:51 +0100 Subject: [PATCH 4/4] disable more tests --- .github/workflows/check-release.yml | 2 +- .github/workflows/ci.yml | 8 +++++--- _unittests/ut_tasks/test_tasks.py | 2 +- _unittests/ut_tasks/test_tasks_image_text_to_text.py | 6 +++--- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/.github/workflows/check-release.yml b/.github/workflows/check-release.yml index 2d9be2fa..fd2206cd 100644 --- a/.github/workflows/check-release.yml +++ b/.github/workflows/check-release.yml @@ -16,7 +16,7 @@ jobs: matrix: os: [ubuntu-latest, macOS-latest, windows-latest] python: ['3.13'] - transformers: ['5.1.0', 'main'] + transformers: ['5.2.0', 'main'] torch: ['2.10', 'main'] steps: diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ed4f416b..e613de03 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -17,7 +17,7 @@ jobs: matrix: os: [ubuntu-latest] python: ['3.10', '3.11', '3.12', '3.13'] - transformers: ['4.48.3', '4.51.3', '4.55.4', '4.57.6', '5.1.0', 'main'] + transformers: ['4.48.3', '4.51.3', '4.55.4', '4.57.6', '5.2.0', 'main'] torch: ['2.10', 'main'] exclude: # 3.10 - torch @@ -29,7 +29,7 @@ jobs: - python: '3.10' transformers: '4.57.6' - python: '3.10' - transformers: '5.1.0' + transformers: '5.2.0' - python: '3.10' transformers: 'main' # 3.11 - torch @@ -41,7 +41,7 @@ jobs: - python: '3.11' transformers: '4.57.6' - python: '3.11' - transformers: '5.1.0' + transformers: '5.2.0' - python: '3.11' transformers: 'main' # 3.13 - torch @@ -54,6 +54,8 @@ jobs: transformers: '4.51.3' - python: '3.13' transformers: '4.55.4' + - python: '3.13' + transformers: '4.57.6' steps: - uses: actions/checkout@v3 diff --git a/_unittests/ut_tasks/test_tasks.py b/_unittests/ut_tasks/test_tasks.py index c0c666ae..b44eed47 100644 --- a/_unittests/ut_tasks/test_tasks.py +++ b/_unittests/ut_tasks/test_tasks.py @@ -266,7 +266,7 @@ def test_falcon_mamba_dev(self): model(**inputs) model(**data["inputs2"]) self.assertIn((data["size"], data["n_weights"]), [(274958336, 68739584)]) - if not has_transformers("5.2.99"): + if not has_transformers("5.3.99"): raise unittest.SkipTest("The model has control flow.") with torch_export_patches(patch_transformers=True, verbose=10, stop_if_static=1): torch.export.export( diff --git a/_unittests/ut_tasks/test_tasks_image_text_to_text.py b/_unittests/ut_tasks/test_tasks_image_text_to_text.py index 3d487180..a1a2a3b5 100644 --- a/_unittests/ut_tasks/test_tasks_image_text_to_text.py +++ b/_unittests/ut_tasks/test_tasks_image_text_to_text.py @@ -15,7 +15,7 @@ class TestTasksImageTextToText(ExtTestCase): @hide_stdout() - @requires_transformers("5.2.99") + @requires_transformers("5.3.99") @requires_torch("2.7.99") def test_image_text_to_text_idefics(self): mid = "HuggingFaceM4/tiny-random-idefics" @@ -32,7 +32,7 @@ def test_image_text_to_text_idefics(self): self.assertEqualAny(expected, ep.module()(**inputs), atol=1) @hide_stdout() - @requires_transformers("5.2.99") + @requires_transformers("5.3.99") @requires_torch("2.7.99") def test_image_text_to_text_tiny_gemma3(self): """ @@ -88,7 +88,7 @@ def test_image_text_to_text_gemma3_4b_it(self): self.assertEqualAny(expected, ep.module()(**inputs)) @hide_stdout() - @requires_transformers("5.2.99") + @requires_transformers("5.3.99") @requires_torch("2.7.99") def test_image_text_to_text_zai_glm(self): """