Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions exir/emit/test/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,15 @@ fbcode_target(_kind = runtime.python_test,
"//executorch/exir:schema",
"//executorch/exir/backend/test/demos/rpc:executor_backend_partitioner",
"//executorch/exir/backend:backend_api",
"//executorch/exir/backend:compile_spec_schema",
"//executorch/exir/backend:partitioner",
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
"//executorch/exir/backend/test:backend_with_compiler_demo",
"//executorch/exir/emit:lib",
"//executorch/exir/passes:const_prop_pass",
"//executorch/exir/passes:constant_prop_pass",
"//executorch/exir/passes:init_mutable_pass",
"//executorch/exir/passes:propagate_device_pass",
"//executorch/exir/tests:lib",
"//executorch/exir/tests:models",
"//executorch/extension/pybindings:portable_lib",
Expand Down
125 changes: 125 additions & 0 deletions exir/emit/test/test_emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2518,3 +2518,128 @@ def forward(self):
for j in range(2):
expected_storage.append(j * 16 + i)
self.assertEqual([int(v) for v in storage_values], expected_storage)

def test_emit_device_info_propagated_to_serialized_tensor(self) -> None:
"""Verify that device info from PropagateDevicePass flows through
the emitter into ExtraTensorInfo.device_type on serialized tensors."""
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
generate_pattern_op_partitions,
)
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.backend.partitioner import (
DelegationSpec,
Partitioner,
PartitionResult,
)
from executorch.exir.backend.test.backend_with_compiler_demo import (
BackendWithCompilerDemo,
)
from executorch.exir.passes.propagate_device_pass import (
TARGET_DEVICE_COMPILE_SPEC_KEY,
)
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase

class AddSupport(OperatorSupportBase):
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
return node.op == "call_function" and node.target in [
exir_ops.edge.aten.add.Tensor,
]

class DevicePartitioner(Partitioner):
def __init__(self):
super().__init__()
self.delegation_spec = DelegationSpec(
BackendWithCompilerDemo.__name__,
[
CompileSpec("max_value", bytes([4])),
CompileSpec(TARGET_DEVICE_COMPILE_SPEC_KEY, b"cuda:0"),
],
)

def partition(self, exported_program) -> PartitionResult:
partition_tags = {}
partition_list = generate_pattern_op_partitions(
exported_program.graph_module,
op_support=any_chain(AddSupport()),
)
for partition in partition_list:
for node in partition.nodes:
tag = f"tag{partition.id}"
node.meta["delegation_tag"] = tag
partition_tags[tag] = self.delegation_spec
return PartitionResult(
tagged_exported_program=exported_program,
partition_tags=partition_tags,
)

class Model(torch.nn.Module):
def forward(self, a, b):
return torch.add(a, b)

model = Model()
inputs = (torch.randn(2, 2), torch.randn(2, 2))

edge = to_edge(
export(model, inputs),
compile_config=EdgeCompileConfig(_check_ir_validity=False),
)
lowered = edge.to_backend(DevicePartitioner())
et_prog = lowered.to_executorch()
program = et_prog._emitter_output.program

plan = program.execution_plan[0]
self.assertGreater(len(plan.delegates), 0)

tensor_values = [v.val for v in plan.values if isinstance(v.val, Tensor)]
cuda_tensors = [
t
for t in tensor_values
if t.extra_tensor_info is not None
and t.extra_tensor_info.device_type == schema.DeviceType.CUDA
]
# add(a, b) has 2 delegate inputs + 1 delegate output = 3 CUDA tensors
self.assertEqual(
len(cuda_tensors),
3,
f"Expected exactly 3 CUDA tensors (2 inputs + 1 output for delegated add), got {len(cuda_tensors)}",
)
# Verify device_index is also correctly serialized (cuda:0 → index 0)
for t in cuda_tensors:
self.assertEqual(
t.extra_tensor_info.device_index,
0,
"CUDA tensor device_index should be 0 for cuda:0",
)

def test_emit_cpu_tensors_no_extra_device_info(self) -> None:
"""When all tensors are on CPU (default), ExtraTensorInfo should NOT be
created solely for device info — it should remain None for activation tensors.
"""

class Model(torch.nn.Module):
def forward(self, a, b):
return torch.add(a, b)

