Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,7 +1193,10 @@ def _gen_edge_manager_for_partitioners(

# Decompose by default if there are no partitioners for the method
if not partitioners_for_program:
program = program.run_decompositions(_default_decomposition_table())
table = _default_decomposition_table()
for op in config.preserve_ops:
table.pop(op, None)
program = program.run_decompositions(table)

# Process each partitioner individually using their specific requirements
for curr_partitioner in partitioners_for_program:
Expand Down
39 changes: 39 additions & 0 deletions exir/program/test/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,45 @@ def test_to_edge_with_preserved_ops_not_in_model(self):
program, ops_not_to_decompose, expected_non_decomposed_edge_ops
)

def test_to_edge_transform_and_lower_with_preserve_ops_no_partitioner(self):
"""Test that preserve_ops works with to_edge_transform_and_lower when
no partitioner is provided. This ensures ops like aten.linear.default
are not decomposed to addmm even without a partitioner."""
model = TestLinear()
program = torch.export.export(model, model._get_random_inputs(), strict=True)

preserved_ops = [torch.ops.aten.linear.default]
expected_edge_ops = [exir_ops.edge.aten.linear.default]

def count_nodes(graph_module, targets):
count = 0
for node in graph_module.graph.nodes:
if node.op == "call_function" and node.target in targets:
count += 1
return count

# Verify ops exist in the original program before the API call
ops_before = count_nodes(program.graph_module, preserved_ops)
self.assertEqual(ops_before, 1)

edge = to_edge_transform_and_lower(
program,
compile_config=EdgeCompileConfig(
preserve_ops=preserved_ops,
),
)

# Verify preserved ops survive after the API call (as edge dialect ops)
ops_after = count_nodes(edge.exported_program().graph_module, expected_edge_ops)
self.assertEqual(ops_before, ops_after)

# Verify linear was NOT decomposed to addmm
addmm_count = count_nodes(
edge.exported_program().graph_module,
[exir_ops.edge.aten.addmm.default],
)
self.assertEqual(addmm_count, 0)

def test_save_fails(self):
model = TestLinear()
program = torch.export.export(model, model._get_random_inputs(), strict=True)
Expand Down
Loading