diff --git a/backends/cadence/aot/memory_constraints.py b/backends/cadence/aot/memory_constraints.py index af85a62fc7e..2b6c9a049da 100644 --- a/backends/cadence/aot/memory_constraints.py +++ b/backends/cadence/aot/memory_constraints.py @@ -452,6 +452,45 @@ def is_cat_along_outermost_dim( return False return True + def _has_duplicate_resolved_sources( + self, cat_tensors: Sequence[torch.fx.Node] + ) -> bool: + """Return True if two cat inputs resolve to the same underlying tensor.""" + if len(cat_tensors) != len(set(cat_tensors)): + return True + resolved_sources = set() + for arg in cat_tensors: + resolved = arg + while ( + info := self.constraint.get_relative_placement_source(resolved) + ) is not None: + if self.constraint.is_alias_of(info.source, resolved): + resolved = info.source + else: + break + if id(resolved) in resolved_sources: + return True + resolved_sources.add(id(resolved)) + return False + + def _has_unaligned_cat_tensors( + self, + graph: torch.fx.Graph, + node: torch.fx.Node, + cat_tensors: Sequence[torch.fx.Node], + ) -> bool: + """Return True if any non-placeholder cat tensor has misaligned offset.""" + if is_node_in_flattened_output(graph, node): + return False + expected_alignment = 8 + relative_offsets = get_relative_offsets_of_cat_tensors(cat_tensors) + for idx, arg in enumerate(cat_tensors): + if not (arg.op == "placeholder") and ( + relative_offsets[idx] & (expected_alignment - 1) != 0 + ): + return True + return False + # If A = cat(B, C), and the concatenation is along the outermost dimension, then # we can optimize away this cat operation if (1) B and C are placed contiguously, # and (2) the absolute memory location of tensor A is the same as B. This function @@ -486,21 +525,17 @@ def is_removable_cat_op( return False # If the same tensor appears multiple times in the cat inputs, # we cannot place it at multiple different offsets relative to the output. - if len(cat_tensors) != len(set(cat_tensors)): + # Also check resolved sources: two different alias nodes may resolve to + # the same underlying tensor, which can't be at two offsets. + if self._has_duplicate_resolved_sources(cat_tensors): return False # Many ops in HiFi require the input to be aligned to 8-byte boundary. # If the cat is not the graph's output, then ensure that the relative # offset of any concatenated non-placeholder tensor is a multiple of # 8 bytes, - if not is_node_in_flattened_output(graph_module.graph, node): - expected_alignment = 8 - relative_offsets = get_relative_offsets_of_cat_tensors(cat_tensors) - for idx, arg in enumerate(cat_tensors): - if not (arg.op == "placeholder") and ( - relative_offsets[idx] & (expected_alignment - 1) != 0 - ): - return False + if self._has_unaligned_cat_tensors(graph_module.graph, node, cat_tensors): + return False return True