Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions backends/qualcomm/builders/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,12 +417,11 @@ def get_tensor_name(
elif is_graph_output(node):
tensor_name = f"output_{tensor_name}"

# Save this for intermediate debugger
# Needs idx since node like topk has 2 outputs
if QCOM_TENSOR_NAME in node.meta:
node.meta[QCOM_TENSOR_NAME][wrapper_idx] = tensor_name
else:
node.meta[QCOM_TENSOR_NAME] = {wrapper_idx: tensor_name}
# Only add qcom_tensor_name when enable tensor dump.
# Only do this in qnn_preprocess since thats the final naming. enable_tensor_dump is set to true only in qnn_preprocess and not op validation.
if self.enable_tensor_dump:
node.meta.setdefault(QCOM_TENSOR_NAME, {})[wrapper_idx] = tensor_name

return tensor_name

def define_custom_tensor_wrapper(
Expand Down
162 changes: 101 additions & 61 deletions backends/qualcomm/debugger/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ Generate optrace and QHAS files using QNN tools under $QNN_SDK_ROOT. After finis
adb = SimpleADB(
qnn_config=qnn_config,
pte_path=f"{args.artifact}/{pte_filename}.pte",
workspace=f"/data/local/tmp/executorch/{pte_filename},
workspace=f"/data/local/tmp/executorch/{pte_filename}",
)
binaries_trace = generate_optrace(
args, adb, f"{args.artifact}/{pte_filename}.pte", example_input
Expand Down Expand Up @@ -121,24 +121,24 @@ flowchart TB;
debug --> output["Output Results"]
```

## Instructions

### 1. Setup
## Prerequisites
1. Follow the [tutorial](https://pytorch.org/executorch/main/getting-started-setup) to set up ExecuTorch.
2. Follow the [tutorial](https://pytorch.org/executorch/stable/build-run-qualcomm-ai-engine-direct-backend.html) to build Qualcomm AI Engine Direct Backend.

### 2. Enable Flag
## Instructions

When executing the script, please add the flag `--dump_intermediate_outputs`. This tells QNN to dump all intermediate tensors during execution.
### 1. Initialize debugger and build binary

Create a `QNNIntermediateDebugger` with a sample input and pass it to `build_executorch_binary`. The `--dump_intermediate_outputs` flag tells QNN to dump all intermediate tensors during execution.

### 3. Add debugger to the example script
Initialize a `QNNIntermediateDebugger`. Please pass initialized `QNNIntermediateDebugger` and the `args.dump_intermediate_outputs` to `build_executorch_binary` method as well.
#### Example:
```python
from executorch.backends.qualcomm.export_utils import build_executorch_binary
from executorch.backends.qualcomm.debugger.qnn_intermediate_debugger import QNNIntermediateDebugger
from executorch.backends.qualcomm.debugger.qnn_intermediate_debugger import (
OutputFormat,
QNNIntermediateDebugger,
)

qnn_intermediate_debugger = QNNIntermediateDebugger()
qnn_intermediate_debugger = QNNIntermediateDebugger(sample_input=inputs[0])
build_executorch_binary(
model=MyModel(),
qnn_config=qnn_config,
Expand All @@ -148,81 +148,121 @@ build_executorch_binary(
)
```

### 4. Set data num to 1
It is perfectly fine for users to pass the desired amount of datasets to `build_executorch_binary`, which helps achieve better quantization results. However, after `build_executorch_binary` is called, we need to ensure that we only perform one inference during execution. Please ensure that CPU and QNN is using the same input during execution; otherwise, the debugging results might not be accurate.
After `build_executorch_binary()`, the debugger holds:
- `edge_ep` — edge `ExportedProgram` for CPU golden inference.
- `etrecord_file_path` — path to the generated ET record.

### 2. Execute on device

Ensure `dump_intermediate_outputs` is enabled in your `QnnConfig` (or pass `--dump_intermediate_outputs` via CLI). Only run **one inference** for debugging — multiple executions are not supported.

### 5: Pull and process the results.
After QNN execution with the runner, if the previous steps are done correctly, we should be able to get two files: `etdump.etdp` and `debug_output.bin`.
The following example pulls the files back and calls a callback function to process the results. In this callback function, we create the `Inspector`. Then we perform CPU inference to get CPU intermediate results. Now, we have both QNN and CPU intermediate results, we can start generating results to compare the accuracy. Taking the following example, we should be able to get `debug_graph.svg` as an output in the current directory.
#### Example:
```python
from executorch.backends.qualcomm.debugger.qnn_intermediate_debugger import OutputFormat
from executorch.examples.qualcomm.utils import SimpleADB

adb = SimpleADB(
qnn_config=qnn_config,
pte_path=f"{args.artifact}/{pte_filename}.pte",
workspace=f"/data/local/tmp/executorch/{pte_filename}",
)
adb.push(inputs=inputs)
adb.execute()
```

### 3. Pull results and compare

After execution, pull `etdump.etdp` and `debug_output.bin` from the device. Use `setup_inspector()` to create the `Inspector`, then create comparators and generate results.

Before comparing per-layer outputs, it is highly recommended to verify that the edge program's final output aligns with the original `nn.Module`. The debugger uses the edge program as the CPU golden reference, so if the edge graph itself has diverged (e.g., due to weights quantization or pass transformations), per-layer comparisons against it may be misleading.

```python
from executorch.backends.qualcomm.debugger.qcom_numerical_comparator_sample import (
QcomCosineSimilarityComparator, QcomMSEComparator,
)

def validate_intermediate_tensor():
inspector = Inspector(
qnn_intermediate_debugger.setup_inspector(
etdump_path=f"{args.artifact}/etdump.etdp",
debug_buffer_path=f"{args.artifact}/debug_output.bin",
)
qnn_intermediate_debugger.intermediate_output_module(*(inputs[0]))

# Verify edge program output aligns with the original nn.Module.
# This ensures the edge graph is a reliable golden reference.
edge_result = qnn_intermediate_debugger.edge_ep.module()(*(inputs[0]))
with torch.no_grad():
source_result = source_model(*(inputs[0]))
score = torch.nn.functional.cosine_similarity(
edge_result.flatten(), source_result.flatten(), dim=0
).item()
print("Cosine similarity between nn.Module and edge CPU:", score)

cos_comparator = qnn_intermediate_debugger.create_comparator(
QcomCosineSimilarityComparator, threshold=0.9
)
qnn_intermediate_debugger.generate_results(
title="debug_graph",
path=".",
output_format=OutputFormat.SVG_GRAPHS,
inspector=inspector,
evaluator=CosineSimilarityEvaluator(0.9),
title="debug_cos_similarity",
path=args.artifact,
output_format=OutputFormat.SVG_GRAPH,
comparator=cos_comparator,
)

adb.pull_debug_output(
args.artifact, args.artifact, callback=validate_intermediate_tensor
)
```

#### Additional Options
The above example sets output formats as SVG and evaluation metrics using Cosine Similarity. Based on different needs, users can choose other output formats as shown in the `OutputFormat` class under [qnn_intermediate_debugger](./qnn_intermediate_debugger.py)
## Comparators

Create comparators via the `create_comparator()` factory, which automatically injects the `edge_ep`. A couple sample comparators are provided under [qcom_numerical_comparator_sample.py](./qcom_numerical_comparator_sample.py):

```python
class OutputFormat(IntEnum):
SVG_GRAPHS = 0
CSV_FILES = 1
DUMP_RAW = 2
cos = qnn_intermediate_debugger.create_comparator(QcomCosineSimilarityComparator, threshold=0.9)
mse = qnn_intermediate_debugger.create_comparator(QcomMSEComparator, threshold=0.1)
```

For evaluation metrics, if users would like to implement their own metrics, we have provided the option to implement [MetricEvaluatorBase](./metrics_evaluator.py). The following shows how to define custom metrics.
### Custom comparators

Users can also define their own comparator by implementing a derived class from [QcomNumericalComparatorBase](./qcom_numerical_comparator_base.py). Inside the derived class, users will need to implement `metric_name()`, `is_valid_score()`, and `element_compare()`. The base class handles QNN-specific preprocessing (dequantization, layout conversion) internally — `preprocessing` cannot be overridden.
```python
class RootMeanSquaredErrorEvaluator(MetricEvaluatorBase):
def __init__(self, threshold=0.02):
from executorch.backends.qualcomm.debugger.qcom_numerical_comparator_base import (
QcomNumericalComparatorBase,
)

class MyComparator(QcomNumericalComparatorBase):
def __init__(self, edge_ep, threshold=0.5):
super().__init__(edge_ep)
self.threshold = threshold

def metric_name(self) -> str:
return "Root Mean Squared Error"

def evaluate(
self, qnn_output: torch.Tensor, cpu_output: torch.Tensor
) -> Tuple[Any, bool]:
mse = F.mse_loss(qnn_output, cpu_output)
rmse = torch.sqrt(mse)
valid = rmse < self.threshold
return rmse, valid

qnn_intermediate_debugger.generate_results(
title="my_metric",
path=".",
output_format=OutputFormat.SVG_GRAPHS,
inspector=inspector,
evaluator=RootMeanSquaredErrorEvaluator(),
)
return "my_metric"

def is_valid_score(self, score: float) -> bool:
return score >= self.threshold

def element_compare(self, a, b) -> float:
# your comparison logic here
...
```

### Example Script
We have provided an inception_v3 demo script to help users better understand how to apply the debugger to their scripts. Please refer to [qnn_intermediate_debugger_demo.py](../../../examples/qualcomm/util_scripts/qnn_intermediate_debugger_demo.py) for the example script.
## Output formats

| Format | Enum | Output |
|--------|------|--------|
| SVG graph | `OutputFormat.SVG_GRAPH` | Color-coded computation graph (green=pass, red=fail) |
| CSV file | `OutputFormat.CSV_FILE` | Per-node tabular results |

## Example Script

An Inception_V3 demo script is provided at [qnn_intermediate_debugger_demo.py](../../../examples/qualcomm/util_scripts/qnn_intermediate_debugger_demo.py).

Before running the example script, please ensure that dataset is downloaded. Example dataset can be retrieved [here](https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000).
Before running, ensure the dataset is downloaded. An example dataset can be retrieved [here](https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000).

To execute the model:
```bash
python examples/qualcomm/util_scripts/qnn_intermediate_debugger_demo.py -b build-android -m ${SOC_MODEL} --device ${SERIAL_NUM} --dataset ${PATH_TO_DATASET} --dump_intermediate_outputs
python -m examples.qualcomm.util_scripts.qnn_intermediate_debugger_demo -b build-android -s $DEVICE_SERIAL -m $SOC_MODEL -d path/to/imagenet/val --dump_intermediate_outputs
```

### Limitation
1. The current debugger only supports performing one execution. Multiple executions may cause unknown behavior and are not recommended.
2. Please ignore this if you are using `qnn_executor_runner`. If you have decided to write your own runner, please follow the [tutorial](https://pytorch.org/executorch/stable/etdump.html) on how to implement etdump into your own runner.
3. The current debugger does not support graph with partitions. (WIP)
4. The current debugger does not support LLM models. (WIP)
## Limitations
1. Only one execution per debug session — multiple executions may cause unknown behavior.
2. If you have decided to write your own runner (instead of `qnn_executor_runner`), follow the [tutorial](https://pytorch.org/executorch/stable/etdump.html) on how to implement etdump.
3. Does not support graphs with partitions (partial delegation).
4. Does not support LLM models.
5. Does not support graphs with multiple methods.
Loading
Loading