From 786d3679ba680bce43eb95640e7c2e2613015e84 Mon Sep 17 00:00:00 2001 From: Ethan Ng Date: Sun, 12 Apr 2026 11:32:02 -0700 Subject: [PATCH] Fix Multiple constraints for allocation for two cat inputs of same underlying tensor (#18830) Summary: Added a resolved-source duplicate check in is_removable_cat_op. After the existing node-identity duplicate check, it follows each input's alias chain to its ultimate source and checks for duplicates among the resolved sources. If two inputs resolve to the same tensor, the cat is not optimizable. Differential Revision: D100494796 --- backends/cadence/aot/memory_constraints.py | 53 ++++++++++++++++++---- 1 file changed, 44 insertions(+), 9 deletions(-) diff --git a/backends/cadence/aot/memory_constraints.py b/backends/cadence/aot/memory_constraints.py index af85a62fc7e..2b6c9a049da 100644 --- a/backends/cadence/aot/memory_constraints.py +++ b/backends/cadence/aot/memory_constraints.py @@ -452,6 +452,45 @@ def is_cat_along_outermost_dim( return False return True + def _has_duplicate_resolved_sources( + self, cat_tensors: Sequence[torch.fx.Node] + ) -> bool: + """Return True if two cat inputs resolve to the same underlying tensor.""" + if len(cat_tensors) != len(set(cat_tensors)): + return True + resolved_sources = set() + for arg in cat_tensors: + resolved = arg + while ( + info := self.constraint.get_relative_placement_source(resolved) + ) is not None: + if self.constraint.is_alias_of(info.source, resolved): + resolved = info.source + else: + break + if id(resolved) in resolved_sources: + return True + resolved_sources.add(id(resolved)) + return False + + def _has_unaligned_cat_tensors( + self, + graph: torch.fx.Graph, + node: torch.fx.Node, + cat_tensors: Sequence[torch.fx.Node], + ) -> bool: + """Return True if any non-placeholder cat tensor has misaligned offset.""" + if is_node_in_flattened_output(graph, node): + return False + expected_alignment = 8 + relative_offsets = get_relative_offsets_of_cat_tensors(cat_tensors) + for idx, arg in enumerate(cat_tensors): + if not (arg.op == "placeholder") and ( + relative_offsets[idx] & (expected_alignment - 1) != 0 + ): + return True + return False + # If A = cat(B, C), and the concatenation is along the outermost dimension, then # we can optimize away this cat operation if (1) B and C are placed contiguously, # and (2) the absolute memory location of tensor A is the same as B. This function @@ -486,21 +525,17 @@ def is_removable_cat_op( return False # If the same tensor appears multiple times in the cat inputs, # we cannot place it at multiple different offsets relative to the output. - if len(cat_tensors) != len(set(cat_tensors)): + # Also check resolved sources: two different alias nodes may resolve to + # the same underlying tensor, which can't be at two offsets. + if self._has_duplicate_resolved_sources(cat_tensors): return False # Many ops in HiFi require the input to be aligned to 8-byte boundary. # If the cat is not the graph's output, then ensure that the relative # offset of any concatenated non-placeholder tensor is a multiple of # 8 bytes, - if not is_node_in_flattened_output(graph_module.graph, node): - expected_alignment = 8 - relative_offsets = get_relative_offsets_of_cat_tensors(cat_tensors) - for idx, arg in enumerate(cat_tensors): - if not (arg.op == "placeholder") and ( - relative_offsets[idx] & (expected_alignment - 1) != 0 - ): - return False + if self._has_unaligned_cat_tensors(graph_module.graph, node, cat_tensors): + return False return True