model = Model()
inputs = (torch.randn(2, 2), torch.randn(2, 2))

edge = to_edge(
export(model, inputs),
compile_config=EdgeCompileConfig(_check_ir_validity=False),
)
et_prog = edge.to_executorch()
program = et_prog._emitter_output.program

plan = program.execution_plan[0]
tensor_values = [v.val for v in plan.values if isinstance(v.val, Tensor)]
non_cpu_tensors = [
t
for t in tensor_values
if t.extra_tensor_info is not None
and t.extra_tensor_info.device_type is not None
]
self.assertEqual(
len(non_cpu_tensors),
0,
"No tensor should have extra device info when model runs entirely on CPU",
)
14 changes: 14 additions & 0 deletions exir/passes/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -451,3 +451,17 @@ fbcode_target(_kind = runtime.python_library,
"//caffe2:torch",
],
)

fbcode_target(_kind = runtime.python_library,
name = "propagate_device_pass",
srcs = [
"propagate_device_pass.py",
],
deps = [
"//caffe2:torch",
"//executorch/exir:delegate",
"//executorch/exir:lowered_backend_module",
"//executorch/exir:schema",
"//executorch/exir:tensor",
],
)
214 changes: 214 additions & 0 deletions exir/passes/propagate_device_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import logging
from typing import Optional

import executorch.exir.schema as schema

import torch
from executorch.exir.delegate import executorch_call_delegate
from executorch.exir.lowered_backend_module import LoweredBackendModule
from executorch.exir.tensor import TensorSpec
from torch.fx.passes.infra.pass_base import PassBase, PassResult

logger: logging.Logger = logging.getLogger(__name__)

# CompileSpec key convention for specifying the target device.
# Partitioners that target a specific device should include a CompileSpec entry
# with this key and a value encoding the device string (e.g., b"cuda:0").
TARGET_DEVICE_COMPILE_SPEC_KEY = "target_device"


def _parse_device_spec_value(value: bytes) -> tuple[schema.DeviceType, int]:
"""
Parse a target_device CompileSpec value (e.g., b"cuda:0") into
(DeviceType, device_index).

The type portion is matched case-insensitively against schema.DeviceType
member names (e.g., "cpu", "cuda"). Raises ValueError for unknown types.
"""
device_str = value.decode("utf-8").strip().lower()
if ":" in device_str:
type_str, index_str = device_str.split(":", 1)
device_index = int(index_str)
else:
type_str = device_str
device_index = 0
device_type = next(
(dt for dt in schema.DeviceType if dt.name.lower() == type_str),
None,
)
if device_type is None:
valid = ", ".join(dt.name for dt in schema.DeviceType)
raise ValueError(f"Unknown device type '{type_str}'. Valid types: {valid}")
return device_type, device_index


def _get_lowered_module(
graph_module: torch.fx.GraphModule,
delegate_call_node: torch.fx.Node,
) -> Optional[LoweredBackendModule]:
"""
Given an executorch_call_delegate node, retrieve the associated
LoweredBackendModule from the graph module.
The first argument to executorch_call_delegate is a get_attr node
whose target names the LoweredBackendModule attribute.
"""
if len(delegate_call_node.args) < 1:
return None
lowered_node = delegate_call_node.args[0]
if not isinstance(lowered_node, torch.fx.Node) or lowered_node.op != "get_attr":
return None
lowered_module = getattr(graph_module, lowered_node.target, None)
if isinstance(lowered_module, LoweredBackendModule):
return lowered_module
return None


def _get_target_device_from_compile_specs(
lowered_module: LoweredBackendModule,
) -> Optional[tuple[schema.DeviceType, int]]:
"""
Look for a CompileSpec with key TARGET_DEVICE_COMPILE_SPEC_KEY and return
the corresponding (DeviceType, device_index), or None if not found.
"""
for spec in lowered_module.compile_specs:
if spec.key == TARGET_DEVICE_COMPILE_SPEC_KEY:
return _parse_device_spec_value(spec.value)
return None


def _set_device_on_spec(
spec: TensorSpec,
device_type: schema.DeviceType,
device_index: int = 0,
) -> None:
"""Set the device attribute on a TensorSpec."""
spec.device = device_type
spec.device_index = device_index


