From b2229116f1139343a3d0d2dedea93e748933af11 Mon Sep 17 00:00:00 2001 From: Github Executorch Date: Fri, 30 Jan 2026 13:51:48 -0800 Subject: [PATCH] Summary: MobileNetV2 Fully Lowered to CMSIS-NN MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Cortex-M: Enable full MobileNetV2 lowering to CMSIS-NN backend This PR enables end-to-end export of MobileNetV2 to the CMSIS-NN backend for Cortex-M targets. All quantized operations (conv2d, depthwise conv2d, linear/addmm, activations) are now properly lowered to cortex_m::quantized_* operators, enabling efficient inference on resource-constrained microcontrollers Test Plan: python3 -m examples.arm.aot_arm_compiler -m mv2 --target=cortex-m --quantize --intermediates=./mv2_intermediates --output=./mv2_cortex_m.pte cat ./mv2_intermediates/delegation_info.txt Delegation info: Total delegated subgraphs: 0 Number of delegated nodes: 0 Number of non-delegated nodes: 72 Delegation table: ╒════╤═════════════════════════════════════════════╤═══════════════════════════════════╤═══════════════════════════════════════╕ │ │ op_type │ occurrences_in_delegated_graphs │ occurrences_in_non_delegated_graphs │ ╞════╪═════════════════════════════════════════════╪═══════════════════════════════════╪═══════════════════════════════════════╡ │ 0 │ aten_as_strided_copy_default │ 0 │ 1 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 1 │ aten_mean_dim │ 0 │ 1 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 2 │ aten_view_copy_default │ 0 │ 1 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 3 │ cortex_m_dequantize_per_tensor_default │ 0 │ 2 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 4 │ cortex_m_quantize_per_tensor_default │ 0 │ 2 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 5 │ cortex_m_quantized_add_default │ 0 │ 10 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 6 │ cortex_m_quantized_conv2d_default │ 0 │ 35 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 7 │ cortex_m_quantized_depthwise_conv2d_default │ 0 │ 17 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 8 │ cortex_m_quantized_linear_default │ 0 │ 1 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 9 │ dim_order_ops__clone_dim_order_default │ 0 │ 1 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 10 │ Total │ 0 │ 71 │ ╘════╧═════════════════════════════════════════════╧═══════════════════════════════════╧═══════════════════════════════════════╛ Note E2E Inference tested on Alif E8 Board Reviewers: Subscribers: Tasks: Tags: --- .../fold_qdq_with_annotated_qparams_pass.py | 64 ++++++++ .../passes/convert_to_cortex_m_pass.py | 115 ++++++++++++- .../cortex_m/passes/cortex_m_pass_manager.py | 22 ++- .../cortex_m/passes/propagate_qparams_pass.py | 152 ++++++++++++++++++ backends/cortex_m/quantizer/quantizer.py | 29 ++++ examples/arm/aot_arm_compiler.py | 146 ++++++++++++----- 6 files changed, 486 insertions(+), 42 deletions(-) create mode 100644 backends/cortex_m/passes/propagate_qparams_pass.py diff --git a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py index 0ecb7ff2070..d868cb3a910 100644 --- a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py +++ b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py @@ -33,6 +33,14 @@ from torch.fx import GraphModule, Node +# Passthrough ops that preserve quantization parameters from input to output. +# These ops should be foldable even without explicit annotation metadata. +PASSTHROUGH_OPS = { + exir_ops.edge.aten.hardtanh.default, + exir_ops.edge.aten.relu.default, + exir_ops.edge.aten.clamp.default, +} + def _get_special_dtype(qspec: QuantArgs) -> TosaSpecialDtype | None: if qspec.dtype == torch.int8: if qspec.qmax == 7 and qspec.qmin == -7: @@ -248,6 +256,26 @@ def _handle_control_flow_node(self, node: Node, graph_module: GraphModule): submodule.graph.erase_node(node_to_remove) return + @staticmethod + def _has_dq_input_and_q_output(node: Node) -> bool: + """ + Check if a node has dequantize input(s) and quantize output(s). + This indicates the node is part of a quantized computation path. + """ + # Check if any input is from a dequantize op + has_dq_input = any( + isinstance(arg, Node) and arg.target in DQ_OPS + for arg in node.args + if isinstance(arg, Node) + ) + + # Check if any output goes to a quantize op + has_q_output = any( + user.target in Q_OPS + for user in node.users + ) + return has_dq_input and has_q_output + @staticmethod def is_foldable(node: Node) -> bool: if node.op != "call_function": @@ -263,6 +291,13 @@ def is_foldable(node: Node) -> bool: ): return True + # Passthrough ops (hardtanh, relu, clamp) that have dq inputs and q outputs + # should be foldable even without explicit annotation. These ops preserve + # quantization parameters and are common in quantized models like MobileNetV2. + if node.target in PASSTHROUGH_OPS: + if FoldAndAnnotateQParamsPass._has_dq_input_and_q_output(node): + return True + # We should not fold q-dq nodes into non-quantized nodes. if not ( ArmAnnotationInfo.CUSTOM_META_KEY in node.meta.get("custom", {}) @@ -335,6 +370,35 @@ def call(self, graph_module: GraphModule) -> PassResult: # noqa: C901 ): self._handle_control_flow_node(n, graph_module) + # Second pass: Propagate qparams through passthrough ops. + # For ops like hardtanh that share qparams with their input, we need to: + # 1. Copy output_qparams from the passthrough op to its input node + # 2. Set input_qparams on the passthrough op + for n in graph_module.graph.nodes: + n = cast(Node, n) + if n.target not in PASSTHROUGH_OPS: + continue + + # Check if this passthrough op has output_qparams but missing input_qparams + has_output = "output_qparams" in n.meta and len(n.meta.get("output_qparams", {})) > 0 + has_input = "input_qparams" in n.meta and len(n.meta.get("input_qparams", {})) > 0 + + if not has_output or has_input: + continue + + # Get the input node + if len(n.args) == 0 or not isinstance(n.args[0], Node): + continue + + input_node = n.args[0] + + # Propagate: For passthrough ops, output qparams equal input qparams + if "output_qparams" not in input_node.meta: + input_node.meta["output_qparams"] = n.meta["output_qparams"] + + # Set input_qparams from output_qparams (same for passthrough ops) + n.meta["input_qparams"] = {0: n.meta["output_qparams"][0]} + # retrace the graph to update the fake tensor types graph_module = super().call(graph_module).graph_module diff --git a/backends/cortex_m/passes/convert_to_cortex_m_pass.py b/backends/cortex_m/passes/convert_to_cortex_m_pass.py index 8da0e720036..64447c12db7 100644 --- a/backends/cortex_m/passes/convert_to_cortex_m_pass.py +++ b/backends/cortex_m/passes/convert_to_cortex_m_pass.py @@ -69,9 +69,117 @@ def _get_batch_size_from_conv(self, conv_node: torch.fx.Node): pass return None + def _get_addmm_replacement(self, node): + """ + Handle aten.addmm (decomposed linear): + addmm(bias, input, weight.T) = input @ weight.T + bias + + input_qparams indices for addmm: + [0] = bias (int32) + [1] = input activation (int8) + [2] = weight (int8) + + The weight qparams at index [2] are guaranteed to be present because + CortexMQuantizer marks weight nodes as annotated, allowing + FoldAndAnnotateQParamsPass to properly fold Q/DQ nodes and populate qparams. + """ + # Validate addmm node structure: addmm(bias, input, weight) + if len(node.args) < 3: + return None + + bias_node = node.args[0] + input_node = node.args[1] + weights_node = node.args[2] + + # Validate qparams are present with helpful error messages + input_qparams = node.meta.get("input_qparams", {}) + if 1 not in input_qparams: + raise RuntimeError( + f"Missing input activation qparams at index 1 for addmm node '{node.name}'. " + f"Available input_qparams keys: {list(input_qparams.keys())}. " + "Ensure the model is properly quantized and FoldAndAnnotateQParamsPass ran." + ) + if 2 not in input_qparams: + raise RuntimeError( + f"Missing weight qparams at index 2 for addmm node '{node.name}'. " + f"Available input_qparams keys: {list(input_qparams.keys())}. " + "Ensure CortexMQuantizer marked weight nodes and PropagateQParamsPass " + "propagated qparams through any transpose/permute ops." + ) + + # Get input activation qparams (index 1, not 0 which is bias!) + input_scale = input_qparams[1].scale + input_zp = input_qparams[1].zp + + # Get weight qparams (index 2) + weight_scale = input_qparams[2].scale + weight_zp = input_qparams[2].zp + + # Get output qparams + output_scale = node.meta["output_qparams"][0].scale + output_zp = node.meta["output_qparams"][0].zp + output_min = node.meta["output_qparams"][0].qmin + output_max = node.meta["output_qparams"][0].qmax + + # Calculate quantization multiplier and shift + quantized_multiplier, quantized_shift = quantize_multiplier_aot( + (input_scale * weight_scale) / output_scale + ) + + # Get the original weight tensor + # Trace back through transpose/permute to find the placeholder + if weights_node.op == "call_function" and len(weights_node.args) > 0: + original_weight_node = weights_node.args[0] + else: + original_weight_node = weights_node + + weights_tensor = get_param_tensor(self.exported_program, original_weight_node) + final_weights = weights_tensor.contiguous() + + # Compute kernel_sum WITHOUT bias (pass None) + # Bias is passed separately to the C++ operator + kernel_sum_tensor = self._compute_kernel_sum( + final_weights, None, -input_zp, -weight_zp + ) + + # Create placeholders for weights and kernel_sum + with node.graph.inserting_after(original_weight_node): + weights_placeholder = create_constant_placeholder( + self.exported_program, + node.graph, + node.name + "_weights", + InputKind.PARAMETER, + final_weights, + ) + + kernel_sum = create_constant_placeholder( + self.exported_program, + node.graph, + node.name + "_kernel_sum", + InputKind.PARAMETER, + kernel_sum_tensor, + ) + + # Build args for cortex_m.quantized_linear + args = ( + input_node, + weights_placeholder, + bias_node, + kernel_sum, + -input_zp, + -weight_zp, + output_zp, + [quantized_multiplier], + [quantized_shift], + output_max, + output_min, + ) + + return exir_ops.edge.cortex_m.quantized_linear.default, args + def _get_linear_replacement(self, node): """ - Let + Let - yi be the output activations (y1, ... yn) - xj be the input activations (x1, ... xm) - wij be the weights (w11, ... wnm) @@ -386,6 +494,11 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: match node.target: case exir_ops.edge.aten.linear.default: op, args = self._get_linear_replacement(node) + case exir_ops.edge.aten.addmm.default: + result = self._get_addmm_replacement(node) + if result is None: + continue + op, args = result case exir_ops.edge.aten.convolution.default: # Check if it's transposed convolution (arg index 6) transposed = node.args[6] if len(node.args) > 6 else False diff --git a/backends/cortex_m/passes/cortex_m_pass_manager.py b/backends/cortex_m/passes/cortex_m_pass_manager.py index 43be3f77fd5..ec7900597f6 100644 --- a/backends/cortex_m/passes/cortex_m_pass_manager.py +++ b/backends/cortex_m/passes/cortex_m_pass_manager.py @@ -10,6 +10,12 @@ FoldAndAnnotateQParamsPass, ScalarsToAttributePass, ) +from executorch.backends.arm._passes.decompose_adaptive_avg_pool2d_pass import ( + DecomposeAdaptiveAvgPool2dPass, +) +from executorch.backends.cortex_m.passes.propagate_qparams_pass import ( + PropagateQParamsPass, +) from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass from executorch.backends.transforms.replace_scalar_with_tensor import ( ReplaceScalarWithTensorArgPass, @@ -34,9 +40,11 @@ class CortexMPassManager(PassManager): # Run before folding so qparams attach to max_pool2d values, not tuple + getitem. RemoveGetItemPass, FoldAndAnnotateQParamsPass, + PropagateQParamsPass, ReplaceScalarWithTensorArgPass, ReplaceQuantNodesPass, ActivationFusionPass, + DecomposeAdaptiveAvgPool2dPass, DecomposeHardswishPass, QuantizedOpFusionPass, ConvertToCortexMPass, @@ -49,12 +57,22 @@ class CortexMPassManager(PassManager): DecomposeMeanPass, ] - def __init__(self, exported_program, passes=None): + def __init__(self, exported_program, passes=None, skip_passes=None): + """ + Initialize CortexMPassManager. + + Args: + exported_program: The ExportedProgram to transform. + passes: Optional custom pass list. Uses default pass_list if None. + skip_passes: Optional list of pass classes to skip. + """ self.exported_program = exported_program if passes is not None: self.passes = passes else: - self.passes = self.pass_list + self.passes = list(self.pass_list) + if skip_passes: + self.passes = [p for p in self.passes if p not in skip_passes] def transform_for_annotation(self, model): passes = self.pass_list_transform_for_annotation diff --git a/backends/cortex_m/passes/propagate_qparams_pass.py b/backends/cortex_m/passes/propagate_qparams_pass.py new file mode 100644 index 00000000000..7f13640a283 --- /dev/null +++ b/backends/cortex_m/passes/propagate_qparams_pass.py @@ -0,0 +1,152 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +PropagateQParamsPass - Propagate qparams to consumer nodes. + +This pass runs after FoldAndAnnotateQParamsPass to ensure ops like addmm +can find weight qparams even when the weight goes through a transpose/permute. + +The issue: FoldAndAnnotateQParamsPass folds DQ into the passthrough node +(e.g., permute), storing qparams as input_qparams. But: +1. output_qparams is empty (no Q node after permute) +2. addmm's input_qparams[2] expects to find the weight qparams + +This pass: +1. For passthrough ops: copies input_qparams to output_qparams (they're equal) +2. Propagates output_qparams from passthrough ops to addmm's input_qparams[2] +""" + +from typing import cast + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass +from torch.fx import Node +from torch.fx.passes.infra.pass_manager import PassResult + + +class PropagateQParamsPass(ExportPass): + """ + Propagates qparams from passthrough ops to their consumers. + + Specifically handles the case where weight goes through transpose/permute + before reaching addmm, ensuring addmm has weight qparams at index 2. + """ + + PASSTHROUGH_OPS = { + exir_ops.edge.aten.t.default, + exir_ops.edge.aten.transpose.int, + exir_ops.edge.aten.permute.default, + exir_ops.edge.aten.permute_copy.default, + exir_ops.edge.aten.view.default, + exir_ops.edge.aten.view_copy.default, + exir_ops.edge.aten.reshape.default, + exir_ops.edge.aten.clone.default, + exir_ops.edge.aten.contiguous.default, + } + + @staticmethod + def _has_qparams(node: Node, key: str) -> bool: + """Check if node has non-empty qparams for the given key.""" + return key in node.meta and len(node.meta.get(key, {})) > 0 + + def _propagate_passthrough_qparams(self, node: Node) -> bool: + """ + Propagate qparams through a passthrough op. + + For passthrough ops, input and output qparams are the same. + Returns True if any modification was made. + """ + modified = False + input_node = node.args[0] + + if not isinstance(input_node, Node): + return False + + # Propagate output_qparams from input to this node + if self._has_qparams(input_node, "output_qparams"): + if not self._has_qparams(node, "output_qparams"): + node.meta["output_qparams"] = input_node.meta["output_qparams"] + modified = True + + # Copy input_qparams to output_qparams (they're the same for passthrough) + if self._has_qparams(node, "input_qparams"): + if not self._has_qparams(node, "output_qparams"): + node.meta["output_qparams"] = {0: node.meta["input_qparams"][0]} + modified = True + + # Copy output_qparams to input_qparams if missing + if self._has_qparams(node, "output_qparams"): + if "input_qparams" not in node.meta: + node.meta["input_qparams"] = {} + if 0 not in node.meta["input_qparams"]: + node.meta["input_qparams"][0] = node.meta["output_qparams"][0] + modified = True + + return modified + + def _propagate_addmm_weight_qparams(self, node: Node) -> bool: + """ + Propagate weight qparams to addmm node at index 2. + + addmm(bias, input, weight.T) expects weight qparams at index 2. + Returns True if any modification was made. + """ + if len(node.args) < 3: + return False + + if "input_qparams" not in node.meta: + node.meta["input_qparams"] = {} + + if 2 in node.meta["input_qparams"]: + return False + + weight_node = node.args[2] + if not isinstance(weight_node, Node): + return False + + if not self._has_qparams(weight_node, "output_qparams"): + return False + + if 0 not in weight_node.meta["output_qparams"]: + return False + + node.meta["input_qparams"][2] = weight_node.meta["output_qparams"][0] + return True + + def call(self, graph_module): + modified = False + + # First pass: Propagate qparams through passthrough ops + for node in graph_module.graph.nodes: + node = cast(Node, node) + + if node.op != "call_function": + continue + if node.target not in self.PASSTHROUGH_OPS: + continue + if len(node.args) == 0: + continue + + if self._propagate_passthrough_qparams(node): + modified = True + + # Second pass: Propagate qparams to addmm nodes + for node in graph_module.graph.nodes: + node = cast(Node, node) + + if node.op != "call_function": + continue + if node.target != exir_ops.edge.aten.addmm.default: + continue + + if self._propagate_addmm_weight_qparams(node): + modified = True + + if modified: + graph_module.recompile() + + return PassResult(graph_module, modified) diff --git a/backends/cortex_m/quantizer/quantizer.py b/backends/cortex_m/quantizer/quantizer.py index f71e538b288..0bb4b7828be 100644 --- a/backends/cortex_m/quantizer/quantizer.py +++ b/backends/cortex_m/quantizer/quantizer.py @@ -355,6 +355,22 @@ def is_bias(self, node: Node, params: List[Node], model: GraphModule) -> bool: """Returns True if node is the second parameter of the given parameters""" return len(params) == 2 and node == params[1] + def _mark_param_node_as_annotated(self, node: Node) -> None: + """ + Mark a weight/bias parameter node as annotated. + + This is necessary for FoldAndAnnotateQParamsPass to recognize the node + as part of a quantized computation path. The ARM quantizer does this + via mark_annotated=True in _QuantProperty. + """ + if Q_ANNOTATION_KEY not in node.meta: + node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation() + node.meta[Q_ANNOTATION_KEY]._annotated = True + annotation_info = ArmAnnotationInfo(quantized=True) + meta_custom = node.meta.get("custom", {}) + meta_custom[ArmAnnotationInfo.CUSTOM_META_KEY] = dict(annotation_info) + node.meta["custom"] = meta_custom + def annotate_match( self, match: List[Node], config: QuantizationConfig, model: GraphModule ) -> None: @@ -372,6 +388,7 @@ def annotate_match( for node in match: input_qspec_map = {} output_qspec = None + param_nodes_to_mark = [] # Track weight/bias nodes to mark as annotated params = [n for n in node.all_input_nodes if self.is_parameter(n, model)] # Check that the assumptions on number of parameters hold to avoid silent errors @@ -385,9 +402,11 @@ def annotate_match( continue if self.is_weight(input_node, params, model): input_qspec_map[input_node] = config.weight if config else None + param_nodes_to_mark.append(input_node) elif self.is_bias(input_node, params, model): # Bias qspec is derived from input + weight qspecs input_qspec_map[input_node] = config.bias(node) if config else None + param_nodes_to_mark.append(input_node) elif input_node not in match: input_qspec_map[input_node] = ( config.input_activation if config else None @@ -398,6 +417,11 @@ def annotate_match( mark_node_as_annotated(node, input_qspec_map, output_qspec) + # Mark weight/bias parameter nodes as annotated so FoldAndAnnotateQParamsPass + # recognizes them and properly folds Q/DQ nodes around them + for param_node in param_nodes_to_mark: + self._mark_param_node_as_annotated(param_node) + def annotate(self, model: GraphModule) -> None: matches = self.match_patterns(model, self.operator_config.operators) for match in matches: @@ -452,6 +476,11 @@ class SharedQspecQuantizer(Quantizer): torch.ops.aten._unsafe_view.default, torch.ops.aten.unflatten.int, torch.ops.aten.flatten.using_ints, + # Additional passthrough ops for MobileNetV2 and similar architectures + torch.ops.aten.hardtanh.default, + torch.ops.aten.hardtanh_.default, + torch.ops.aten.max_pool2d.default, + torch.ops.aten.dropout.default, ] def __init__(self, targets: Optional[List[OpOverload]] = None) -> None: diff --git a/examples/arm/aot_arm_compiler.py b/examples/arm/aot_arm_compiler.py index 509b0e10e5c..419cdc96831 100644 --- a/examples/arm/aot_arm_compiler.py +++ b/examples/arm/aot_arm_compiler.py @@ -9,6 +9,7 @@ import argparse import copy +import inspect import logging import os import sys @@ -36,19 +37,8 @@ from executorch.backends.arm.util._factory import create_partitioner, create_quantizer from executorch.backends.arm.vgf import VgfCompileSpec - -# To use Cortex-M backend -from executorch.backends.cortex_m.passes.convert_to_cortex_m_pass import ( - ConvertToCortexMPass, -) - -from executorch.backends.cortex_m.passes.quantized_op_fusion_pass import ( - QuantizedOpFusionPass, -) - -from executorch.backends.cortex_m.passes.replace_quant_nodes_pass import ( - ReplaceQuantNodesPass, -) +from executorch.backends.cortex_m.passes.cortex_m_pass_manager import CortexMPassManager +from executorch.backends.cortex_m.quantizer.quantizer import CortexMQuantizer from executorch.devtools import generate_etrecord from executorch.devtools.backend_debug import get_delegation_info @@ -399,6 +389,7 @@ def forward(self, x): "TOSA-1.0+INT", "TOSA-1.0+FP", "TOSA-1.0+INT+int16", + "cortex-m", ] @@ -795,6 +786,103 @@ def to_edge_TOSA_delegate( return model_quant, edge +def to_edge_cortex_m( + exported_program: ExportedProgram, + args, + model: GraphModule, + example_inputs: Tuple[torch.Tensor], +): + """ + Export and lower model for Cortex-M target using CMSIS-NN portable kernels. + + This function: + 1. Quantizes the model using CortexMQuantizer + 2. Re-exports the quantized model + 3. Lowers to edge IR + 4. Applies CortexMPassManager transforms to convert ops to cortex_m::* ops + + No delegation is used - all ops run as portable kernels on the Cortex-M target. + """ + logging.info("Using Cortex-M/CMSIS-NN compilation path (no delegation)") + + model_quant = None + + if args.quantize: + logging.info("Quantizing with CortexMQuantizer") + + # Convert model to channels_last memory format for optimal Cortex-M performance + model_channels_last = model.to(memory_format=torch.channels_last) + example_inputs_cl = tuple( + x.to(memory_format=torch.channels_last) if x.dim() == 4 else x + for x in example_inputs + ) + + # Use CortexMQuantizer for INT8 quantization + quantizer = CortexMQuantizer() + prepared = prepare_pt2e(model_channels_last, quantizer) + + dataset = get_calibration_data( + args.model_name, example_inputs_cl, args.evaluate, args.evaluate_config + ) + + if isinstance(dataset, DataLoader): + for sample, _ in dataset: + if isinstance(sample, torch.Tensor) and sample.dim() == 4: + sample = sample.to(memory_format=torch.channels_last) + prepared(sample) + else: + dataset_cl = tuple( + ( + x.to(memory_format=torch.channels_last) + if isinstance(x, torch.Tensor) and x.dim() == 4 + else x + ) + for x in dataset + ) + prepared(*dataset_cl) + + model_quant = convert_pt2e(prepared) + + exported_program = torch.export.export( + model_quant, example_inputs_cl, strict=args.strict_export + ) + else: + logging.warning( + "Quantization is DISABLED. Cortex-M typically requires quantization." + ) + + edge = to_edge_transform_and_lower( + exported_program, + compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) + + # Build pass instances from CortexMPassManager.pass_list + pass_instances = [] + for pass_cls in CortexMPassManager.pass_list: + sig = inspect.signature(pass_cls.__init__) + if "exported_program" in sig.parameters: + pass_instances.append(pass_cls(edge.exported_program())) + else: + pass_instances.append(pass_cls()) + + # Apply transforms + edge = edge.transform(pass_instances) + + # Log cortex_m ops summary + cortex_m_ops = {} + for node in edge.exported_program().graph.nodes: + target_str = str(node.target) + if "cortex_m" in target_str: + op_name = target_str.split(".")[-1] if "." in target_str else target_str + cortex_m_ops[op_name] = cortex_m_ops.get(op_name, 0) + 1 + + logging.info("Cortex-M ops summary:") + for op_name, count in sorted(cortex_m_ops.items()): + logging.info(f" - {op_name}: {count}") + + return model_quant, edge + + def to_edge_no_delegate( exported_program: ExportedProgram, args, @@ -830,26 +918,6 @@ def to_edge_no_delegate( return model_quant, edge -def transform_for_cortex_m_backend(edge_program_manager, args): - # Let's make sure we are using optimized Cortex M backend - # NB: If we can't find and replace ops those are expected to be replaced, - # bad things will happen at runtime, like "missing operator" errors! - - # Instantiate the mandatory ReplaceQuantNodesPass - passes = [ReplaceQuantNodesPass] - if args.enable_qdq_fusion_pass: - passes += [ConvertToCortexMPass, QuantizedOpFusionPass] - current_edge = edge_program_manager - for pass_cls in passes: - transform_pass = ( - pass_cls(current_edge.exported_program()) - if pass_cls.__name__ == "QuantizedLinearFusionPass" - else pass_cls() - ) - current_edge = current_edge.transform([transform_pass]) - return current_edge - - if __name__ == "__main__": # noqa: C901 args = get_args() @@ -881,7 +949,12 @@ def transform_for_cortex_m_backend(edge_program_manager, args): # Quantize if required model_quant = None - if args.delegate: + if args.target == "cortex-m": + # Cortex-M path: CMSIS-NN portable kernels, no delegation + model_quant, edge = to_edge_cortex_m( + exported_program, args, model, example_inputs + ) + elif args.delegate: model_quant, edge = to_edge_TOSA_delegate( exported_program, args, model, example_inputs ) @@ -890,11 +963,6 @@ def transform_for_cortex_m_backend(edge_program_manager, args): exported_program, args, model, example_inputs ) - # Cortex-m ops are never included in vgf or direct-drive - if args.target != "vgf" and not args.direct_drive: - # Transform so we can use ops from the Cortex M backend - edge = transform_for_cortex_m_backend(edge, args) - dump_delegation_info(edge, args.intermediates) edge_program_manager_copy = copy.deepcopy(edge)