diff --git a/exir/program/_program.py b/exir/program/_program.py index 8e825f6f85b..abf413918e5 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -1193,7 +1193,10 @@ def _gen_edge_manager_for_partitioners( # Decompose by default if there are no partitioners for the method if not partitioners_for_program: - program = program.run_decompositions(_default_decomposition_table()) + table = _default_decomposition_table() + for op in config.preserve_ops: + table.pop(op, None) + program = program.run_decompositions(table) # Process each partitioner individually using their specific requirements for curr_partitioner in partitioners_for_program: diff --git a/exir/program/test/test_program.py b/exir/program/test/test_program.py index 2e788ef5c74..1e2e9f824c7 100644 --- a/exir/program/test/test_program.py +++ b/exir/program/test/test_program.py @@ -893,6 +893,45 @@ def test_to_edge_with_preserved_ops_not_in_model(self): program, ops_not_to_decompose, expected_non_decomposed_edge_ops ) + def test_to_edge_transform_and_lower_with_preserve_ops_no_partitioner(self): + """Test that preserve_ops works with to_edge_transform_and_lower when + no partitioner is provided. This ensures ops like aten.linear.default + are not decomposed to addmm even without a partitioner.""" + model = TestLinear() + program = torch.export.export(model, model._get_random_inputs(), strict=True) + + preserved_ops = [torch.ops.aten.linear.default] + expected_edge_ops = [exir_ops.edge.aten.linear.default] + + def count_nodes(graph_module, targets): + count = 0 + for node in graph_module.graph.nodes: + if node.op == "call_function" and node.target in targets: + count += 1 + return count + + # Verify ops exist in the original program before the API call + ops_before = count_nodes(program.graph_module, preserved_ops) + self.assertEqual(ops_before, 1) + + edge = to_edge_transform_and_lower( + program, + compile_config=EdgeCompileConfig( + preserve_ops=preserved_ops, + ), + ) + + # Verify preserved ops survive after the API call (as edge dialect ops) + ops_after = count_nodes(edge.exported_program().graph_module, expected_edge_ops) + self.assertEqual(ops_before, ops_after) + + # Verify linear was NOT decomposed to addmm + addmm_count = count_nodes( + edge.exported_program().graph_module, + [exir_ops.edge.aten.addmm.default], + ) + self.assertEqual(addmm_count, 0) + def test_save_fails(self): model = TestLinear() program = torch.export.export(model, model._get_random_inputs(), strict=True)