Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 44 additions & 9 deletions backends/cadence/aot/memory_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,45 @@ def is_cat_along_outermost_dim(
return False
return True

def _has_duplicate_resolved_sources(
self, cat_tensors: Sequence[torch.fx.Node]
) -> bool:
"""Return True if two cat inputs resolve to the same underlying tensor."""
if len(cat_tensors) != len(set(cat_tensors)):
return True
resolved_sources = set()
for arg in cat_tensors:
resolved = arg
while (
info := self.constraint.get_relative_placement_source(resolved)
) is not None:
if self.constraint.is_alias_of(info.source, resolved):
resolved = info.source
else:
break
if id(resolved) in resolved_sources:
return True
resolved_sources.add(id(resolved))
return False

def _has_unaligned_cat_tensors(
self,
graph: torch.fx.Graph,
node: torch.fx.Node,
cat_tensors: Sequence[torch.fx.Node],
) -> bool:
"""Return True if any non-placeholder cat tensor has misaligned offset."""
if is_node_in_flattened_output(graph, node):
return False
expected_alignment = 8
relative_offsets = get_relative_offsets_of_cat_tensors(cat_tensors)
for idx, arg in enumerate(cat_tensors):
if not (arg.op == "placeholder") and (
relative_offsets[idx] & (expected_alignment - 1) != 0
):
return True
return False

# If A = cat(B, C), and the concatenation is along the outermost dimension, then
# we can optimize away this cat operation if (1) B and C are placed contiguously,
# and (2) the absolute memory location of tensor A is the same as B. This function
Expand Down Expand Up @@ -486,21 +525,17 @@ def is_removable_cat_op(
return False
# If the same tensor appears multiple times in the cat inputs,
# we cannot place it at multiple different offsets relative to the output.
if len(cat_tensors) != len(set(cat_tensors)):
# Also check resolved sources: two different alias nodes may resolve to
# the same underlying tensor, which can't be at two offsets.
if self._has_duplicate_resolved_sources(cat_tensors):
return False

# Many ops in HiFi require the input to be aligned to 8-byte boundary.
# If the cat is not the graph's output, then ensure that the relative
# offset of any concatenated non-placeholder tensor is a multiple of
# 8 bytes,
if not is_node_in_flattened_output(graph_module.graph, node):
expected_alignment = 8
relative_offsets = get_relative_offsets_of_cat_tensors(cat_tensors)
for idx, arg in enumerate(cat_tensors):
if not (arg.op == "placeholder") and (
relative_offsets[idx] & (expected_alignment - 1) != 0
):
return False
if self._has_unaligned_cat_tensors(graph_module.graph, node, cat_tensors):
return False

return True

Expand Down
Loading