def _tag_specs_with_device(
specs: object,
device_type: schema.DeviceType,
device_index: int = 0,
) -> bool:
"""Apply device annotation to a TensorSpec or a collection of TensorSpecs.

Args:
specs: A TensorSpec, a tuple/list of TensorSpecs, or None.
device_type: The target device type to set.
device_index: The device index (e.g., 0 for cuda:0, 1 for cuda:1).

Returns:
True if any spec was modified, False otherwise.
"""
if specs is None:
return False
if isinstance(specs, TensorSpec):
_set_device_on_spec(specs, device_type, device_index)
return True
if isinstance(specs, (tuple, list)):
changed = False
for s in specs:
if isinstance(s, TensorSpec):
_set_device_on_spec(s, device_type, device_index)
changed = True
return changed
return False


class PropagateDevicePass(PassBase):
"""
After to_backend, walk the graph and set device metadata on TensorSpecs
based on partitioner-assigned delegation info.

Rules:
1. Delegated nodes: Input and output tensors of a delegate call are marked
with the target device derived from the delegate's CompileSpec
(key="target_device").
2. Non-delegated nodes: Remain on CPU (default).
3. Getitem nodes that extract from a delegate call inherit the device from
the delegate call's output spec at the corresponding index.
"""

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
changed = False
for node in graph_module.graph.nodes:
if node.op == "call_function" and node.target == executorch_call_delegate:
lowered_module = _get_lowered_module(graph_module, node)
if lowered_module is None:
raise RuntimeError(
f"executorch_call_delegate node '{node.name}' does not reference "
"a valid LoweredBackendModule. The first argument must be a "
"get_attr node pointing to a LoweredBackendModule attribute."
)

result = _get_target_device_from_compile_specs(lowered_module)
if result is None:
continue

target_device_type, device_index = result

# Tag delegate input tensors.
# args[0] is the get_attr node for the lowered module; skip it.
for arg in node.args[1:]:
if isinstance(arg, torch.fx.Node):
changed |= _tag_specs_with_device(
arg.meta.get("spec"),
target_device_type,
device_index,
)

# Tag delegate output tensors.
changed |= _tag_specs_with_device(
node.meta.get("spec"),
target_device_type,
device_index,
)

logger.debug(
"PropagateDevicePass: set device=%s on delegate node %s "
"(backend=%s)",
target_device_type,
node.name,
lowered_module.backend_id,
)

# Second pass: propagate device through getitem nodes that extract
# individual outputs from a delegate call.
for node in graph_module.graph.nodes:
if node.op == "call_function" and node.target.__name__ == "getitem":
source_node = node.args[0]
if (
isinstance(source_node, torch.fx.Node)
and source_node.op == "call_function"
and source_node.target == executorch_call_delegate
):
spec = node.meta.get("spec")
source_specs = source_node.meta.get("spec")
idx = node.args[1]
if (
spec is not None
and isinstance(spec, TensorSpec)
and source_specs is not None
and isinstance(source_specs, (tuple, list))
and isinstance(idx, int)
and idx < len(source_specs)
):
source_spec = source_specs[idx]
if isinstance(source_spec, TensorSpec):
_set_device_on_spec(
spec,
source_spec.device,
source_spec.device_index,
)
changed = True

return PassResult(graph_module, changed)
2 changes: 2 additions & 0 deletions exir/passes/replace_view_copy_with_view_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def __init__(self, base: TensorSpec, shape: List[int]) -> None:
"mem_offset",
"dtype", # property
"extra_tensor_info", # property
"device",
"device_index",
]

# Make sure _self_fields and _base_fields are disjoint
Expand Down
1 change: 1 addition & 0 deletions exir/program/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ fbcode_target(_kind = runtime.python_library,
"//executorch/exir/passes:insert_write_back_for_buffers_pass",
"//executorch/exir/passes:lib",
"//executorch/exir/passes:normalize_view_copy_base_pass",
"//executorch/exir/passes:propagate_device_pass",
"//executorch/exir/passes:remove_graph_asserts_pass",
"//executorch/exir/passes:remove_mixed_type_operators",
"//executorch/exir/passes:replace_aten_with_edge_pass",
Expand Down
Loading
Loading