Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@
from torch.fx import GraphModule, Node


# Passthrough ops that preserve quantization parameters from input to output.
# These ops should be foldable even without explicit annotation metadata.
PASSTHROUGH_OPS = {
exir_ops.edge.aten.hardtanh.default,
exir_ops.edge.aten.relu.default,
exir_ops.edge.aten.clamp.default,
}

def _get_special_dtype(qspec: QuantArgs) -> TosaSpecialDtype | None:
if qspec.dtype == torch.int8:
if qspec.qmax == 7 and qspec.qmin == -7:
Expand Down Expand Up @@ -248,6 +256,26 @@ def _handle_control_flow_node(self, node: Node, graph_module: GraphModule):
submodule.graph.erase_node(node_to_remove)
return

@staticmethod
def _has_dq_input_and_q_output(node: Node) -> bool:
"""
Check if a node has dequantize input(s) and quantize output(s).
This indicates the node is part of a quantized computation path.
"""
# Check if any input is from a dequantize op
has_dq_input = any(
isinstance(arg, Node) and arg.target in DQ_OPS
for arg in node.args
if isinstance(arg, Node)
)

# Check if any output goes to a quantize op
has_q_output = any(
user.target in Q_OPS
for user in node.users
)
return has_dq_input and has_q_output

@staticmethod
def is_foldable(node: Node) -> bool:
if node.op != "call_function":
Expand All @@ -263,6 +291,13 @@ def is_foldable(node: Node) -> bool:
):
return True

# Passthrough ops (hardtanh, relu, clamp) that have dq inputs and q outputs
# should be foldable even without explicit annotation. These ops preserve
# quantization parameters and are common in quantized models like MobileNetV2.
if node.target in PASSTHROUGH_OPS:
if FoldAndAnnotateQParamsPass._has_dq_input_and_q_output(node):
return True

# We should not fold q-dq nodes into non-quantized nodes.
if not (
ArmAnnotationInfo.CUSTOM_META_KEY in node.meta.get("custom", {})
Expand Down Expand Up @@ -335,6 +370,35 @@ def call(self, graph_module: GraphModule) -> PassResult: # noqa: C901
):
self._handle_control_flow_node(n, graph_module)

# Second pass: Propagate qparams through passthrough ops.
# For ops like hardtanh that share qparams with their input, we need to:
# 1. Copy output_qparams from the passthrough op to its input node
# 2. Set input_qparams on the passthrough op
for n in graph_module.graph.nodes:
n = cast(Node, n)
if n.target not in PASSTHROUGH_OPS:
continue

# Check if this passthrough op has output_qparams but missing input_qparams
has_output = "output_qparams" in n.meta and len(n.meta.get("output_qparams", {})) > 0
has_input = "input_qparams" in n.meta and len(n.meta.get("input_qparams", {})) > 0

if not has_output or has_input:
continue

# Get the input node
if len(n.args) == 0 or not isinstance(n.args[0], Node):
continue

input_node = n.args[0]

# Propagate: For passthrough ops, output qparams equal input qparams
if "output_qparams" not in input_node.meta:
input_node.meta["output_qparams"] = n.meta["output_qparams"]

# Set input_qparams from output_qparams (same for passthrough ops)
n.meta["input_qparams"] = {0: n.meta["output_qparams"][0]}

# retrace the graph to update the fake tensor types
graph_module = super().call(graph_module).graph_module

Expand Down
115 changes: 114 additions & 1 deletion backends/cortex_m/passes/convert_to_cortex_m_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,117 @@ def _get_batch_size_from_conv(self, conv_node: torch.fx.Node):
pass
return None

def _get_addmm_replacement(self, node):
"""
Handle aten.addmm (decomposed linear):
addmm(bias, input, weight.T) = input @ weight.T + bias

input_qparams indices for addmm:
[0] = bias (int32)
[1] = input activation (int8)
[2] = weight (int8)

The weight qparams at index [2] are guaranteed to be present because
CortexMQuantizer marks weight nodes as annotated, allowing
FoldAndAnnotateQParamsPass to properly fold Q/DQ nodes and populate qparams.
"""
# Validate addmm node structure: addmm(bias, input, weight)
if len(node.args) < 3:
return None

bias_node = node.args[0]
input_node = node.args[1]
weights_node = node.args[2]

# Validate qparams are present with helpful error messages
input_qparams = node.meta.get("input_qparams", {})
if 1 not in input_qparams:
raise RuntimeError(
f"Missing input activation qparams at index 1 for addmm node '{node.name}'. "
f"Available input_qparams keys: {list(input_qparams.keys())}. "
"Ensure the model is properly quantized and FoldAndAnnotateQParamsPass ran."
)
if 2 not in input_qparams:
raise RuntimeError(
f"Missing weight qparams at index 2 for addmm node '{node.name}'. "
f"Available input_qparams keys: {list(input_qparams.keys())}. "
"Ensure CortexMQuantizer marked weight nodes and PropagateQParamsPass "
"propagated qparams through any transpose/permute ops."
)

# Get input activation qparams (index 1, not 0 which is bias!)
input_scale = input_qparams[1].scale
input_zp = input_qparams[1].zp

# Get weight qparams (index 2)
weight_scale = input_qparams[2].scale
weight_zp = input_qparams[2].zp

