From caab6da05245ea78bf313e453ae462cc0df0ea32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 23 Feb 2026 11:13:51 +0100 Subject: [PATCH 1/6] fix cache dynamic dimension in image-text-to-txt --- onnx_diagnostic/tasks/image_text_to_text.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnx_diagnostic/tasks/image_text_to_text.py b/onnx_diagnostic/tasks/image_text_to_text.py index f26b65e5..f414480f 100644 --- a/onnx_diagnostic/tasks/image_text_to_text.py +++ b/onnx_diagnostic/tasks/image_text_to_text.py @@ -156,7 +156,7 @@ def _get_inputs_gemma3( }, "position_ids": {0: batch, 1: seq_length}, "cache_position": {0: seq_length}, - "past_key_values": [{0: batch} for _ in range(num_hidden_layers * 2)], + "past_key_values": [{0: batch, 2: seq_length} for _ in range(num_hidden_layers * 2)], "pixel_values": {0: batch}, "use_cache": None, } @@ -280,7 +280,7 @@ def get_inputs_default( "past_key_values": list( itertools.chain.from_iterable( zip( - [{0: batch} for _ in range(num_hidden_layers)], + [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], ) ) From 8989eb7bfc2bfbd2ddee8ebc55791c03eee3f3c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 23 Feb 2026 12:44:52 +0100 Subject: [PATCH 2/6] disable patches --- _unittests/ut_ci_models/test_ci_export.py | 1 + .../test_patch_transformers.py | 2 + onnx_diagnostic/ci_models/export_qwen25_vl.py | 10 +- .../_patch_transformers_generation_mixin.py | 190 +----------------- 4 files changed, 11 insertions(+), 192 deletions(-) diff --git a/_unittests/ut_ci_models/test_ci_export.py b/_unittests/ut_ci_models/test_ci_export.py index 0491adac..f99df4dd 100644 --- a/_unittests/ut_ci_models/test_ci_export.py +++ b/_unittests/ut_ci_models/test_ci_export.py @@ -20,6 +20,7 @@ def test_main_qwen25_tiny_llm(self): pretrained=False, part="", output_folder=self.get_dump_folder("test_main_qwen25_tiny_llm"), + opset=24, ) self.clean_dump() diff --git a/_unittests/ut_torch_export_patches/test_patch_transformers.py b/_unittests/ut_torch_export_patches/test_patch_transformers.py index aa654a69..47be0589 100644 --- a/_unittests/ut_torch_export_patches/test_patch_transformers.py +++ b/_unittests/ut_torch_export_patches/test_patch_transformers.py @@ -996,6 +996,8 @@ def test_prepare_inputs_for_generation_decoder_llm(self): with self.subTest(case="case5"): if not has_transformers("4.57"): raise unittest.SkipTest("transformers 4.57+.") + if has_transformers("5.2.99"): + raise unittest.SkipTest("transformers 5.2+.") with self.assertRaises((AttributeError, TypeError)): model_inputs = model.prepare_inputs_for_generation( input_ids, past_key_values=dynamic_cache diff --git a/onnx_diagnostic/ci_models/export_qwen25_vl.py b/onnx_diagnostic/ci_models/export_qwen25_vl.py index 9279f57d..e611ff6e 100644 --- a/onnx_diagnostic/ci_models/export_qwen25_vl.py +++ b/onnx_diagnostic/ci_models/export_qwen25_vl.py @@ -60,7 +60,7 @@ import sys import time import warnings -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple from .ci_helpers import ( check_for_discrepancies_and_log_everything_into_a_json_file, compute_expected_outputs, @@ -199,6 +199,7 @@ def main( atol: float = 0.01, mismatch01: float = 0.1, profile_exporter: bool = False, + opset: Optional[int] = None, ): """ Exports model Qwen/Qwen2.5-VL-7B-Instruct or pieces of it. @@ -221,6 +222,8 @@ def main( :param atol: raises an exception if tolerance is above that threshold :param mismatch01: raises an exception if the ratio of mismatches is above that threshold + :param opset: opset, if not specified, a value is chosen based on the + proposed rewriting :param profile_exporter: profiles the exporter """ prefix = simplify_model_id_for_a_filename(model_id) @@ -243,6 +246,7 @@ def main( print(f"-- make_zip={make_zip}") print(f"-- output_folder={output_folder}") print(f"-- atol={atol}") + print(f"-- opset={opset}") print(f"-- mismatch01={mismatch01}") print(f"-- profile_exporter={profile_exporter}") print("------------------------------------------------------------------") @@ -473,7 +477,7 @@ def process_image(inputs_embeds, image_features): begin = time.perf_counter() - target_opset = 22 + target_opset = opset or 22 if ( exporter == "onnx-dynamo" and device == "cuda" @@ -481,7 +485,7 @@ def process_image(inputs_embeds, image_features): ): os.environ["QWEN25ATTENTION"] = "PACKED" elif "QWEN25ATTENTION" in os.environ and os.environ["QWEN25ATTENTION"] == "LOOPA23": - target_opset = 23 + target_opset = opset or 23 with torch_export_patches( patch_torch=False, diff --git a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py index dde80d22..9813491e 100644 --- a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +++ b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py @@ -19,6 +19,7 @@ class patched_GenerationMixin: ( None if pv.Version(transformers.__version__) >= pv.Version("4.56") + and pv.Version(transformers.__version__) < pv.Version("5.2.99") else "prepare_inputs_for_generation" ), # ( @@ -297,192 +298,3 @@ def prepare_inputs_for_generation( # pragma: no cover # 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples) model_inputs.pop("labels", None) return model_inputs - - ''' - # drops a patch since it is for a very specific version. - def _sample( - self, - input_ids: torch.LongTensor, - logits_processor: "LogitsProcessorList", # noqa: F821 - stopping_criteria: "StoppingCriteriaList", # noqa: F821 - generation_config: "GenerationConfig", # noqa: F821 - synced_gpus: bool = False, - streamer: Optional["BaseStreamer"] = None, # noqa: F821 - **model_kwargs, - ) -> Union["GenerateNonBeamOutput", torch.LongTensor]: # noqa: F821 - """ - 2025/09/29: updates for Gemma3 models, fix for eager mode as well as the export. - """ - # init values - pad_token_id = generation_config._pad_token_tensor - output_attentions = generation_config.output_attentions - output_hidden_states = generation_config.output_hidden_states - output_scores = generation_config.output_scores - output_logits = generation_config.output_logits - return_dict_in_generate = generation_config.return_dict_in_generate - has_eos_stopping_criteria = any( - hasattr(criteria, "eos_token_id") for criteria in stopping_criteria - ) - do_sample = generation_config.do_sample - - # init attention / hidden states / scores tuples - scores = () if (return_dict_in_generate and output_scores) else None - raw_logits = () if (return_dict_in_generate and output_logits) else None - decoder_attentions = () if (return_dict_in_generate and output_attentions) else None - cross_attentions = () if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = ( - () if (return_dict_in_generate and output_hidden_states) else None - ) - - # if model is an encoder-decoder, retrieve encoder attention weights and hidden states - if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = ( - model_kwargs["encoder_outputs"].get("attentions") - if output_attentions - else None - ) - encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") - if output_hidden_states - else None - ) - - # keep track of which sequences are already finished - batch_size, cur_len = input_ids.shape[:2] - this_peer_finished = False - unfinished_sequences = torch.ones( - batch_size, dtype=torch.long, device=input_ids.device - ) - model_kwargs = self._get_initial_cache_position( - cur_len, input_ids.device, model_kwargs - ) - - model_forward = self.__call__ - compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config) - if compile_forward: - os.environ["TOKENIZERS_PARALLELISM"] = "0" - # If we use FA2 and a static cache, we cannot compile with fullgraph - if self.config._attn_implementation == "flash_attention_2": - # only raise warning if the user passed an explicit compile-config - if ( - generation_config.compile_config is not None - and generation_config.compile_config.fullgraph - ): - generation_config.compile_config.fullgraph = False - model_forward = self.get_compiled_call(generation_config.compile_config) - - if generation_config.prefill_chunk_size is not None: - model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs) - is_prefill = False - else: - is_prefill = True - - while self._has_unfinished_sequences( - this_peer_finished, synced_gpus, device=input_ids.device - ): - # prepare model inputs - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - - if is_prefill: - outputs = self(**model_inputs, return_dict=True) - is_prefill = False - else: - outputs = model_forward(**model_inputs, return_dict=True) - - model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - ) - if synced_gpus and this_peer_finished: - continue - - next_token_logits = outputs.logits[:, -1, :].to( - copy=True, dtype=torch.float32, device=input_ids.device - ) - - # pre-process distribution - next_token_scores = logits_processor(input_ids, next_token_logits) - - # Store scores, attentions and hidden_states when required - if return_dict_in_generate: - if output_scores: - scores += (next_token_scores,) - if output_logits: - raw_logits += (next_token_logits,) - if output_attentions: - decoder_attentions += ( - (outputs.decoder_attentions,) - if self.config.is_encoder_decoder - else (outputs.attentions,) - ) - if self.config.is_encoder_decoder: - cross_attentions += (outputs.cross_attentions,) - - if output_hidden_states: - decoder_hidden_states += ( - (outputs.decoder_hidden_states,) - if self.config.is_encoder_decoder - else (outputs.hidden_states,) - ) - - # token selection - if do_sample: - probs = torch.nn.functional.softmax(next_token_scores, dim=-1) - next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) - else: - next_tokens = torch.argmax(next_token_scores, dim=-1) - - # finished sentences should have their next token be a padding token - if has_eos_stopping_criteria: - next_tokens = next_tokens * unfinished_sequences + pad_token_id * ( - 1 - unfinished_sequences - ) - - # update generated ids, model inputs, and length for next step - # PATCHED: the two following lines, next_tokens can 2D already for this model - next_tokens_2d = ( - next_tokens if len(next_tokens.shape) == 2 else next_tokens[:, None] - ) - input_ids = torch.cat([input_ids, next_tokens_2d], dim=-1) - if streamer is not None: - streamer.put(next_tokens.cpu()) - - unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) - this_peer_finished = unfinished_sequences.max() == 0 - cur_len += 1 - - # This is needed to properly delete outputs.logits which may be very large - # for first iteration - # Otherwise a reference to outputs is kept which keeps - # the logits alive in the next iteration - del outputs - - if streamer is not None: - streamer.end() - - if return_dict_in_generate: - if self.config.is_encoder_decoder: - return transformers.generation.utils.GenerateEncoderDecoderOutput( - sequences=input_ids, - scores=scores, - logits=raw_logits, - encoder_attentions=encoder_attentions, - encoder_hidden_states=encoder_hidden_states, - decoder_attentions=decoder_attentions, - cross_attentions=cross_attentions, - decoder_hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return transformers.generation.utils.GenerateDecoderOnlyOutput( - sequences=input_ids, - scores=scores, - logits=raw_logits, - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return input_ids - ''' From de5958e38acbce5f6c2fc49c2982aef24d02d720 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 23 Feb 2026 13:30:49 +0100 Subject: [PATCH 3/6] improves serialization --- .../ut_tasks/test_tasks_image_text_to_text.py | 4 +++- .../test_patch_transformers.py | 16 ++++++++++------ onnx_diagnostic/helpers/cache_helper.py | 18 ++++++++++++------ .../patches/_patch_transformers_attention.py | 12 ++++++++++++ .../serialization/transformers_impl.py | 2 +- 5 files changed, 38 insertions(+), 14 deletions(-) diff --git a/_unittests/ut_tasks/test_tasks_image_text_to_text.py b/_unittests/ut_tasks/test_tasks_image_text_to_text.py index a1a2a3b5..efcbe64f 100644 --- a/_unittests/ut_tasks/test_tasks_image_text_to_text.py +++ b/_unittests/ut_tasks/test_tasks_image_text_to_text.py @@ -61,7 +61,9 @@ def test_image_text_to_text_tiny_gemma3(self): def test_image_text_to_text_gemma3_4b_it(self): make_hybrid_cache = get_make_hybrid_cache() if make_hybrid_cache is None: - raise unittest.SkipTest("not implemented yet for transformers>=5") + raise unittest.SkipTest( + "not implemented yet for transformers>=5 (make_hybrid_cache is None)" + ) mid = "google/gemma-3-4b-it" data = get_untrained_model_with_inputs( mid, diff --git a/_unittests/ut_torch_export_patches/test_patch_transformers.py b/_unittests/ut_torch_export_patches/test_patch_transformers.py index 47be0589..66bbc6fa 100644 --- a/_unittests/ut_torch_export_patches/test_patch_transformers.py +++ b/_unittests/ut_torch_export_patches/test_patch_transformers.py @@ -586,7 +586,7 @@ def forward( for exporter in ("custom", "onnx-dynamo"): # onnx-dynamo needs OpOverload(op='aten.sym_storage_offset' (transformers>=5.0?) if exporter == "onnx-dynamo" and not has_onnxscript("0.5.7"): - raise unittest.SkipTest("needs onnxscript>=0.5.7") + self.skipTest("needs onnxscript>=0.5.7") filename = self.get_dump_file( f"test_patched_qwen2_5_vl_vision_attention_forward.{exporter}.onnx" ) @@ -640,7 +640,7 @@ def test_qwen2_5_vl_vision_attention_iteration(self): ) for exporter in ("custom", "onnx-dynamo"): if exporter == "onnx-dynamo" and aten_sym_storage_offset is None: - raise unittest.SkipTest("update onnxscript to make this test run") + self.skipTest("update onnxscript to make this test run") # onnx-dynamo needs OpOverload(op='aten.sym_storage_offset' (transformers>=5.0?) filename = self.get_dump_file( f"test_qwen2_5_vl_vision_attention_iteration.{exporter}.onnx" @@ -909,7 +909,7 @@ def test_cache_dependant_input_preparation_exporting(self): torch.testing.assert_close(eager2, export2) with self.subTest(case="case2"): - raise unittest.SkipTest("torch 2.10+ has probably a bug here.") + self.skipTest("torch 2.10+ has probably a bug here.") input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64) inputs_embeds = torch.rand((2, 8), dtype=torch.float32) cache_position = torch.arange(0, 8, dtype=torch.int64) @@ -995,15 +995,17 @@ def test_prepare_inputs_for_generation_decoder_llm(self): with self.subTest(case="case5"): if not has_transformers("4.57"): - raise unittest.SkipTest("transformers 4.57+.") + self.skipTest("This test only works with transformers>=4.57, <5.3.") if has_transformers("5.2.99"): - raise unittest.SkipTest("transformers 5.2+.") + self.skipTest("This test is no longer valid with transformers>=5.3.") with self.assertRaises((AttributeError, TypeError)): model_inputs = model.prepare_inputs_for_generation( input_ids, past_key_values=dynamic_cache ) with self.subTest(case="case6"): + if has_transformers("5.2.99"): + self.skipTest("This test is no longer valid with transformers>=5.3.") cache_position = torch.arange(input_ids.shape[-1], dtype=torch.long).to( torch_device ) @@ -1025,6 +1027,8 @@ def test_prepare_inputs_for_generation_decoder_llm(self): ) # we still need the full attention mask! with self.subTest(case="case6.2"): + if has_transformers("5.2.99"): + self.skipTest("This test is no longer valid with transformers>=5.3.") max_cache_len = 10 batch_size = 2 query_length = input_ids.shape[-1] - init_input_ids.shape[-1] @@ -1048,7 +1052,7 @@ def test_prepare_inputs_for_generation_decoder_llm(self): with self.subTest(case="case7"): if not has_transformers("4.57"): - raise unittest.SkipTest("transformers 4.57+.") + self.skipTest("This test only works with transformers>=4.57.") init_inputs_embeds = model.get_input_embeddings()(init_input_ids) model_inputs = model.prepare_inputs_for_generation( input_ids, diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index 7f1cd81b..d8499092 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -742,16 +742,22 @@ def make_hybrid_cache( not max_batch_size and not max_cache_len ), "key_value_pairs is not empty, do not specify max_cache_len and max_batch_size" max_batch_size = key_value_pairs[0][0].shape[0] + assert max_cache_len is not None or all( + isinstance(kv[0].shape[2], int) for kv in key_value_pairs + ), ( + f"Cannot determine max_cache_len with " + f"shapes={[kv[0].shape for kv in key_value_pairs]}" + ) sets_of_dim = set(kv[0].shape[2] for kv in key_value_pairs) if len(sets_of_dim) == 1: - max_cache_len = sets_of_dim.pop() - sliding_window = max_cache_len + if max_cache_len is None: + max_cache_len = sets_of_dim.pop() else: assert ( len(sets_of_dim) == 2 ), f"Not implemented for more than 2 dimensions {sets_of_dim}" - max_cache_len = max(sets_of_dim) - sliding_window = min(sets_of_dim) + if max_cache_len is None: + max_cache_len = max(sets_of_dim) layer_types = [ "full_attention" if i == max_cache_len else "sliding_attention" for i in [kv[0].shape[2] for kv in key_value_pairs] @@ -760,8 +766,8 @@ def make_hybrid_cache( assert ( max_batch_size and max_cache_len ), "key_value_pairs is empty, max_batch_size and max_cache_len are required" - if sliding_window is None: - sliding_window = max_cache_len + if sliding_window is None: + sliding_window = max_cache_len _max_cache_len = max_cache_len _sliding_window = sliding_window diff --git a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py index 316a6dba..fd82fbc3 100644 --- a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +++ b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py @@ -139,6 +139,18 @@ def patched_sdpa_attention_forward( if is_causal is None and attention_mask is not None: is_causal = False if is_causal is not None: + torch._check(query.shape[0] > 0) + torch._check(query.shape[1] > 0) + torch._check(query.shape[2] > 0) + torch._check(query.shape[3] > 0) + torch._check(key.shape[0] > 0) + torch._check(key.shape[1] > 0) + torch._check(key.shape[2] > 0) + torch._check(key.shape[3] > 0) + torch._check(value.shape[0] > 0) + torch._check(value.shape[1] > 0) + torch._check(value.shape[2] > 0) + torch._check(value.shape[3] > 0) return ( torch.nn.functional.scaled_dot_product_attention( query, diff --git a/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py b/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py index b93c4601..4b5947d1 100644 --- a/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +++ b/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py @@ -61,7 +61,7 @@ def _flatten_key_value_cache(cache: Cache) -> Tuple[List[Any], torch.utils._pytr flat = list(itertools.chain.from_iterable(zip(ca.key_cache, ca.value_cache))) unique = set(ca.cls_layers) if ca.cls_layers else None if ( - cache.__class__.__name__ != "DynamicCache" + cache.__class__.__name__ not in ("DynamicCache", "HybridCache") or unique is None or (len(unique) == 1 and unique.pop().__name__ == "DynamicLayer") ): From 6a34e817ddb23782be9d7091a503c436bb7b1c13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 23 Feb 2026 13:58:29 +0100 Subject: [PATCH 4/6] fix ache --- .../serialization/transformers_impl.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py b/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py index 4b5947d1..cb716720 100644 --- a/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +++ b/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py @@ -43,9 +43,13 @@ KWARGS_LAYER_NAMES = { "DynamicLayer": lambda layer: "", - "DynamicSlidingWindowLayer": lambda layer: str(layer.sliding_window), + "DynamicSlidingWindowLayer": lambda layer: str( + getattr(layer, "sliding_window", getattr(layer, "max_cache_len", 0)) + ), "StaticLayer": lambda layer: "", - "StaticSlidingWindowLayer": lambda layer: str(layer.sliding_window), + "StaticSlidingWindowLayer": lambda layer: str( + getattr(layer, "sliding_window", getattr(layer, "max_cache_len", 0)) + ), } PARSE_LAYER_NAMES = { From 90679ba81b4b6581d92d079221decd2960e03330 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 23 Feb 2026 15:19:19 +0100 Subject: [PATCH 5/6] a few fixes --- onnx_diagnostic/ci_models/ci_helpers.py | 6 ++++++ onnx_diagnostic/ci_models/export_phi4_mm.py | 5 ++++- onnx_diagnostic/ci_models/export_qwen25_vl.py | 1 + onnx_diagnostic/helpers/cache_helper.py | 3 +++ onnx_diagnostic/tasks/image_text_to_text.py | 2 +- .../patches/_patch_transformers_qwen3.py | 5 ++++- 6 files changed, 19 insertions(+), 3 deletions(-) diff --git a/onnx_diagnostic/ci_models/ci_helpers.py b/onnx_diagnostic/ci_models/ci_helpers.py index 6a5611fb..a2cc2411 100644 --- a/onnx_diagnostic/ci_models/ci_helpers.py +++ b/onnx_diagnostic/ci_models/ci_helpers.py @@ -128,6 +128,12 @@ def get_parser(name: str, epilog: str = "") -> ArgumentParser: help="Profiles the exporter and outputs an html document from pyinstrument", action=BooleanOptionalAction, ) + parser.add_argument( + "--opset", + type=int, + default=0, + help="default opsets, 0 to let the exporter choose", + ) return parser diff --git a/onnx_diagnostic/ci_models/export_phi4_mm.py b/onnx_diagnostic/ci_models/export_phi4_mm.py index d68e301e..4063e8af 100644 --- a/onnx_diagnostic/ci_models/export_phi4_mm.py +++ b/onnx_diagnostic/ci_models/export_phi4_mm.py @@ -711,6 +711,7 @@ def main( atol: float = 2, mismatch01: float = 0.01, profile_exporter: bool = False, + opset: Optional[int] = None, ): """ Exports model Qwen/Qwen2.5-VL-7B-Instruct or pieces of it. @@ -733,6 +734,7 @@ def main( :param atol: raises an exception if tolerance is above that threshold :param mismatch01: raises an exception if the ratio of mismatches is above that threshold + :param opset: opset to choose :param profile_exporter: profiles the exporter """ prefix = simplify_model_id_for_a_filename(model_id) @@ -947,7 +949,7 @@ def forward( begin = time.perf_counter() - target_opset = 22 + target_opset = opset or 22 details = PatchDetails() with torch_export_patches( @@ -1062,4 +1064,5 @@ def forward( atol=args.atol, mismatch01=args.mismatch01, profile_exporter=args.profile_exporter, + opset=args.opset if args.opset > 0 else None, ) diff --git a/onnx_diagnostic/ci_models/export_qwen25_vl.py b/onnx_diagnostic/ci_models/export_qwen25_vl.py index e611ff6e..bdc83bfc 100644 --- a/onnx_diagnostic/ci_models/export_qwen25_vl.py +++ b/onnx_diagnostic/ci_models/export_qwen25_vl.py @@ -569,4 +569,5 @@ def process_image(inputs_embeds, image_features): atol=args.atol, mismatch01=args.mismatch01, profile_exporter=args.profile_exporter, + opset=args.opset if args.opset > 0 else None, ) diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index d8499092..e0416b3f 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -654,6 +654,7 @@ def make_hybrid_cache( max_cache_len: Optional[int] = None, max_batch_size: Optional[int] = None, sliding_window: Optional[int] = None, + cls_layers: Optional[List[type]] = None, ) -> transformers.cache_utils.HybridCache: """ Creates an instance of :class:`transformers.cache_utils.HybridCache`. @@ -662,6 +663,8 @@ def make_hybrid_cache( :param key_value_pairs: list of pairs of (key, values) :return: :class:`transformers.cache_utils.HybridCache` + `cls_layers` is unused. + Example: .. runpython:: diff --git a/onnx_diagnostic/tasks/image_text_to_text.py b/onnx_diagnostic/tasks/image_text_to_text.py index f414480f..bbce6936 100644 --- a/onnx_diagnostic/tasks/image_text_to_text.py +++ b/onnx_diagnostic/tasks/image_text_to_text.py @@ -156,7 +156,7 @@ def _get_inputs_gemma3( }, "position_ids": {0: batch, 1: seq_length}, "cache_position": {0: seq_length}, - "past_key_values": [{0: batch, 2: seq_length} for _ in range(num_hidden_layers * 2)], + "past_key_values": [{0: batch} for _ in range(num_hidden_layers * 2)], "pixel_values": {0: batch}, "use_cache": None, } diff --git a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py index 3b793f87..64e9e09f 100644 --- a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +++ b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py @@ -1,9 +1,12 @@ +import packaging.version as pv import torch +import transformers try: import transformers.models.qwen3_moe - patch_qwen3 = True + # Experts were refactored in transformers>=5.3 to use grouped_mm. + patch_qwen3 = pv.Version(transformers.__version__) < pv.Version("5.2.99") except ImportError: patch_qwen3 = False From f64fb83823fec0b94fb3b39157e7ba6644d21656 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 23 Feb 2026 16:22:01 +0100 Subject: [PATCH 6/6] cache --- onnx_diagnostic/helpers/cache_helper.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index e0416b3f..a672e411 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -539,8 +539,10 @@ def make_encoder_decoder_cache( def make_mamba_cache( key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], + cls_layers: Optional[Union[str, List[type]]] = None, + cls_kwargs: Optional[Union[Dict[str, int], List[Dict[str, int]]]] = None, ) -> "MambaCache": # noqa: F821 - "Creates a ``MambaCache``." + """Creates a ``MambaCache``. `cls_layers`, `cls_kwargs` are unused.""" # import is moved here because this part is slow. try: from transformers.models.mamba.modeling_mamba import MambaCache @@ -591,8 +593,13 @@ def get_text_config(self, *args, **kwargs): def make_sliding_window_cache( key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]], + cls_layers: Optional[Union[str, List[type]]] = None, + cls_kwargs: Optional[Union[Dict[str, int], List[Dict[str, int]]]] = None, ) -> transformers.cache_utils.SlidingWindowCache: - "Creates a :class:`transformers.cache_utils.SlidingWindowCache`." + """ + Creates a :class:`transformers.cache_utils.SlidingWindowCache`. + `cls_layers`, `cls_kwargs` are unused. + """ key_value_pairs = _preprocess_key_value_pairs(key_value_pairs) class _config: @@ -654,7 +661,8 @@ def make_hybrid_cache( max_cache_len: Optional[int] = None, max_batch_size: Optional[int] = None, sliding_window: Optional[int] = None, - cls_layers: Optional[List[type]] = None, + cls_layers: Optional[Union[str, List[type]]] = None, + cls_kwargs: Optional[Union[Dict[str, int], List[Dict[str, int]]]] = None, ) -> transformers.cache_utils.HybridCache: """ Creates an instance of :class:`transformers.cache_utils.HybridCache`. @@ -663,7 +671,7 @@ def make_hybrid_cache( :param key_value_pairs: list of pairs of (key, values) :return: :class:`transformers.cache_utils.HybridCache` - `cls_layers` is unused. + `cls_layers`, `cls_kwargs` are unused. Example: