diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index f924288a86..4ec347a830 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -1,24 +1,26 @@ """Define new Ops from existing Ops""" +from __future__ import annotations + import warnings from collections.abc import Callable, Sequence from copy import copy from functools import partial from itertools import chain -from typing import Union, cast +from typing import cast from pytensor.compile.function import function from pytensor.compile.function.pfunc import rebuild_collect_shared from pytensor.compile.sharedvalue import SharedVariable from pytensor.configdefaults import config -from pytensor.gradient import DisconnectedType, Rop, grad +from pytensor.gradient import DisconnectedType, Rop, disconnected_type, grad from pytensor.graph.basic import ( Apply, Constant, NominalVariable, Variable, ) -from pytensor.graph.fg import FunctionGraph +from pytensor.graph.fg import FrozenFunctionGraph, FunctionGraph from pytensor.graph.null_type import NullType from pytensor.graph.op import HasInnerGraph, Op, io_connection_pattern from pytensor.graph.replace import clone_replace @@ -156,41 +158,37 @@ def construct_nominal_fgraph( class OpFromGraph(Op, HasInnerGraph): - r""" - This creates an `Op` from inputs and outputs lists of variables. - The signature is similar to :func:`pytensor.function ` - and the resulting `Op`'s perform will do the same operation as:: + r"""Create an Op from inputs and outputs lists of variables. - orig_function(inputs, outputs, **kwargs) + The signature is similar to :func:`pytensor.function` and the resulting Op's perform will do + the same operation as ``orig_function(inputs, outputs, **kwargs)``. - Currently does not support ``updates`` or ``givens`` argument. + Does not support ``updates`` or ``givens``. - .. TODO: - - Allow / test merging of OpFromGraph nodes + .. TODO:: - Add support for NullType and DisconnectedType when R_op supports them - - Add support to pickle this Op. - Add optimization to removing unused inputs/outputs - Add optimization to work inplace on inputs when not inline Notes ----- - - We support shared variables in the inner graph. This is automatic - and invisible to the user. They can be as input to the node or in - the inner graph. - - We support unused inputs. This is needed for the grad. - - We support nested OpFromGraph. - - ``inline=True`` will cause better runtime optimization at the cost - of compilation time. Currently only works with ``fast_compile`` or - ``fast_run`` mode. - - For overriding, it's recommended to provide pure functions (no side - effects like setting global variable) as callable(s). The callable(s) - supplied for overriding gradient/rop will be called only once at the - first call to L_op/R_op, and will be converted to OpFromGraph instances. + - Shared variables in the inner graph are supported. They are detected automatically and added + as implicit inputs. + - Unused inputs are supported (needed for gradient overrides). + - Nested OpFromGraph is supported. + - ``inline=True`` causes the Op's inner graph to be inlined during compilation, which gives + better runtime optimization at the cost of compilation time. Currently only works with + ``fast_compile`` or ``fast_run`` mode. + - Override callables should be pure functions (no side effects). They are called once at the + first call to L_op/R_op and converted to OpFromGraph instances. They are also called once at + construction time with dummy inputs to build a frozen representation for equality comparison. + - Two OpFromGraph instances with the same inner graph, overrides, shared variables, and settings + are considered equal. This allows the MergeOptimizer to deduplicate identical OpFromGraph + nodes. Examples -------- - - Example 1: + Basic usage: .. code-block:: python @@ -204,7 +202,7 @@ class OpFromGraph(Op, HasInnerGraph): e2 = op(x, y, z) + op(z, y, x) fn = function([x, y, z], [e2]) - Example 2 with shared variable: + With a shared variable: .. code-block:: python @@ -217,11 +215,10 @@ class OpFromGraph(Op, HasInnerGraph): s = pytensor.shared(np.random.random((2, 2)).astype(config.floatX)) e = x + y * z + s op = OpFromGraph([x, y, z], [e]) - # op behaves like a normal pytensor op e2 = op(x, y, z) + op(z, y, x) fn = function([x, y, z], [e2]) - Example 3 override second output of L_op + Per-input L_op override: .. code-block:: python @@ -238,17 +235,12 @@ def rescale_dy(inps, outputs, out_grads): return z * 2 - op = OpFromGraph( - [x, y, z], - [e], - lop_overrides=[None, rescale_dy, None], - ) + op = OpFromGraph([x, y, z], [e], lop_overrides=[None, rescale_dy, None]) e2 = op(x, y, z) dx, dy, dz = grad(e2, [x, y, z]) fn = function([x, y, z], [dx, dy, dz]) # the gradient wrt y is now doubled fn(2.0, 3.0, 4.0) # [1., 8., 3.] - """ def __init__( @@ -257,9 +249,9 @@ def __init__( outputs: list[Variable], *, inline: bool = False, - lop_overrides: Union[Callable, "OpFromGraph", None] = None, - grad_overrides: Union[Callable, "OpFromGraph", None] = None, - rop_overrides: Union[Callable, "OpFromGraph", None] = None, + lop_overrides: Callable | list | OpFromGraph | None = None, + grad_overrides: Callable | list | OpFromGraph | None = None, + rop_overrides: Callable | list | OpFromGraph | None = None, connection_pattern: list[list[bool]] | None = None, strict: bool = False, name: str | None = None, @@ -269,98 +261,54 @@ def __init__( """ Parameters ---------- - inputs + inputs : list of Variable The inputs to the graph. - - outputs + outputs : list of Variable The outputs to the graph. - - inline - Defaults to ``False`` - - ``True`` : Cause the :class:`Op`'s original graph being used during - compilation, the :class:`Op` will not be visible in the compiled - graph but rather its internal graph. - - ``False`` : will use a pre-compiled function inside. - - grad_overrides - Defaults to ``None``. - This argument is mutually exclusive with ``lop_overrides``. - - ``None`` : Do not override, use default grad() result - - `OpFromGraph`: Override with another `OpFromGraph`, should - accept inputs as the same order and types of ``inputs`` and ``output_grads`` - arguments as one would specify in :meth:`Op.grad`() method. - - `callable`: Should take two args: ``inputs`` and ``output_grads``. - Each argument is expected to be a list of :class:`Variable `. - Must return list of :class:`Variable `. - - lop_overrides - Defaults to ``None``. - - This argument is mutually exclusive with ``grad_overrides``. - - These options are similar to the ``grad_overrides`` above, but for - the :meth:`Op.L_op` method. - - ``None``: Do not override, use the default :meth:`Op.L_op` result - - `OpFromGraph`: Override with another `OpFromGraph`, should - accept inputs as the same order and types of ``inputs``, - ``outputs`` and ``output_grads`` arguments as one would specify in - :meth:`Op.grad` method. - - `callable`: Should take three args: ``inputs``, ``outputs`` and ``output_grads``. - Each argument is expected to be a list of :class:`Variable`. - Must return list of :class:`Variable`. - - ``list``: Each `OpFromGraph`/callable must return a single - :class:`Variable`. Each list element corresponds to gradient of - a specific input, length of list must be equal to number of inputs. - - rop_overrides - One of ``{None, OpFromGraph, callable, Variable}``. - - Defaults to ``None``. - - ``None``: Do not override, use the default :meth:`Op.R_op` result - - `OpFromGraph`: Override with another `OpFromGraph`, should - accept inputs as the same order and types of ``inputs`` and ``eval_points`` - arguments as one would specify in :meth:`Op.R_op` method. - - `callable`: Should take two args: ``inputs`` and ``eval_points``. - Each argument is expected to be a list of :class:`Variable`. Must - return list of :class:`Variable`. - - ``list``: - Each :class:`OpFromGraph`/callable must return a single - :class:`Variable `. Each list element - corresponds to a specific output of :meth:`Op.R_op`, length of list - must be equal to number of outputs. connection_pattern If not - ``None``, this will be used as the connection_pattern for this - :class:`Op`. - - .. warning:: - - rop overrides is ignored when `pytensor.gradient.Rop` is called with - `use_op_rop_implementation=False` (default). In this case the Lop - is used twice to obtain a mathematically equivalent Rop. - - strict: bool, default False - If true, it raises when any variables needed to compute the inner graph - are not provided as explici inputs. This can only happen for graphs with - shared variables. - - name + inline : bool, optional + If True, the Op's inner graph is inlined during compilation. If False (default), a + pre-compiled function is used instead. + lop_overrides : callable or OpFromGraph or list or None, optional + Override for the L_op method. Mutually exclusive with ``grad_overrides``. + + - None: use the default L_op result. + - OpFromGraph: should accept ``(inputs, outputs, output_grads)`` with the same types + as the inner graph. + - callable: should take three args ``(inputs, outputs, output_grads)``, each a list of + Variable, and return a list of Variable. + - list: one entry per input. Each entry is None (use default), a DisconnectedType or + NullType Variable, or a callable returning a single Variable. + grad_overrides : callable or OpFromGraph or list or None, optional + Deprecated in favor of ``lop_overrides``. Same as ``lop_overrides`` but the callable + signature is ``(inputs, output_grads)`` (no ``outputs`` argument). Mutually exclusive + with ``lop_overrides``. + rop_overrides : callable or OpFromGraph or list or None, optional + Override for the R_op method. + + - None: use the default R_op result. + - OpFromGraph: should accept ``(inputs, eval_points)`` with the same types as the + inner graph inputs. + - callable: should take two args ``(inputs, eval_points)``, each a list of Variable, + and return a list of Variable. + - list: one entry per output. Each entry is None (use default), a DisconnectedType or + NullType Variable, or a callable returning a single Variable. + + .. warning:: + + R_op overrides are ignored when ``pytensor.gradient.Rop`` is called with + ``use_op_rop_implementation=False`` (the default). In that case the L_op is used + twice to obtain a mathematically equivalent R_op. + connection_pattern : list of list of bool, optional + If provided, used as the connection pattern for this Op. Each inner list has one bool + per output, and the outer list has one entry per input. + strict : bool, optional + If True, raises when any variables needed to compute the inner graph are not provided + as explicit inputs. Only relevant for graphs with shared variables. Default False. + name : str, optional A name for debugging purposes. - - kwargs - Check :func:`pytensor.function` for more arguments, only works when not - inline. + **kwargs + Additional arguments passed to :func:`pytensor.function`. Only used when + ``inline=False``. """ ignore_unused_inputs = kwargs.get("on_unused_input", False) == "ignore" if not ignore_unused_inputs and len(inputs) != len(set(inputs)): @@ -389,6 +337,7 @@ def __init__( self.fgraph, self.shared_inputs, _, _ = construct_nominal_fgraph( inputs, outputs ) + self._frozen_fgraph = self.fgraph.freeze() if strict and self.shared_inputs: raise ValueError( @@ -438,13 +387,117 @@ def __init__( self.name = name self.destroy_map = destroy_map if destroy_map is not None else {} + self._frozen_lop = self._freeze_lop_override() + self._frozen_rop = self._freeze_rop_override() + + @staticmethod + def _freeze_override_to_fgraph( + all_inputs: list[Variable], results: list[Variable] + ) -> tuple[tuple[bool, ...], FrozenFunctionGraph | None]: + """Build a FrozenFunctionGraph from override results, filtering out disconnected/null types. + + A structural connection pattern is also constructed from the results, indicating where filtering occured. + This pattern, together with the FrozenFunctionGraph, participate in the equality comparison of OpFromGraph + instances to distinguish different override implementations. + """ + pattern = tuple( + isinstance(r.type, DisconnectedType | NullType) for r in results + ) + connected = [ + r for r, is_disc in zip(results, pattern, strict=True) if not is_disc + ] + if not connected: + return pattern, None + return pattern, FunctionGraph(all_inputs, connected).freeze() + + def _freeze_lop_override(self): + """Freeze lop/grad override into a FrozenFunctionGraph for equality comparison.""" + lop = self.lop_overrides if self._lop_op_interface else self.grad_overrides + if lop is None: + return None + + if isinstance(lop, OpFromGraph): + return lop._frozen_fgraph + + dummy_inputs = [t() for t in self.input_types] + dummy_outputs = [t() for t in self.output_types] + dummy_output_grads = [t() for t in self.output_types] + + if self._lop_op_interface: + all_inputs = dummy_inputs + dummy_outputs + dummy_output_grads + callable_args = (dummy_inputs, dummy_outputs, dummy_output_grads) + else: + all_inputs = dummy_inputs + dummy_output_grads + callable_args = (dummy_inputs, dummy_output_grads) + + if isinstance(lop, list): + results = [] + for entry in lop: + if entry is None: + results.append(disconnected_type()) + elif isinstance(entry, Variable): + results.append(entry) + elif callable(entry): + results.append(entry(*callable_args)) + return self._freeze_override_to_fgraph(all_inputs, results) + + result = lop(*callable_args) + return self._freeze_override_to_fgraph(all_inputs, result) + + def _freeze_rop_override(self): + """Freeze rop override into a FrozenFunctionGraph for equality comparison.""" + rop = self.rop_overrides + if rop is None: + return None + + if isinstance(rop, OpFromGraph): + return rop._frozen_fgraph + + dummy_inputs = [t() for t in self.input_types] + dummy_eval_points = [t() for t in self.input_types] + all_inputs = dummy_inputs + dummy_eval_points + callable_args = (dummy_inputs, dummy_eval_points) + + if isinstance(rop, list): + results = [] + for entry in rop: + if entry is None: + results.append(disconnected_type()) + elif isinstance(entry, Variable): + results.append(entry) + elif callable(entry): + results.append(entry(*callable_args)) + return self._freeze_override_to_fgraph(all_inputs, results) + + result = rop(*callable_args) + return self._freeze_override_to_fgraph(all_inputs, result) + def __eq__(self, other): - # TODO: recognize a copy - return self is other + if self is other: + return True + if type(self) is not type(other): + return False + if self._frozen_fgraph != other._frozen_fgraph: + return False + if self.is_inline != other.is_inline: + return False + if self.destroy_map != other.destroy_map: + return False + if len(self.shared_inputs) != len(other.shared_inputs): + return False + if any( + a is not b + for a, b in zip(self.shared_inputs, other.shared_inputs, strict=True) + ): + return False + if self._frozen_lop != other._frozen_lop: + return False + if self._frozen_rop != other._frozen_rop: + return False + return True def __hash__(self): - # TODO: use internal variables in hash - return hash(type(self)) + return hash((type(self), self._frozen_fgraph, self.is_inline)) def __str__(self): name = self.__class__.__name__ if self.name is None else self.name diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index 1ee46f3449..f4aeed5be5 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -2,6 +2,7 @@ import abc import warnings +import weakref from collections.abc import ( Hashable, Iterable, @@ -824,6 +825,16 @@ def __repr__(self): def clone(self, **kwargs): return self + def equals(self, other): + if not isinstance(other, type(self)): + return False + if self.type != other.type: + return False + try: + return np.array_equal(self.data, other.data, equal_nan=True) + except (TypeError, ValueError): + return self.data == other.data + @property def owner(self) -> None: return None @@ -838,6 +849,119 @@ def value(self): return self.data +class FrozenConstant(Constant): + """A globally-interned Constant for use in frozen graphs. + + Two ``FrozenConstant`` instances with the same type and data are the same object. + """ + + _cache: weakref.WeakValueDictionary = weakref.WeakValueDictionary() + _filtered: Any + + # FrozenConstant doesn't inherit the scalar mixin that provides .dtype, + # but scalar C code generation expects it on all variables. + @property + def dtype(self): + return self.type.dtype + + def __new__(cls, type: _TypeType, data: Any, name: str | None = None): + filtered = type.filter(data) + cache_key = cls._make_key(type, filtered) + if cache_key is not None: + cached = cls._cache.get(cache_key) + if cached is not None: + return cached + instance = object.__new__(cls) + # Store filtered data now so __init__ can skip re-filtering + instance._filtered = filtered + if cache_key is not None: + cls._cache[cache_key] = instance + return instance + + def __init__(self, type: _TypeType, data: Any, name: str | None = None): + if hasattr(self, "data"): + return + # Use pre-filtered data from __new__ to avoid a second type.filter() call + AtomicVariable.__init__(self, type, name=name) + self.data = self._filtered + del self._filtered + add_tag_trace(self) + + @staticmethod + def _make_key(type, filtered): + if isinstance(filtered, np.ndarray): + from pytensor.tensor.utils import hash_from_ndarray + + return type, hash_from_ndarray(filtered) + if isinstance(filtered, np.generic): + from pytensor.tensor.utils import hash_from_ndarray + + return type, hash_from_ndarray(np.asarray(filtered)) + try: + return type, hash(filtered) + except TypeError: + return None + + def __reduce__(self): + return (type(self), (self.type, self.data, self.name)) + + +class FrozenApply(Apply): + """An immutable, globally-interned Apply node for frozen graphs. + + Uses tuples for ``inputs`` and ``outputs`` so mutation raises ``TypeError`` + at the language level. Interned by ``(op, inputs)`` — constructing a + ``FrozenApply`` with an ``op`` and ``inputs`` that match an existing live + instance returns that instance. + """ + + _cache: weakref.WeakValueDictionary = weakref.WeakValueDictionary() + + def __new__( + cls, op: "Op", inputs: tuple[Variable, ...], output_types: tuple["Type", ...] + ): + # Canonicalize inputs through their owner's outputs to ensure cache hits after unpickling. + inputs = tuple( + inp.owner.outputs[inp.index] + if inp.owner is not None and isinstance(inp.owner, FrozenApply) + else inp + for inp in inputs + ) + key = (op, inputs) + cached = cls._cache.get(key) + if cached is not None: + return cached + + instance = object.__new__(cls) + instance.op = op + instance.inputs = inputs # type: ignore[assignment] + instance.outputs = tuple( # type: ignore[assignment] + t.variable_type(type=t, owner=instance, index=i) + for i, t in enumerate(output_types) + ) + instance.tag = Scratchpad() + cls._cache[key] = instance + return instance + + def __init__(self, op, inputs, output_types): + # All initialization is done in __new__ + pass + + def clone(self, clone_inner_graph: bool = False) -> "Apply": + """Clone into a mutable Apply node.""" + from pytensor.graph.op import HasInnerGraph + + new_op = self.op + if isinstance(new_op, HasInnerGraph) and clone_inner_graph: + new_op = new_op.clone() + + return Apply(new_op, list(self.inputs), [o.clone() for o in self.outputs]) + + def __reduce__(self): + output_types = tuple(o.type for o in self.outputs) + return (type(self), (self.op, self.inputs, output_types)) + + def clone( inputs: Sequence[Variable], outputs: Sequence[Variable], @@ -1154,14 +1278,14 @@ def equal_computations( for x, y in zip(xs, ys, strict=True): if not isinstance(x, Variable) and not isinstance(y, Variable): - return np.array_equal(x, y) + return np.array_equal(x, y, equal_nan=True) if not isinstance(x, Variable): if isinstance(y, Constant): - return np.array_equal(y.data, x) + return np.array_equal(y.data, x, equal_nan=True) return False if not isinstance(y, Variable): if isinstance(x, Constant): - return np.array_equal(x.data, y) + return np.array_equal(x.data, y, equal_nan=True) return False x_is_owned, y_is_owned = (x.owner is not None, y.owner is not None) if x_is_owned != y_is_owned: diff --git a/pytensor/graph/fg.py b/pytensor/graph/fg.py index a0a71fab4f..76b2bd4573 100644 --- a/pytensor/graph/fg.py +++ b/pytensor/graph/fg.py @@ -10,6 +10,8 @@ from pytensor.graph.basic import ( Apply, AtomicVariable, + Constant, + NominalVariable, Variable, clone_get_equiv, ) @@ -928,3 +930,154 @@ def dprint(self, **kwargs): from pytensor.printing import debugprint return debugprint(self, **kwargs) + + def freeze(self) -> "FrozenFunctionGraph": + """Return a frozen, hashable version of this FunctionGraph.""" + return FrozenFunctionGraph(self.inputs, self.outputs) + + +class FrozenFunctionGraph: + """Immutable, hashable function graph for inner graphs of Ops. + + All internal nodes are globally interned via ``FrozenApply`` and ``FrozenConstant``. Two ``FrozenFunctionGraph`` + instances built from structurally identical source graphs share the same internal objects, so equality reduces to + an identity check on the output tuples. + + .. code-block:: python + + from pytensor.scalar.basic import float64, add + from pytensor.graph.fg import FunctionGraph + + x, y = float64("x"), float64("y") + fg = FunctionGraph([x, y], [add(x, y)]) + frozen = fg.freeze() + frozen2 = FunctionGraph([x, y], [add(x, y)]).freeze() + + assert frozen == frozen2 + assert {frozen: "value"}[frozen2] == "value" + """ + + def __init__( + self, + inputs: Sequence[Variable], + outputs: Sequence[Variable], + ): + from pytensor.graph.basic import ( + FrozenApply, + FrozenConstant, + ) + + nominal_inputs = tuple( + NominalVariable(i, inp.type, name=inp.name) for i, inp in enumerate(inputs) + ) + + memo: dict[Variable, Variable] = dict(zip(inputs, nominal_inputs, strict=True)) + + var_hash: dict[Variable, int] = {} + for i, nm in enumerate(nominal_inputs): + var_hash[nm] = hash(("input", i, nm.type)) + + for node in toposort(outputs, blockers=inputs): + for inp in node.inputs: + if inp not in memo: + if isinstance(inp, Constant): + fc = FrozenConstant(inp.type, inp.data) + memo[inp] = fc + if fc not in var_hash: + var_hash[fc] = hash(fc) + elif isinstance(inp, AtomicVariable): + # AtomicVariables (e.g. NominalVariables from outer + # scopes) are already interned and hashable. + memo[inp] = inp + if inp not in var_hash: + var_hash[inp] = hash(inp) + else: + raise ValueError( + f"Non-Constant, non-AtomicVariable orphan {inp} found " + "in the graph. All variables must be graph inputs, " + "Constants, AtomicVariables, or produced by Apply " + "nodes reachable from the inputs." + ) + + new_inputs = tuple(memo[i] for i in node.inputs) + output_types = tuple(out.type for out in node.outputs) + new_node = FrozenApply(node.op, new_inputs, output_types) + + input_hashes = tuple(var_hash[i] for i in new_inputs) + node_hash = hash((node.op, input_hashes)) + for old_out, new_out in zip(node.outputs, new_node.outputs, strict=True): + memo[old_out] = new_out + var_hash[new_out] = hash((node_hash, new_out.index)) + + self.inputs: tuple[Variable, ...] = nominal_inputs + + resolved_outputs = [] + for o in outputs: + mapped = memo.get(o) + # After unpickling, o may be a fresh object whose owner is the (correctly interned) FrozenApply. + # We thus resolve it through its owner to get back the original variable. + if mapped is None and o.owner is not None: + mapped = memo.get(o.owner.outputs[o.index]) + if mapped is None: + raise ValueError( + f"Output variable {o} could not be mapped to a frozen graph variable. " + "All outputs must be graph inputs, constants, or produced by Apply nodes " + "reachable from the inputs." + ) + resolved_outputs.append(mapped) + self.outputs: tuple[Variable, ...] = tuple(resolved_outputs) + + self._structural_hash: int = hash(tuple(var_hash[o] for o in self.outputs)) + + def __hash__(self): + return self._structural_hash + + def __eq__(self, other): + if self is other: + return True + if not isinstance(other, FrozenFunctionGraph): + return False + if self._structural_hash != other._structural_hash: + return False + if self.outputs == other.outputs: + return True + # Hash match but output identity mismatch — likely a hash collision + # or interning bug. Fall back to structural comparison. + import warnings + + from pytensor.graph.basic import equal_computations + + if ( + len(self.outputs) == len(other.outputs) + and len(self.inputs) == len(other.inputs) + and equal_computations( + list(self.outputs), + list(other.outputs), + in_xs=list(self.inputs), + in_ys=list(other.inputs), + ) + ): + warnings.warn( + "FrozenFunctionGraph: structurally equal graphs did not share " + "interned objects. This may indicate an interning bug.", + stacklevel=2, + ) + return True + return False + + def __repr__(self): + return f"FrozenFunctionGraph(inputs={list(self.inputs)}, outputs={list(self.outputs)})" + + def __reduce__(self): + return (type(self), (list(self.inputs), list(self.outputs))) + + @property + def apply_nodes(self) -> set[Apply]: + return set(applys_between(self.inputs, self.outputs)) + + def toposort(self) -> list[Apply]: + return list(toposort(self.outputs, blockers=self.inputs)) + + @property + def variables(self) -> set[Variable]: + return set(vars_between(self.inputs, self.outputs)) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index b59cc9992f..d06f5d1f4b 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -23,10 +23,9 @@ from pytensor import printing from pytensor.configdefaults import config from pytensor.gradient import disconnected_type, grad_undefined -from pytensor.graph.basic import Apply, Constant, Variable, clone +from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import HasInnerGraph -from pytensor.graph.rewriting.basic import MergeOptimizer from pytensor.graph.traversal import applys_between from pytensor.graph.type import HasDataType, HasShape from pytensor.graph.utils import MetaObject, MethodNotDefined @@ -984,6 +983,8 @@ def as_scalar(x: Any, name: str | None = None) -> ScalarVariable: if isinstance(x.type, TensorType) and x.type.ndim == 0: return scalar_from_tensor(x) + elif isinstance(x, Constant) and isinstance(x.type, ScalarType): + return ScalarConstant(x.type, x.data, name=x.name) else: raise TypeError(f"Cannot convert {x} to a scalar type") @@ -4004,14 +4005,8 @@ def __init__(self, *args, **kwargs): self.prepare_node_called = set() super().__init__(*args, **kwargs) - def _cleanup_graph(self, inputs, outputs, clone: builtins.bool = True): - # TODO: We could convert to TensorVariable, optimize graph, - # and then convert back to ScalarVariable. - # This would introduce rewrites like `log(1 + x) -> log1p`. - - fgraph = FunctionGraph(inputs, outputs, clone=clone) - - # Validate node types + def _validate_inner_graph(self, fgraph): + """Validate that all ops in the inner graph are ScalarOps.""" for node in fgraph.apply_nodes: if not isinstance(node.op, ScalarOp): raise TypeError( @@ -4019,27 +4014,6 @@ def _cleanup_graph(self, inputs, outputs, clone: builtins.bool = True): "composed of scalar operations." ) - # Run MergeOptimization to avoid duplicated nodes - MergeOptimizer().rewrite(fgraph) - - inputs, outputs = fgraph.inputs, fgraph.outputs - - # Clone identical outputs that may have been merged - # If fgraph.outputs = [out_A, out_B, out_A], then final outputs = [out_A, out_B, clone(out_A)] - if len(set(fgraph.outputs)) != len(outputs): - old_outputs = outputs - outputs = [] - for old_output in old_outputs: - if old_output not in outputs: - outputs.append(old_output) - else: - node = old_output.owner - output_idx = node.outputs.index(old_output) - output = node.clone().outputs[output_idx] - outputs.append(output) - - return inputs, outputs - @property def fn(self): return None @@ -4116,6 +4090,8 @@ def c_support_code(self, **kwargs): return "\n".join(sorted(rval)) def c_support_code_apply(self, node, name): + # Ensure nodenames is populated (side effect of c_code_template) + _ = self.c_code_template rval = [] for subnode, subnodename in zip( self.fgraph.toposort(), self.nodenames, strict=True @@ -4140,38 +4116,17 @@ def prepare_node(self, node, storage_map, compute_map, impl): def __eq__(self, other): if self is other: return True - if ( - type(self) is not type(other) - or self.nin != other.nin - or self.nout != other.nout - ): + if type(self) is not type(other): return False - - # TODO FIXME: Why this? Shouldn't we expect equivalent inputs to this - # object to generate the same `_c_code`? - return self.c_code_template == other.c_code_template + return self.fgraph == other.fgraph def __hash__(self): - # Note that in general, the configparser settings at the time - # of code generation (__init__) affect the semantics of this Op. - # This function assumes that all relevant info about the configparser - # is embodied in _c_code. So the _c_code, rather than self.fgraph, - # is the signature of the semantics of this Op. - # _c_code is preserved through unpickling, so the Op will not change - # semantics when it is reloaded with different configparser - # settings. - # - # TODO FIXME: Doesn't the above just mean that we should be including - # the relevant "configparser settings" here? Also, why should we even - # care about the exact form of the generated C code when comparing - # `Op`s? All this smells of leaky concerns and interfaces. - return hash((type(self), self.nin, self.nout, self.c_code_template)) + return hash((type(self), self.fgraph)) def __getstate__(self): rval = dict(self.__dict__) rval.pop("_c_code", None) rval.pop("_py_perform_fn", None) - rval.pop("_fgraph", None) rval.pop("prepare_node_called", None) return rval @@ -4193,40 +4148,26 @@ class Composite(ScalarInnerGraphOp): init_param: tuple[str, ...] = ("inputs", "outputs") def __init__( - self, inputs, outputs, name="Composite", clone_graph: builtins.bool = True + self, + inputs, + outputs, + name="Composite", + clone_graph: builtins.bool = True, ): self.name = name self._name = None - # We need to clone the graph as sometimes its nodes already - # contain a reference to an fgraph. As we want the Composite - # to be pickable, we can't have reference to fgraph. - - # Also, if there is Composite in the inner graph, we want to - # remove them. In that case, we do a more complicated clone - # that will flatten Composite. We don't need to do this - # recursively, as the way the fusion optimizer work, we have - # only 1 new Composite each time at the output. + for i in inputs: assert i not in outputs # This isn't supported, use identity - if len(outputs) > 1 or not any( - isinstance(var.owner.op, Composite) for var in outputs + # Flatten nested Composites in single-output case + if len(outputs) == 1 and any( + var.owner is not None and isinstance(var.owner.op, Composite) + for var in outputs ): - if clone_graph: - inputs, outputs = clone(inputs, outputs) - - else: - # Inner Composite that we need to flatten - # FIXME: There could be a composite in the middle of the graph, why is this here? - # If anything it should be an optimization, but I suspect lower-level compilation can handle this anyway. - assert len(outputs) == 1 - # 1. Create a new graph from inputs up to the - # Composite res = pytensor.compile.rebuild_collect_shared( inputs=inputs, outputs=outputs[0].owner.inputs, copy_inputs_over=False - ) # Clone also the inputs - # 2. We continue this partial clone with the graph in - # the inner Composite + ) res2 = pytensor.compile.rebuild_collect_shared( inputs=outputs[0].owner.op.inputs, outputs=outputs[0].owner.op.outputs, @@ -4234,35 +4175,30 @@ def __init__( ) assert len(res2[1]) == len(outputs) assert len(res[0]) == len(inputs) - assert res[0] != inputs inputs, outputs = res[0], res2[1] - # We already cloned the graph, or the user told us there was no need for it - self.inputs, self.outputs = self._cleanup_graph(inputs, outputs, clone=False) - self.inputs_type = tuple(input.type for input in self.inputs) - self.outputs_type = tuple(output.type for output in self.outputs) - self.nin = len(inputs) - self.nout = len(outputs) + fgraph = FunctionGraph(inputs, outputs, clone=clone_graph) + self._validate_inner_graph(fgraph) + self._fgraph = fgraph.freeze() + + self.inputs_type = tuple(inp.type for inp in self._fgraph.inputs) + self.outputs_type = tuple(out.type for out in self._fgraph.outputs) + self.nin = len(self._fgraph.inputs) + self.nout = len(self._fgraph.outputs) super().__init__() + @property + def inputs(self): + return self._fgraph.inputs + + @property + def outputs(self): + return self._fgraph.outputs + def __str__(self): if self._name is not None: return self._name - # Rename internal variables - for i, r in enumerate(self.fgraph.inputs): - r.name = f"i{i}" - for i, r in enumerate(self.fgraph.outputs): - r.name = f"o{i}" - io = set(self.fgraph.inputs + self.fgraph.outputs) - for i, r in enumerate(self.fgraph.variables): - if ( - not isinstance(r, Constant) - and r not in io - and len(self.fgraph.clients[r]) > 1 - ): - r.name = f"t{i}" - if len(self.fgraph.outputs) > 1 or len(self.fgraph.apply_nodes) > 10: self._name = "Composite{...}" else: @@ -4273,12 +4209,6 @@ def __str__(self): @property def fgraph(self): - if hasattr(self, "_fgraph"): - return self._fgraph - # fgraph cannot be a property of the base class because it messes up with C caching. - # We also need a `FunctionGraph(clone=True)` (default) according to an old comment - fgraph = FunctionGraph(self.inputs, self.outputs) - self._fgraph = fgraph return self._fgraph def clone(self): diff --git a/pytensor/scalar/loop.py b/pytensor/scalar/loop.py index 888db3e52b..52408501ef 100644 --- a/pytensor/scalar/loop.py +++ b/pytensor/scalar/loop.py @@ -67,13 +67,18 @@ def __init__( inputs, outputs = clone([*init, *constant], update) self.is_while = until is not None - self.inputs, self.outputs = self._cleanup_graph(inputs, outputs) - self._validate_updates(self.inputs, self.outputs) - self.inputs_type = tuple(input.type for input in self.inputs) - self.outputs_type = tuple(output.type for output in self.outputs) - self.nin = len(self.inputs) + 1 # n_steps is not part of the inner graph - self.nout = len(self.outputs) + fgraph = FunctionGraph(inputs, outputs) + self._validate_inner_graph(fgraph) + self._fgraph = fgraph.freeze() + self._validate_updates(self._fgraph.inputs, self._fgraph.outputs) + + self.inputs_type = tuple(inp.type for inp in self._fgraph.inputs) + self.outputs_type = tuple(out.type for out in self._fgraph.outputs) + self.nin = ( + len(self._fgraph.inputs) + 1 + ) # n_steps is not part of the inner graph + self.nout = len(self._fgraph.outputs) self.name = name super().__init__(**kwargs) @@ -106,14 +111,16 @@ def _validate_updates( "If you want to return an output as a lagged input, wrap it in an identity Op." ) + @property + def inputs(self): + return self._fgraph.inputs + + @property + def outputs(self): + return self._fgraph.outputs + @property def fgraph(self): - if hasattr(self, "_fgraph"): - return self._fgraph - # fgraph cannot be a property of the base class because it messes up with C caching. - # We also need a `FunctionGraph(clone=True)` (default) according to an old comment - fgraph = FunctionGraph(self.inputs, self.outputs) - self._fgraph = fgraph return self._fgraph def clone(self, name=None, **kwargs): diff --git a/tests/compile/test_builders.py b/tests/compile/test_builders.py index 6c6906b31f..3a17b5a6a1 100644 --- a/tests/compile/test_builders.py +++ b/tests/compile/test_builders.py @@ -4,6 +4,7 @@ import pytest import pytensor.tensor as pt +from pytensor import Mode from pytensor.compile import shared from pytensor.compile.builders import OpFromGraph from pytensor.compile.function import function @@ -18,6 +19,7 @@ from pytensor.graph.basic import equal_computations from pytensor.graph.fg import FunctionGraph from pytensor.graph.null_type import NullType, null_type +from pytensor.graph.rewriting.basic import MergeOptimizer from pytensor.graph.rewriting.utils import rewrite_graph from pytensor.graph.utils import MissingInputError from pytensor.printing import debugprint @@ -712,6 +714,181 @@ def test_repeated_inputs(self): f = g(x, x, y) assert f.eval({x: 5, y: 5}) == 10 + def test_equality_and_hashing(self): + x, y = dscalars("x", "y") + e = x + y * x + + op1 = OpFromGraph([x, y], [e]) + op2 = OpFromGraph([x, y], [e]) + + # Same output with same inputs are equal with consistent hash + assert op1 == op2 + assert hash(op1) == hash(op2) + assert {op1: "v"}[op2] == "v" + + # Different graphs are not equal + op_different = OpFromGraph([x, y], [x * y + x]) + assert op1 != op_different + + # inline flag participates in equality + op_inline = OpFromGraph([x, y], [e], inline=True) + assert op1 != op_inline + + # destroy_map participates in equality + op_destroy = OpFromGraph([x, y], [e], destroy_map={0: (0,)}) + assert op1 != op_destroy + + # Multi-output OFGs are also hashed and compared based on their inner graph structure + op_multi1 = OpFromGraph([x, y], [x + y, x * y]) + op_multi2 = OpFromGraph([x, y], [x + y, x * y]) + assert op_multi1 == op_multi2 + + # OFG is hashable, and different OFGs have different hashes + assert hash(op1) != hash(op_inline) + + def test_equality_shared_variables(self): + x = scalar("x") + s = shared(np.array(1.0, dtype=config.floatX)) + + op1 = OpFromGraph([x], [x + s]) + op2 = OpFromGraph([x], [x + s]) + assert op1 == op2 + + # Same value, different shared object -> not equal + s2 = shared(np.array(1.0, dtype=config.floatX)) + op3 = OpFromGraph([x], [x + s2]) + assert op1 != op3 + + def test_equality_callable_overrides(self): + x, y = dscalars("x", "y") + e = x + y + + op_plain = OpFromGraph([x, y], [e]) + + # lop override present vs absent + op_with_lop = OpFromGraph( + [x, y], + [e], + lop_overrides=lambda inps, outs, grads: [grads[0], grads[0]], + ) + assert op_plain != op_with_lop + + # Structurally identical callable overrides are equal + op_with_lop2 = OpFromGraph( + [x, y], + [e], + lop_overrides=lambda inps, outs, grads: [grads[0], grads[0]], + ) + assert op_with_lop == op_with_lop2 + + # Structurally different callable override are not equal + op_with_lop3 = OpFromGraph( + [x, y], + [e], + lop_overrides=lambda inps, outs, grads: [grads[0] * 2, grads[0]], + ) + assert op_with_lop != op_with_lop3 + + # Overrides returning disconnected_type for different inputs are not equal + op_disc_y = OpFromGraph( + [x, y], + [e], + lop_overrides=lambda inps, outs, grads: [grads[0], disconnected_type()], + ) + op_disc_x = OpFromGraph( + [x, y], + [e], + lop_overrides=lambda inps, outs, grads: [disconnected_type(), grads[0]], + ) + assert op_disc_y != op_disc_x + + # Same disconnected pattern is equal + op_disc_y2 = OpFromGraph( + [x, y], + [e], + lop_overrides=lambda inps, outs, grads: [grads[0], disconnected_type()], + ) + assert op_disc_y == op_disc_y2 + + # All disconnected is still an override — not equal to no override + op_all_disc = OpFromGraph( + [x, y], + [e], + lop_overrides=lambda inps, outs, grads: [ + disconnected_type(), + disconnected_type(), + ], + ) + assert op_all_disc != op_plain + assert op_all_disc != op_disc_y + + # rop override follows the same logic + op_with_rop = OpFromGraph( + [x, y], + [e], + rop_overrides=lambda inps, epts: [epts[0] + epts[1]], + ) + op_with_rop2 = OpFromGraph( + [x, y], + [e], + rop_overrides=lambda inps, epts: [epts[0] + epts[1]], + ) + assert op_with_rop == op_with_rop2 + assert op_with_rop != op_plain + + def test_equality_list_overrides(self): + x, y = dscalars("x", "y") + e = x + y + + def scale_grad(inps, outs, grads): + return grads[0] * 2 + + op1 = OpFromGraph([x, y], [e], lop_overrides=[scale_grad, None]) + op2 = OpFromGraph([x, y], [e], lop_overrides=[scale_grad, None]) + assert op1 == op2 + + def scale_grad_3x(inps, outs, grads): + return grads[0] * 3 + + op3 = OpFromGraph([x, y], [e], lop_overrides=[scale_grad_3x, None]) + assert op1 != op3 + + # Position of None vs callable matters + op4 = OpFromGraph([x, y], [e], lop_overrides=[None, scale_grad]) + assert op1 != op4 + + def test_merge_identical_ofgs(self): + x, y = dscalars("x", "y") + e = x + y * x + + op1 = OpFromGraph([x, y], [e]) + op2 = OpFromGraph([x, y], [e]) + + a, b = dscalars("a", "b") + + # Two OFG with the same inputs are collapsed to one node by MergeOptimizer + fg = FunctionGraph([a, b], [op1(a, b), op2(a, b)]) + MergeOptimizer().rewrite(fg) + ofg_nodes = [n for n in fg.toposort() if isinstance(n.op, OpFromGraph)] + assert len(ofg_nodes) == 1 + + # Different inputs are different graphs, so both nodes survive + c, d = dscalars("c", "d") + fg = FunctionGraph([a, b, c, d], [op1(a, b), op2(c, d)]) + MergeOptimizer().rewrite(fg) + ofg_nodes = [n for n in fg.toposort() if isinstance(n.op, OpFromGraph)] + assert len(ofg_nodes) == 2 + + # Check numerics to make sure the merged OFG is correct + fn = function( + [a, b, c, d], + [op1(a, b), op2(c, d)], + mode=Mode(optimizer="merge", linker="py"), + ) + r1, r2 = fn(2.0, 3.0, 4.0, 5.0) + np.testing.assert_allclose(r1, 2.0 + 3.0 * 2.0) + np.testing.assert_allclose(r2, 4.0 + 5.0 * 4.0) + @config.change_flags(floatX="float64") def test_debugprint(): diff --git a/tests/graph/test_basic.py b/tests/graph/test_basic.py index 455b36f757..e9d8aebd74 100644 --- a/tests/graph/test_basic.py +++ b/tests/graph/test_basic.py @@ -471,3 +471,107 @@ def test_dprint(): o1 = MyOp(r1, r2) assert o1.dprint(file="str") == debugprint(o1, file="str") assert o1.owner.dprint(file="str") == debugprint(o1.owner, file="str") + + +class TestFrozenConstant: + def test_interning(self): + from pytensor.graph.basic import Constant, FrozenConstant + from pytensor.scalar.basic import float64 + + c1 = FrozenConstant(float64, 2.0) + c2 = FrozenConstant(float64, 2.0) + c3 = FrozenConstant(float64, 3.0) + + assert c1 is c2 + assert c1 is not c3 + assert isinstance(c1, Constant) + assert c1.data == 2.0 + assert c3.data == 3.0 + + # Usable as dict key / in sets via identity + assert len({c1, c2, c3}) == 2 + + def test_nan_interning(self): + from pytensor.graph.basic import FrozenConstant + from pytensor.scalar.basic import float64 + + c1 = FrozenConstant(float64, float("nan")) + c2 = FrozenConstant(float64, float("nan")) + # NaN hashes the same, so these should be the same object + assert c1 is c2 + + def test_array_interning(self): + from pytensor.graph.basic import FrozenConstant + + t = TensorType("float64", shape=(3,)) + arr = np.array([1.0, 2.0, np.nan]) + + c1 = FrozenConstant(t, arr) + c2 = FrozenConstant(t, arr.copy()) + c3 = FrozenConstant(t, np.array([9.0, 9.0, 9.0])) + + assert c1 is c2 + assert c1 is not c3 + np.testing.assert_array_equal(c1.data, arr) + + def test_different_types_not_shared(self): + from pytensor.graph.basic import FrozenConstant + from pytensor.scalar.basic import float32, float64 + + c1 = FrozenConstant(float64, 1.0) + c2 = FrozenConstant(float32, 1.0) + assert c1 is not c2 + + +class TestFrozenApply: + def test_interning_and_immutability(self): + from pytensor.graph.basic import FrozenApply + from pytensor.scalar.basic import add, float64, mul + + x = NominalVariable(0, float64) + y = NominalVariable(1, float64) + + fa1 = FrozenApply(add, (x, y), (float64,)) + fa2 = FrozenApply(add, (x, y), (float64,)) + fa_diff_op = FrozenApply(mul, (x, y), (float64,)) + fa_diff_order = FrozenApply(add, (y, x), (float64,)) + + # Same (op, inputs) implies the same object + assert fa1 is fa2 + + # Different op or input order implies a different object + assert fa1 is not fa_diff_op + assert fa1 is not fa_diff_order + + assert isinstance(fa1, Apply) + + assert fa1.outputs[0].owner is fa1 + assert fa1.outputs[0].index == 0 + + def test_cross_graph_identity(self): + """Two independently-built identical graphs share all FrozenApply nodes.""" + from pytensor.graph.basic import FrozenApply + from pytensor.scalar.basic import float64, mul, sin, sqr + + def build_graph(): + a = NominalVariable(0, float64) + b = NominalVariable(1, float64) + n_sin = FrozenApply(sin, (a,), (float64,)) + n_sqr = FrozenApply(sqr, (b,), (float64,)) + n_mul = FrozenApply(mul, (n_sin.outputs[0], n_sqr.outputs[0]), (float64,)) + return n_mul.outputs[0] + + out1 = build_graph() + out2 = build_graph() + assert out1 is out2 + + def test_frozen_constant_in_key_chain(self): + from pytensor.graph.basic import FrozenApply, FrozenConstant + from pytensor.scalar.basic import add, float64 + + x = NominalVariable(0, float64) + c1 = FrozenConstant(float64, 3.14) + c2 = FrozenConstant(float64, 3.14) + fa1 = FrozenApply(add, (x, c1), (float64,)) + fa2 = FrozenApply(add, (x, c2), (float64,)) + assert fa1 is fa2 diff --git a/tests/graph/test_fg.py b/tests/graph/test_fg.py index c14ad2dce8..c218ae77ff 100644 --- a/tests/graph/test_fg.py +++ b/tests/graph/test_fg.py @@ -771,3 +771,84 @@ def test_optimizer_verbose(self, capsys): cap_out = capsys.readouterr().out assert "y->z" not in cap_out assert "z->y" not in cap_out + + +class TestFrozenFunctionGraph: + def test_hashability_and_comparison(self): + var1, var2 = MyVariable("x"), MyVariable("y") + + ffg1 = FunctionGraph([var1, var2], [op1(var1, var2)]).freeze() + ffg2 = FunctionGraph([var1, var2], [op1(var1, var2)]).freeze() + ffg_different = FunctionGraph([var1, var2], [op2(var1, var2)]).freeze() + + assert ffg1 == ffg2 + assert hash(ffg1) == hash(ffg2) + assert ffg1 != ffg_different + + assert {ffg1: "value"}[ffg2] == "value" + assert len({ffg1, ffg2}) == 1 + + def test_nominal_inputs_renumbered(self): + """Inputs are always renumbered 0..n regardless of original ids.""" + t = MyType() + nm5, nm10 = NominalVariable(5, t), NominalVariable(10, t) + + ffg = FunctionGraph([nm5, nm10], [op1(nm5, nm10)]).freeze() + assert [inp.id for inp in ffg.inputs] == [0, 1] + + def test_deduplication(self): + var1 = MyVariable("x") + + dup1, dup2 = op1(var1), op1(var1) + frozen = FunctionGraph([var1], [op2(dup1, dup2)]).freeze() + assert {n.op for n in frozen.apply_nodes} == {op1, op2} + + c1 = MyConstant("c", data=42) + c2 = MyConstant("c", data=42) + frozen_const = FunctionGraph( + [var1], [op2(op1(var1, c1), op1(var1, c2))] + ).freeze() + assert {n.op for n in frozen_const.apply_nodes} == {op1, op2} + + def test_input_passed_directly_to_output(self): + var1 = MyVariable("x") + frozen = FunctionGraph([var1], [var1]).freeze() + + assert frozen.apply_nodes == set() + assert isinstance(frozen.outputs[0], NominalVariable) + + def test_cross_graph_output_identity(self): + var1, var2 = MyVariable("x"), MyVariable("y") + ffg1 = FunctionGraph([var1, var2], [op1(var1, var2)]).freeze() + ffg2 = FunctionGraph([var1, var2], [op1(var1, var2)]).freeze() + + assert all(a is b for a, b in zip(ffg1.outputs, ffg2.outputs)) + + def test_pickle_round_trip(self): + from pytensor.scalar.basic import add, float64, mul + + x, y = float64("x"), float64("y") + ffg = FunctionGraph([x, y], [mul(add(x, y), y)]).freeze() + + ffg2 = pickle.loads(pickle.dumps(ffg)) + assert ffg == ffg2 + assert hash(ffg) == hash(ffg2) + # Interned objects survive pickle + assert all(o1 is o2 for o1, o2 in zip(ffg.outputs, ffg2.outputs)) + + def test_orphan_non_constant_raises(self): + from pytensor.graph.fg import FrozenFunctionGraph + + var1 = MyVariable("x") + orphan = MyVariable("orphan") + out = op1(var1, orphan) + with pytest.raises(ValueError, match=r"Non-Constant.*orphan"): + FrozenFunctionGraph([var1], [out]) + + def test_unmapped_output_raises(self): + from pytensor.graph.fg import FrozenFunctionGraph + + var1 = MyVariable("x") + disconnected = MyVariable("disconnected") + with pytest.raises(ValueError, match="could not be mapped"): + FrozenFunctionGraph([var1], [disconnected]) diff --git a/tests/scalar/test_basic.py b/tests/scalar/test_basic.py index 3167a20149..96fb5044b9 100644 --- a/tests/scalar/test_basic.py +++ b/tests/scalar/test_basic.py @@ -1,3 +1,5 @@ +import pickle + import numpy as np import pytest @@ -92,6 +94,24 @@ def test_flatten(self): # We don't flatten that case. assert isinstance(CC.outputs[0].owner.op, Composite) + def test_shared_identity(self): + x, y = floats("xy") + c1 = Composite([x, y], [x + y]) + c2 = Composite([x, y], [x + y]) + assert c1 == c2 + assert hash(c1) == hash(c2) + assert {c1: 1}[c2] == 1 + + c3 = Composite([x, y], [x * y]) + assert c1 != c3 + + def test_pickle_roundtrip(self): + x, y = floats("xy") + c = Composite([x, y], [x + y]) + c2 = pickle.loads(pickle.dumps(c)) + assert c == c2 + assert hash(c) == hash(c2) + @pytest.mark.parametrize("literal_value", (70.0, -np.inf, np.float32("nan"))) def test_with_constants(self, literal_value): x, y, _z = floats("xyz") diff --git a/tests/scalar/test_loop.py b/tests/scalar/test_loop.py index d4f0f5b021..9fabfda6b8 100644 --- a/tests/scalar/test_loop.py +++ b/tests/scalar/test_loop.py @@ -297,7 +297,34 @@ def test_elemwise_inplace(mutate_arg_idx): cv_test = np.array([0, 0, 0], dtype="int64") xv_res, yv_res = fn(n_test, x0v_test, y0v_test, cv_test) - # Check the outputs are the destroyed inputs assert xv_res is (n_test, x0v_test, y0v_test, cv_test)[mutate_arg_idx] np.testing.assert_allclose(xv_res, [-1, -8, -128]) np.testing.assert_allclose(yv_res, [1, 8, 128]) + + +def test_identical_loops_share_inner_graph(): + x0 = float64("x0") + c = float64("c") + + op1 = ScalarLoop(init=[x0], constant=[c], update=[x0 + c]) + op2 = ScalarLoop(init=[x0], constant=[c], update=[x0 + c]) + + assert op1 == op2 + assert hash(op1) == hash(op2) + assert op1.fgraph == op2.fgraph + + # Two loops with the same structure but different outer inputs. + # MergeOptimizer can't collapse the Apply nodes (different inputs), + # but both should reference the same inner Op after merging. + n = int64("n") + a, b, c_val, d = float64("a"), float64("b"), float64("c_val"), float64("d") + y1 = op1(n, a, b) + y2 = op2(n, c_val, d) + + fn = function( + [n, a, b, c_val, d], [y1, y2], mode=Mode(optimizer="merge", linker="py") + ) + nodes = fn.maker.fgraph.toposort() + loop_nodes = [nd for nd in nodes if isinstance(nd.op, ScalarLoop)] + assert len(loop_nodes) == 2 + assert loop_nodes[0].op is loop_nodes[1].op