-
Notifications
You must be signed in to change notification settings - Fork 846
Cortex-M: Enable full MobileNetV2 lowering to CMSIS-NN backend via Aot Compiler script #17075
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -69,9 +69,117 @@ def _get_batch_size_from_conv(self, conv_node: torch.fx.Node): | |||||||||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def _get_addmm_replacement(self, node): | ||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||
| Handle aten.addmm (decomposed linear): | ||||||||||||||||||||||||||||||||
| addmm(bias, input, weight.T) = input @ weight.T + bias | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| input_qparams indices for addmm: | ||||||||||||||||||||||||||||||||
| [0] = bias (int32) | ||||||||||||||||||||||||||||||||
| [1] = input activation (int8) | ||||||||||||||||||||||||||||||||
| [2] = weight (int8) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| The weight qparams at index [2] are guaranteed to be present because | ||||||||||||||||||||||||||||||||
| CortexMQuantizer marks weight nodes as annotated, allowing | ||||||||||||||||||||||||||||||||
| FoldAndAnnotateQParamsPass to properly fold Q/DQ nodes and populate qparams. | ||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||
| # Validate addmm node structure: addmm(bias, input, weight) | ||||||||||||||||||||||||||||||||
| if len(node.args) < 3: | ||||||||||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| bias_node = node.args[0] | ||||||||||||||||||||||||||||||||
| input_node = node.args[1] | ||||||||||||||||||||||||||||||||
| weights_node = node.args[2] | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # Validate qparams are present with helpful error messages | ||||||||||||||||||||||||||||||||
| input_qparams = node.meta.get("input_qparams", {}) | ||||||||||||||||||||||||||||||||
| if 1 not in input_qparams: | ||||||||||||||||||||||||||||||||
| raise RuntimeError( | ||||||||||||||||||||||||||||||||
| f"Missing input activation qparams at index 1 for addmm node '{node.name}'. " | ||||||||||||||||||||||||||||||||
| f"Available input_qparams keys: {list(input_qparams.keys())}. " | ||||||||||||||||||||||||||||||||
| "Ensure the model is properly quantized and FoldAndAnnotateQParamsPass ran." | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
| if 2 not in input_qparams: | ||||||||||||||||||||||||||||||||
| raise RuntimeError( | ||||||||||||||||||||||||||||||||
| f"Missing weight qparams at index 2 for addmm node '{node.name}'. " | ||||||||||||||||||||||||||||||||
| f"Available input_qparams keys: {list(input_qparams.keys())}. " | ||||||||||||||||||||||||||||||||
| "Ensure CortexMQuantizer marked weight nodes and PropagateQParamsPass " | ||||||||||||||||||||||||||||||||
| "propagated qparams through any transpose/permute ops." | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # Get input activation qparams (index 1, not 0 which is bias!) | ||||||||||||||||||||||||||||||||
| input_scale = input_qparams[1].scale | ||||||||||||||||||||||||||||||||
| input_zp = input_qparams[1].zp | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # Get weight qparams (index 2) | ||||||||||||||||||||||||||||||||
| weight_scale = input_qparams[2].scale | ||||||||||||||||||||||||||||||||
| weight_zp = input_qparams[2].zp | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # Get output qparams | ||||||||||||||||||||||||||||||||
| output_scale = node.meta["output_qparams"][0].scale | ||||||||||||||||||||||||||||||||
| output_zp = node.meta["output_qparams"][0].zp | ||||||||||||||||||||||||||||||||
| output_min = node.meta["output_qparams"][0].qmin | ||||||||||||||||||||||||||||||||
| output_max = node.meta["output_qparams"][0].qmax | ||||||||||||||||||||||||||||||||
|
Comment on lines
+119
to
+122
|
||||||||||||||||||||||||||||||||
| output_scale = node.meta["output_qparams"][0].scale | |
| output_zp = node.meta["output_qparams"][0].zp | |
| output_min = node.meta["output_qparams"][0].qmin | |
| output_max = node.meta["output_qparams"][0].qmax | |
| output_qparams = node.meta.get("output_qparams", {}) | |
| if 0 not in output_qparams: | |
| raise RuntimeError( | |
| f"Missing output activation qparams at index 0 for addmm node '{node.name}'. " | |
| f"Available output_qparams keys: {list(output_qparams.keys()) if hasattr(output_qparams, 'keys') else output_qparams}. " | |
| "Ensure the model is properly quantized and that qparams were propagated to outputs." | |
| ) | |
| output_scale = output_qparams[0].scale | |
| output_zp = output_qparams[0].zp | |
| output_min = output_qparams[0].qmin | |
| output_max = output_qparams[0].qmax |
Uh oh!
There was an error while loading. Please reload this page.