Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pytensor/graph/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,8 +978,7 @@ def clone_node_and_cache(

clone_d[node] = new_node

if new_node.op is not node.op:
clone_d.setdefault(node.op, new_node.op)
clone_d.setdefault(node.op, new_node.op)

for old_o, new_o in zip(node.outputs, new_node.outputs, strict=True):
clone_d.setdefault(old_o, new_o)
Expand Down
20 changes: 13 additions & 7 deletions pytensor/graph/fg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,27 +1074,33 @@ def clients(self) -> dict[Variable, list[ClientType]]: # type: ignore[override]
self._clients = clients
return self._clients

def unfreeze(self) -> "FunctionGraph":
"""Return a mutable FunctionGraph with fresh mutable Apply nodes."""
memo: dict[Variable, Variable] = {inp: inp.type() for inp in self.inputs}
def bind(self, replace: dict[Variable, Variable]) -> list[Variable]:
"""Return fresh outputs with root inputs substituted per *replace*.

Constants are reused; any non-Constant input not in *replace* raises KeyError.
"""
memo = replace.copy()
for node in self.toposort():
for i in node.inputs:
if i not in memo:
if isinstance(i, AtomicVariable):
if isinstance(i, Constant):
memo[i] = i
else:
memo[i] = i.clone()
raise KeyError(f"Missing replacement for input {i}")

new_node = Apply(
node.op,
[memo[i] for i in node.inputs],
[o.type() for o in node.outputs],
)
memo.update(zip(node.outputs, new_node.outputs))
return [memo[out] for out in self.outputs]

def unfreeze(self) -> "FunctionGraph":
"""Return a mutable FunctionGraph with fresh mutable Apply nodes."""
fresh_inputs = [inp.type() for inp in self.inputs]
return FunctionGraph(
[memo[i] for i in self.inputs],
[memo[o] for o in self.outputs],
fresh_inputs,
self.bind(dict(zip(self.inputs, fresh_inputs))),
clone=False,
)
29 changes: 4 additions & 25 deletions pytensor/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4239,25 +4239,6 @@ def __init__(
for i in inputs:
assert i not in outputs # This isn't supported, use identity

# Flatten nested Composites in single-output case
if len(outputs) == 1 and any(
var.owner is not None and isinstance(var.owner.op, Composite)
for var in outputs
):
inner_op = outputs[0].owner.op
inner_fgraph = inner_op.fgraph.unfreeze()
res = pytensor.compile.rebuild_collect_shared(
inputs=inputs, outputs=outputs[0].owner.inputs, copy_inputs_over=False
)
res2 = pytensor.compile.rebuild_collect_shared(
inputs=inner_fgraph.inputs,
outputs=inner_fgraph.outputs,
replace=dict(zip(inner_fgraph.inputs, res[1], strict=True)),
)
assert len(res2[1]) == len(outputs)
assert len(res[0]) == len(inputs)
inputs, outputs = res[0], res2[1]

self.fgraph = FrozenFunctionGraph(inputs, outputs)
self._validate_inner_graph(self.fgraph)

Expand All @@ -4282,8 +4263,7 @@ def __str__(self):
return self._name

def clone(self):
mutable_fg = self.fgraph.unfreeze()
return self.__class__(mutable_fg.inputs, mutable_fg.outputs)
return self # Op is immutable

def output_types(self, input_types):
if tuple(input_types) != self.inputs_type:
Expand All @@ -4297,12 +4277,11 @@ def make_node(self, *inputs):
return super().make_node(*inputs)
else:
# Make a new op with the right input types.
# Unfreeze the frozen inner graph for rebuild_collect_shared.
assert len(inputs) == self.nin
mutable_fg = self.fgraph.unfreeze()
fg = self.fgraph
res = pytensor.compile.rebuild_collect_shared(
mutable_fg.outputs,
replace=dict(zip(mutable_fg.inputs, inputs, strict=True)),
fg.outputs,
replace=dict(zip(fg.inputs, inputs, strict=True)),
rebuild_strict=False,
)
# After rebuild_collect_shared, the Variable in inputs
Expand Down
24 changes: 4 additions & 20 deletions pytensor/scalar/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,23 +104,7 @@ def _validate_updates(
)

def clone(self, name=None, **kwargs):
mutable_fg = self.fgraph.unfreeze()
inputs = mutable_fg.inputs
outputs = mutable_fg.outputs
if self.is_while:
*update, until = outputs
else:
update, until = outputs, None
init = inputs[: len(update)]
constant = inputs[len(update) :]
return self.__class__(
init=init,
update=update,
constant=constant,
until=until,
name=self.name if name is None else name,
**kwargs,
)
return self # Op is immutable

@property
def fn(self):
Expand All @@ -140,10 +124,10 @@ def make_node(self, n_steps, *inputs):
return super().make_node(n_steps, *inputs)
else:
# Make a new op with the right input types.
mutable_fg = self.fgraph.unfreeze()
fg = self.fgraph
res = rebuild_collect_shared(
mutable_fg.outputs,
replace=dict(zip(mutable_fg.inputs, inputs, strict=True)),
fg.outputs,
replace=dict(zip(fg.inputs, inputs, strict=True)),
rebuild_strict=False,
)
if self.is_while:
Expand Down
2 changes: 1 addition & 1 deletion pytensor/sparse/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,7 @@ def infer_shape(self, fgraph, node, shapes):
class SparseFromDense(Op):
"""Convert a dense matrix to a sparse matrix."""

__props__ = ()
__props__ = ("format",)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This showed up during node.op deduplication in rebuild_collect_shared. We were mixing distinct SparseFromDense


def __init__(self, format):
self.format = format
Expand Down
9 changes: 4 additions & 5 deletions pytensor/tensor/rewriting/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from pytensor.graph.features import ReplaceValidate
from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.op import Op
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import (
GraphRewriter,
copy_stack_trace,
Expand Down Expand Up @@ -1069,13 +1068,13 @@ def local_inline_composite_constants(fgraph, node):
if not inlineable:
return None

mutable_fg = composite_op.fgraph.unfreeze()
composite_fg = composite_op.fgraph
inlineable_indices = {i for i, _ in inlineable}
new_outer_inputs = []
new_inner_inputs = []
inner_replacements = {}
inner_replacements = {i: i for i in composite_fg.inputs}
for i, (outer_inp, inner_inp) in enumerate(
zip(node.inputs, mutable_fg.inputs, strict=True)
zip(node.inputs, composite_fg.inputs, strict=True)
):
if i in inlineable_indices:
inner_replacements[inner_inp] = scalar_constant(
Expand All @@ -1085,7 +1084,7 @@ def local_inline_composite_constants(fgraph, node):
new_outer_inputs.append(outer_inp)
new_inner_inputs.append(inner_inp)

new_inner_outs = clone_replace(mutable_fg.outputs, replace=inner_replacements)
new_inner_outs = composite_fg.bind(inner_replacements)
new_composite_op = Composite(new_inner_inputs, new_inner_outs)
new_outputs = Elemwise(new_composite_op).make_node(*new_outer_inputs).outputs

Expand Down
16 changes: 6 additions & 10 deletions pytensor/tensor/rewriting/ofg.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,18 @@
from typing import cast

from pytensor.compile import optdb
from pytensor.compile.builders import OpFromGraph
from pytensor.graph import Apply, Variable, clone_replace, node_rewriter
from pytensor.graph import Apply, Variable, node_rewriter
from pytensor.graph.rewriting.basic import copy_stack_trace, dfs_rewriter
from pytensor.tensor.basic import AllocDiag
from pytensor.tensor.rewriting.basic import register_specialize
from pytensor.tensor.special import XLog1PY, XLogY


def inline_ofg_node(node: Apply) -> list[Variable]:
op = node.op
assert isinstance(op, OpFromGraph)
inlined_outs = clone_replace(
op.inner_outputs, dict(zip(op.inner_inputs, node.inputs, strict=True))
)
copy_stack_trace(op.inner_outputs, inlined_outs)
return cast(list[Variable], inlined_outs)
frozen_fg = node.op._frozen_fgraph
replacements = dict(zip(frozen_fg.inputs, node.inputs))
inlined_outs = frozen_fg.bind(replacements)
copy_stack_trace(frozen_fg.outputs, inlined_outs)
return inlined_outs


@node_rewriter([OpFromGraph])
Expand Down
12 changes: 0 additions & 12 deletions tests/scalar/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,18 +84,6 @@ def test_straightforward(self):
fn = make_function(DualLinker().accept(g))
assert fn(1.0, 2.0) == 1.5

def test_flatten(self):
# Test that we flatten multiple Composite.
x, y, z = floats("xyz")
C = Composite([x, y], [x + y])
CC = Composite([x, y], [C(x * y, y)])
assert not isinstance(CC.outputs[0].owner.op, Composite)

# Test with multiple outputs
CC = Composite([x, y, z], [C(x * y, y), C(x * z, y)])
# We don't flatten that case.
assert isinstance(CC.outputs[0].owner.op, Composite)

def test_shared_identity(self):
x, y = floats("xy")
c1 = Composite([x, y], [x + y])
Expand Down
18 changes: 18 additions & 0 deletions tests/tensor/rewriting/test_ofg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import pytensor.tensor as pt
from pytensor import config
from pytensor.compile.builders import OpFromGraph
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import dfs_rewriter
from pytensor.tensor.rewriting.ofg import inline_ofg_expansion


@pytest.mark.skipif(
Expand All @@ -20,3 +23,18 @@ def test_alloc_diag_inlined():
nodes = f.maker.fgraph.apply_nodes

assert not any(isinstance(node.op, OpFromGraph) for node in nodes)


def test_expansion_no_cloning():
x = pt.scalar("x")
y = pt.exp(x)

inner_y = y.type()
ofg = OpFromGraph([inner_y], [pt.cos(inner_y)], inline=True)
z = ofg(y)

fg = FunctionGraph(outputs=[y, z])
assert len(fg.toposort()) == 2

dfs_rewriter(inline_ofg_expansion).rewrite(fg)
assert len(fg.toposort()) == 2, len(fg.toposort())
Loading