Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/check-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
matrix:
os: [ubuntu-latest, macOS-latest, windows-latest]
python: ['3.13']
transformers: ['5.1.0', 'main']
transformers: ['5.2.0', 'main']
torch: ['2.10', 'main']

steps:
Expand Down
8 changes: 5 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
matrix:
os: [ubuntu-latest]
python: ['3.10', '3.11', '3.12', '3.13']
transformers: ['4.48.3', '4.51.3', '4.55.4', '4.57.6', '5.1.0', 'main']
transformers: ['4.48.3', '4.51.3', '4.55.4', '4.57.6', '5.2.0', 'main']
torch: ['2.10', 'main']
exclude:
# 3.10 - torch
Expand All @@ -29,7 +29,7 @@ jobs:
- python: '3.10'
transformers: '4.57.6'
- python: '3.10'
transformers: '5.1.0'
transformers: '5.2.0'
- python: '3.10'
transformers: 'main'
# 3.11 - torch
Expand All @@ -41,7 +41,7 @@ jobs:
- python: '3.11'
transformers: '4.57.6'
- python: '3.11'
transformers: '5.1.0'
transformers: '5.2.0'
- python: '3.11'
transformers: 'main'
# 3.13 - torch
Expand All @@ -54,6 +54,8 @@ jobs:
transformers: '4.51.3'
- python: '3.13'
transformers: '4.55.4'
- python: '3.13'
transformers: '4.57.6'
steps:
- uses: actions/checkout@v3

