diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index 32601cdb3e..5aac46a06e 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -978,8 +978,7 @@ def clone_node_and_cache( clone_d[node] = new_node - if new_node.op is not node.op: - clone_d.setdefault(node.op, new_node.op) + clone_d.setdefault(node.op, new_node.op) for old_o, new_o in zip(node.outputs, new_node.outputs, strict=True): clone_d.setdefault(old_o, new_o) diff --git a/pytensor/graph/fg.py b/pytensor/graph/fg.py index 1b77a6403f..93fc6aff0c 100644 --- a/pytensor/graph/fg.py +++ b/pytensor/graph/fg.py @@ -1074,17 +1074,19 @@ def clients(self) -> dict[Variable, list[ClientType]]: # type: ignore[override] self._clients = clients return self._clients - def unfreeze(self) -> "FunctionGraph": - """Return a mutable FunctionGraph with fresh mutable Apply nodes.""" - memo: dict[Variable, Variable] = {inp: inp.type() for inp in self.inputs} + def bind(self, replace: dict[Variable, Variable]) -> list[Variable]: + """Return fresh outputs with root inputs substituted per *replace*. + Constants are reused; any non-Constant input not in *replace* raises KeyError. + """ + memo = replace.copy() for node in self.toposort(): for i in node.inputs: if i not in memo: - if isinstance(i, AtomicVariable): + if isinstance(i, Constant): memo[i] = i else: - memo[i] = i.clone() + raise KeyError(f"Missing replacement for input {i}") new_node = Apply( node.op, @@ -1092,9 +1094,13 @@ def unfreeze(self) -> "FunctionGraph": [o.type() for o in node.outputs], ) memo.update(zip(node.outputs, new_node.outputs)) + return [memo[out] for out in self.outputs] + def unfreeze(self) -> "FunctionGraph": + """Return a mutable FunctionGraph with fresh mutable Apply nodes.""" + fresh_inputs = [inp.type() for inp in self.inputs] return FunctionGraph( - [memo[i] for i in self.inputs], - [memo[o] for o in self.outputs], + fresh_inputs, + self.bind(dict(zip(self.inputs, fresh_inputs))), clone=False, ) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index d3a7f2f650..90fde303e6 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -4239,25 +4239,6 @@ def __init__( for i in inputs: assert i not in outputs # This isn't supported, use identity - # Flatten nested Composites in single-output case - if len(outputs) == 1 and any( - var.owner is not None and isinstance(var.owner.op, Composite) - for var in outputs - ): - inner_op = outputs[0].owner.op - inner_fgraph = inner_op.fgraph.unfreeze() - res = pytensor.compile.rebuild_collect_shared( - inputs=inputs, outputs=outputs[0].owner.inputs, copy_inputs_over=False - ) - res2 = pytensor.compile.rebuild_collect_shared( - inputs=inner_fgraph.inputs, - outputs=inner_fgraph.outputs, - replace=dict(zip(inner_fgraph.inputs, res[1], strict=True)), - ) - assert len(res2[1]) == len(outputs) - assert len(res[0]) == len(inputs) - inputs, outputs = res[0], res2[1] - self.fgraph = FrozenFunctionGraph(inputs, outputs) self._validate_inner_graph(self.fgraph) @@ -4282,8 +4263,7 @@ def __str__(self): return self._name def clone(self): - mutable_fg = self.fgraph.unfreeze() - return self.__class__(mutable_fg.inputs, mutable_fg.outputs) + return self # Op is immutable def output_types(self, input_types): if tuple(input_types) != self.inputs_type: @@ -4297,12 +4277,11 @@ def make_node(self, *inputs): return super().make_node(*inputs) else: # Make a new op with the right input types. - # Unfreeze the frozen inner graph for rebuild_collect_shared. assert len(inputs) == self.nin - mutable_fg = self.fgraph.unfreeze() + fg = self.fgraph res = pytensor.compile.rebuild_collect_shared( - mutable_fg.outputs, - replace=dict(zip(mutable_fg.inputs, inputs, strict=True)), + fg.outputs, + replace=dict(zip(fg.inputs, inputs, strict=True)), rebuild_strict=False, ) # After rebuild_collect_shared, the Variable in inputs diff --git a/pytensor/scalar/loop.py b/pytensor/scalar/loop.py index c80aa7f83c..1a4b30a008 100644 --- a/pytensor/scalar/loop.py +++ b/pytensor/scalar/loop.py @@ -104,23 +104,7 @@ def _validate_updates( ) def clone(self, name=None, **kwargs): - mutable_fg = self.fgraph.unfreeze() - inputs = mutable_fg.inputs - outputs = mutable_fg.outputs - if self.is_while: - *update, until = outputs - else: - update, until = outputs, None - init = inputs[: len(update)] - constant = inputs[len(update) :] - return self.__class__( - init=init, - update=update, - constant=constant, - until=until, - name=self.name if name is None else name, - **kwargs, - ) + return self # Op is immutable @property def fn(self): @@ -140,10 +124,10 @@ def make_node(self, n_steps, *inputs): return super().make_node(n_steps, *inputs) else: # Make a new op with the right input types. - mutable_fg = self.fgraph.unfreeze() + fg = self.fgraph res = rebuild_collect_shared( - mutable_fg.outputs, - replace=dict(zip(mutable_fg.inputs, inputs, strict=True)), + fg.outputs, + replace=dict(zip(fg.inputs, inputs, strict=True)), rebuild_strict=False, ) if self.is_while: diff --git a/pytensor/sparse/basic.py b/pytensor/sparse/basic.py index bdd4f77777..771135b1c6 100644 --- a/pytensor/sparse/basic.py +++ b/pytensor/sparse/basic.py @@ -752,7 +752,7 @@ def infer_shape(self, fgraph, node, shapes): class SparseFromDense(Op): """Convert a dense matrix to a sparse matrix.""" - __props__ = () + __props__ = ("format",) def __init__(self, format): self.format = format diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 27f7f94034..c04f3a72a8 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -16,7 +16,6 @@ from pytensor.graph.features import ReplaceValidate from pytensor.graph.fg import FunctionGraph, Output from pytensor.graph.op import Op -from pytensor.graph.replace import clone_replace from pytensor.graph.rewriting.basic import ( GraphRewriter, copy_stack_trace, @@ -1069,13 +1068,13 @@ def local_inline_composite_constants(fgraph, node): if not inlineable: return None - mutable_fg = composite_op.fgraph.unfreeze() + composite_fg = composite_op.fgraph inlineable_indices = {i for i, _ in inlineable} new_outer_inputs = [] new_inner_inputs = [] - inner_replacements = {} + inner_replacements = {i: i for i in composite_fg.inputs} for i, (outer_inp, inner_inp) in enumerate( - zip(node.inputs, mutable_fg.inputs, strict=True) + zip(node.inputs, composite_fg.inputs, strict=True) ): if i in inlineable_indices: inner_replacements[inner_inp] = scalar_constant( @@ -1085,7 +1084,7 @@ def local_inline_composite_constants(fgraph, node): new_outer_inputs.append(outer_inp) new_inner_inputs.append(inner_inp) - new_inner_outs = clone_replace(mutable_fg.outputs, replace=inner_replacements) + new_inner_outs = composite_fg.bind(inner_replacements) new_composite_op = Composite(new_inner_inputs, new_inner_outs) new_outputs = Elemwise(new_composite_op).make_node(*new_outer_inputs).outputs diff --git a/pytensor/tensor/rewriting/ofg.py b/pytensor/tensor/rewriting/ofg.py index e21a01c5b8..0e99be81a2 100644 --- a/pytensor/tensor/rewriting/ofg.py +++ b/pytensor/tensor/rewriting/ofg.py @@ -1,8 +1,6 @@ -from typing import cast - from pytensor.compile import optdb from pytensor.compile.builders import OpFromGraph -from pytensor.graph import Apply, Variable, clone_replace, node_rewriter +from pytensor.graph import Apply, Variable, node_rewriter from pytensor.graph.rewriting.basic import copy_stack_trace, dfs_rewriter from pytensor.tensor.basic import AllocDiag from pytensor.tensor.rewriting.basic import register_specialize @@ -10,13 +8,11 @@ def inline_ofg_node(node: Apply) -> list[Variable]: - op = node.op - assert isinstance(op, OpFromGraph) - inlined_outs = clone_replace( - op.inner_outputs, dict(zip(op.inner_inputs, node.inputs, strict=True)) - ) - copy_stack_trace(op.inner_outputs, inlined_outs) - return cast(list[Variable], inlined_outs) + frozen_fg = node.op._frozen_fgraph + replacements = dict(zip(frozen_fg.inputs, node.inputs)) + inlined_outs = frozen_fg.bind(replacements) + copy_stack_trace(frozen_fg.outputs, inlined_outs) + return inlined_outs @node_rewriter([OpFromGraph]) diff --git a/tests/scalar/test_basic.py b/tests/scalar/test_basic.py index 94444c327b..73a074e73e 100644 --- a/tests/scalar/test_basic.py +++ b/tests/scalar/test_basic.py @@ -84,18 +84,6 @@ def test_straightforward(self): fn = make_function(DualLinker().accept(g)) assert fn(1.0, 2.0) == 1.5 - def test_flatten(self): - # Test that we flatten multiple Composite. - x, y, z = floats("xyz") - C = Composite([x, y], [x + y]) - CC = Composite([x, y], [C(x * y, y)]) - assert not isinstance(CC.outputs[0].owner.op, Composite) - - # Test with multiple outputs - CC = Composite([x, y, z], [C(x * y, y), C(x * z, y)]) - # We don't flatten that case. - assert isinstance(CC.outputs[0].owner.op, Composite) - def test_shared_identity(self): x, y = floats("xy") c1 = Composite([x, y], [x + y]) diff --git a/tests/tensor/rewriting/test_ofg.py b/tests/tensor/rewriting/test_ofg.py index 6304939562..ac0acfa73d 100644 --- a/tests/tensor/rewriting/test_ofg.py +++ b/tests/tensor/rewriting/test_ofg.py @@ -4,6 +4,9 @@ import pytensor.tensor as pt from pytensor import config from pytensor.compile.builders import OpFromGraph +from pytensor.graph.fg import FunctionGraph +from pytensor.graph.rewriting.basic import dfs_rewriter +from pytensor.tensor.rewriting.ofg import inline_ofg_expansion @pytest.mark.skipif( @@ -20,3 +23,18 @@ def test_alloc_diag_inlined(): nodes = f.maker.fgraph.apply_nodes assert not any(isinstance(node.op, OpFromGraph) for node in nodes) + + +def test_expansion_no_cloning(): + x = pt.scalar("x") + y = pt.exp(x) + + inner_y = y.type() + ofg = OpFromGraph([inner_y], [pt.cos(inner_y)], inline=True) + z = ofg(y) + + fg = FunctionGraph(outputs=[y, z]) + assert len(fg.toposort()) == 2 + + dfs_rewriter(inline_ofg_expansion).rewrite(fg) + assert len(fg.toposort()) == 2, len(fg.toposort())