# Get output qparams
output_scale = node.meta["output_qparams"][0].scale
output_zp = node.meta["output_qparams"][0].zp
output_min = node.meta["output_qparams"][0].qmin
output_max = node.meta["output_qparams"][0].qmax
Comment on lines +119 to +122
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code accesses node.meta["output_qparams"][0] without first verifying that "output_qparams" exists or that it contains an entry at index 0. This could raise a KeyError if output_qparams is missing or empty. Consider adding validation similar to the input_qparams validation above (lines 96-108) with a helpful error message to aid debugging.

Suggested change
output_scale = node.meta["output_qparams"][0].scale
output_zp = node.meta["output_qparams"][0].zp
output_min = node.meta["output_qparams"][0].qmin
output_max = node.meta["output_qparams"][0].qmax
output_qparams = node.meta.get("output_qparams", {})
if 0 not in output_qparams:
raise RuntimeError(
f"Missing output activation qparams at index 0 for addmm node '{node.name}'. "
f"Available output_qparams keys: {list(output_qparams.keys()) if hasattr(output_qparams, 'keys') else output_qparams}. "
"Ensure the model is properly quantized and that qparams were propagated to outputs."
)
output_scale = output_qparams[0].scale
output_zp = output_qparams[0].zp
output_min = output_qparams[0].qmin
output_max = output_qparams[0].qmax

Copilot uses AI. Check for mistakes.

# Calculate quantization multiplier and shift
quantized_multiplier, quantized_shift = quantize_multiplier_aot(
(input_scale * weight_scale) / output_scale
)

# Get the original weight tensor
# Trace back through transpose/permute to find the placeholder
if weights_node.op == "call_function" and len(weights_node.args) > 0:
original_weight_node = weights_node.args[0]
else:
original_weight_node = weights_node

weights_tensor = get_param_tensor(self.exported_program, original_weight_node)
final_weights = weights_tensor.contiguous()

# Compute kernel_sum WITHOUT bias (pass None)
# Bias is passed separately to the C++ operator
kernel_sum_tensor = self._compute_kernel_sum(
final_weights, None, -input_zp, -weight_zp
)

# Create placeholders for weights and kernel_sum
with node.graph.inserting_after(original_weight_node):
weights_placeholder = create_constant_placeholder(
self.exported_program,
node.graph,
node.name + "_weights",
InputKind.PARAMETER,
final_weights,
)

kernel_sum = create_constant_placeholder(
self.exported_program,
node.graph,
node.name + "_kernel_sum",
InputKind.PARAMETER,
kernel_sum_tensor,
)

# Build args for cortex_m.quantized_linear
args = (
input_node,
weights_placeholder,
bias_node,
kernel_sum,
-input_zp,
-weight_zp,
output_zp,
[quantized_multiplier],
[quantized_shift],
output_max,
output_min,
)

return exir_ops.edge.cortex_m.quantized_linear.default, args

def _get_linear_replacement(self, node):
"""
Let
Let
- yi be the output activations (y1, ... yn)
- xj be the input activations (x1, ... xm)
- wij be the weights (w11, ... wnm)
Expand Down Expand Up @@ -386,6 +494,11 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
match node.target:
case exir_ops.edge.aten.linear.default:
op, args = self._get_linear_replacement(node)
case exir_ops.edge.aten.addmm.default:
result = self._get_addmm_replacement(node)
if result is None:
continue
op, args = result
case exir_ops.edge.aten.convolution.default:
# Check if it's transposed convolution (arg index 6)
transposed = node.args[6] if len(node.args) > 6 else False
Expand Down
22 changes: 20 additions & 2 deletions backends/cortex_m/passes/cortex_m_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@
FoldAndAnnotateQParamsPass,
ScalarsToAttributePass,
)
from executorch.backends.arm._passes.decompose_adaptive_avg_pool2d_pass import (
DecomposeAdaptiveAvgPool2dPass,
)
from executorch.backends.cortex_m.passes.propagate_qparams_pass import (
PropagateQParamsPass,
)
from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass
from executorch.backends.transforms.replace_scalar_with_tensor import (
ReplaceScalarWithTensorArgPass,
Expand All @@ -34,9 +40,11 @@ class CortexMPassManager(PassManager):
# Run before folding so qparams attach to max_pool2d values, not tuple + getitem.
RemoveGetItemPass,
FoldAndAnnotateQParamsPass,
PropagateQParamsPass,
ReplaceScalarWithTensorArgPass,
ReplaceQuantNodesPass,
ActivationFusionPass,
DecomposeAdaptiveAvgPool2dPass,
DecomposeHardswishPass,
QuantizedOpFusionPass,
ConvertToCortexMPass,
Expand All @@ -49,12 +57,22 @@ class CortexMPassManager(PassManager):
DecomposeMeanPass,
]

def __init__(self, exported_program, passes=None):
def __init__(self, exported_program, passes=None, skip_passes=None):
"""
Initialize CortexMPassManager.

Args:
exported_program: The ExportedProgram to transform.
passes: Optional custom pass list. Uses default pass_list if None.
skip_passes: Optional list of pass classes to skip.
"""
self.exported_program = exported_program
if passes is not None:
self.passes = passes
else:
self.passes = self.pass_list
self.passes = list(self.pass_list)
if skip_passes:
self.passes = [p for p in self.passes if p not in skip_passes]

def transform_for_annotation(self, model):
passes = self.pass_list_transform_for_annotation
Expand Down
Loading
Loading