Expand Down
2 changes: 1 addition & 1 deletion _doc/final/plot_export_gemma3_tiny_input_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
# %%
# Captures inputs and outputs for the model.
observer = InputObserver(
missing=dict(pixel_values=torch.empty((0, 3, 896, 896), dtype=torch.float16))
value_if_missing=dict(pixel_values=torch.empty((0, 3, 896, 896), dtype=torch.float16))
)
with (
register_additional_serialization_functions(patch_transformers=True),
Expand Down
152 changes: 140 additions & 12 deletions _unittests/ut_investigate/test_input_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,7 +887,7 @@ def forward(self, x=None, y=None):
# self.assertEqual(2, len(args))
# self.assertEqual(len([v for v in args.values() if v is not None]), 2)

def test_infer_dynamic_shapes_missing(self):
def test_infer_dynamic_shapes_missing_kwargs(self):
class Model(torch.nn.Module):
def forward(
self,
Expand All @@ -903,33 +903,35 @@ def forward(

inputs = [
dict(
input_ids=torch.ones((1, 282), dtype=torch.int64),
pixel_values=torch.ones((1, 3, 896, 896), dtype=torch.int64),
attention_mask=torch.ones((1, 282), dtype=torch.int64),
position_ids=torch.ones((1, 282), dtype=torch.int64),
token_type_ids=torch.ones((1, 282), dtype=torch.int64),
cache_position=torch.ones((282,), dtype=torch.int64),
input_ids=torch.ones((1, 28), dtype=torch.int64),
pixel_values=torch.ones((1, 3, 112, 112), dtype=torch.int64),
attention_mask=torch.ones((1, 28), dtype=torch.int64),
position_ids=torch.ones((1, 28), dtype=torch.int64),
token_type_ids=torch.ones((1, 28), dtype=torch.int64),
cache_position=torch.ones((28,), dtype=torch.int64),
),
dict(
input_ids=torch.ones((1, 1), dtype=torch.int64),
attention_mask=torch.ones((1, 283), dtype=torch.int64),
attention_mask=torch.ones((1, 29), dtype=torch.int64),
position_ids=torch.ones((1, 1), dtype=torch.int64),
past_key_values=torch.rand((1, 1, 282, 32)),
past_key_values=torch.rand((1, 1, 28, 32)),
token_type_ids=torch.ones((1, 1), dtype=torch.int64),
cache_position=torch.ones((1,), dtype=torch.int64),
),
dict(
input_ids=torch.ones((1, 1), dtype=torch.int64),
attention_mask=torch.ones((1, 284), dtype=torch.int64),
attention_mask=torch.ones((1, 30), dtype=torch.int64),
position_ids=torch.ones((1, 1), dtype=torch.int64),
past_key_values=torch.rand((1, 1, 283, 32)),
past_key_values=torch.rand((1, 1, 29, 32)),
token_type_ids=torch.ones((1, 1), dtype=torch.int64),
cache_position=torch.ones((1,), dtype=torch.int64),
),
]

model = Model()
observer = InputObserver(missing=dict(pixel_values=torch.empty((0, 3, 896, 896))))
observer = InputObserver(
value_if_missing=dict(pixel_values=torch.empty((0, 3, 112, 112)))
)
with observer(model):
for kwargs in inputs:
model(**kwargs)
Expand All @@ -946,6 +948,132 @@ def forward(
"cache_position": {0: cst},
}
self.assertEqual(expected, shapes)
kwargs = observer.infer_arguments()
self.assertEqual(list(expected), list(kwargs))
self.assertEqual((0, 3, 112, 112), kwargs["pixel_values"].shape)

def test_infer_dynamic_shapes_missing_args(self):
class Model(torch.nn.Module):
def forward(
self,
input_ids=None,
pixel_values=None,
attention_mask=None,
past_key_values=None,
):
return input_ids

inputs = [
(
torch.ones((1, 28), dtype=torch.int64),
torch.ones((1, 3, 112, 112), dtype=torch.int64),
torch.ones((1, 28), dtype=torch.int64),
),
(
torch.ones((1, 1), dtype=torch.int64),
None,
torch.ones((1, 29), dtype=torch.int64),
torch.rand((1, 1, 28, 32)),
),
(
torch.ones((1, 1), dtype=torch.int64),
None,
torch.ones((1, 30), dtype=torch.int64),
torch.rand((1, 1, 29, 32)),
),
]

model = Model()
observer = InputObserver(
value_if_missing={1: torch.empty((0, 3, 112, 112), dtype=torch.int64)}
)
with observer(model):
for args in inputs:
model(*args)

shapes = observer.infer_dynamic_shapes(set_batch_dimension_for=True)
cst = torch.export.Dim.DYNAMIC
expected = ({0: cst, 1: cst}, {0: cst}, {0: cst, 1: cst}, {0: cst, 2: cst})
self.assertEqual(expected, shapes)
args = observer.infer_arguments()
self.assertEqual(len(expected), len(args))
self.assertEqual((0, 3, 112, 112), args[1].shape)

def test_infer_dynamic_shapes_missing_kwargs_nested(self):
class Model(torch.nn.Module):
def forward(
self,
input_ids=None,
pixel_values=None,
attention_mask=None,
position_ids=None,
past_key_values=None,
token_type_ids=None,
cache_position=None,
):
return input_ids

inputs = [
dict(
input_ids=torch.ones((1, 28), dtype=torch.int64),
pixel_values=(
torch.ones((1, 3, 112, 112), dtype=torch.int64),
torch.ones((1, 3, 112, 112), dtype=torch.int64),
),
attention_mask=torch.ones((1, 28), dtype=torch.int64),
position_ids=torch.ones((1, 28), dtype=torch.int64),
token_type_ids=torch.ones((1, 28), dtype=torch.int64),
cache_position=torch.ones((28,), dtype=torch.int64),
),
dict(
input_ids=torch.ones((1, 1), dtype=torch.int64),
attention_mask=torch.ones((1, 29), dtype=torch.int64),
position_ids=torch.ones((1, 1), dtype=torch.int64),
past_key_values=torch.rand((1, 1, 28, 32)),
token_type_ids=torch.ones((1, 1), dtype=torch.int64),
cache_position=torch.ones((1,), dtype=torch.int64),
),
dict(
input_ids=torch.ones((1, 1), dtype=torch.int64),
attention_mask=torch.ones((1, 30), dtype=torch.int64),
position_ids=torch.ones((1, 1), dtype=torch.int64),
past_key_values=torch.rand((1, 1, 29, 32)),
token_type_ids=torch.ones((1, 1), dtype=torch.int64),
cache_position=torch.ones((1,), dtype=torch.int64),
),
]

model = Model()
observer = InputObserver(
value_if_missing=dict(
pixel_values=(
torch.empty((0, 3, 112, 112), dtype=torch.int64),
torch.empty((0, 3, 112, 112), dtype=torch.int64),
)
)
)
with observer(model):
for kwargs in inputs:
model(**kwargs)

shapes = observer.infer_dynamic_shapes(set_batch_dimension_for=True)
cst = torch.export.Dim.DYNAMIC
expected = {
"input_ids": {0: cst, 1: cst},
"pixel_values": ({0: cst}, {0: cst}),
"attention_mask": {0: cst, 1: cst},
"position_ids": {0: cst, 1: cst},
"past_key_values": {0: cst, 2: cst},
"token_type_ids": {0: cst, 1: cst},
"cache_position": {0: cst},
}
self.assertEqual(expected, shapes)
kwargs = observer.infer_arguments()
self.assertEqual(list(expected), list(kwargs))
self.assertIsInstance(kwargs["pixel_values"], tuple)
self.assertEqual(2, len(kwargs["pixel_values"]))
self.assertEqual((0, 3, 112, 112), kwargs["pixel_values"][0].shape)
self.assertEqual((0, 3, 112, 112), kwargs["pixel_values"][1].shape)

def test_io_captured_kwargs_kwargs(self):
class Model(torch.nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,11 @@ def forward(
]

model = Model()
observer = InputObserver(missing=dict(pixel_values=torch.empty((0, 3, 896, 896))))
observer = InputObserver(
value_if_missing=dict(
pixel_values=torch.empty((0, 3, 896, 896), dtype=torch.int64)
)
)
with (
register_additional_serialization_functions(patch_transformers=True),
observer(model),
Expand Down
2 changes: 1 addition & 1 deletion _unittests/ut_tasks/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def test_falcon_mamba_dev(self):
model(**inputs)
model(**data["inputs2"])
self.assertIn((data["size"], data["n_weights"]), [(274958336, 68739584)])
if not has_transformers("5.2.99"):
if not has_transformers("5.3.99"):
raise unittest.SkipTest("The model has control flow.")
with torch_export_patches(patch_transformers=True, verbose=10, stop_if_static=1):
torch.export.export(
Expand Down
6 changes: 3 additions & 3 deletions _unittests/ut_tasks/test_tasks_image_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

class TestTasksImageTextToText(ExtTestCase):
@hide_stdout()
@requires_transformers("5.2.99")
@requires_transformers("5.3.99")
@requires_torch("2.7.99")
def test_image_text_to_text_idefics(self):
mid = "HuggingFaceM4/tiny-random-idefics"
Expand All @@ -32,7 +32,7 @@ def test_image_text_to_text_idefics(self):
self.assertEqualAny(expected, ep.module()(**inputs), atol=1)

@hide_stdout()
@requires_transformers("5.2.99")
@requires_transformers("5.3.99")
@requires_torch("2.7.99")
def test_image_text_to_text_tiny_gemma3(self):
"""
Expand Down Expand Up @@ -88,7 +88,7 @@ def test_image_text_to_text_gemma3_4b_it(self):
self.assertEqualAny(expected, ep.module()(**inputs))

@hide_stdout()
@requires_transformers("5.2.99")
@requires_transformers("5.3.99")
@requires_torch("2.7.99")
def test_image_text_to_text_zai_glm(self):
"""
Expand Down
Loading
Loading