From ed2a746b4e788b685eff95898c8f7c426d155b80 Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Thu, 5 Mar 2026 08:42:25 -0800 Subject: [PATCH 1/3] Add render graph framework (#73) Lightweight DAG of render passes with: - RenderPass base class (inputs/outputs/requires declarations) - RenderGraph container with topological sort and cycle detection - Buffer lifetime analysis for GPU memory reuse - Capability-gated pass skipping with fallback wiring - Validation for missing inputs and dependency errors --- rtxpy/__init__.py | 8 + rtxpy/render_graph.py | 547 +++++++++++++++++++++++++ rtxpy/tests/test_render_graph.py | 659 +++++++++++++++++++++++++++++++ 3 files changed, 1214 insertions(+) create mode 100644 rtxpy/render_graph.py create mode 100644 rtxpy/tests/test_render_graph.py diff --git a/rtxpy/__init__.py b/rtxpy/__init__.py index ca3c137..f474b68 100644 --- a/rtxpy/__init__.py +++ b/rtxpy/__init__.py @@ -25,6 +25,14 @@ ) from .analysis import viewshed, hillshade, render, flyover, view from .engine import explore +from .render_graph import ( + BufferDesc, + RenderPass, + RenderGraph, + CompiledGraph, + AllocationPlan, + GraphValidationError, +) __version__ = "0.1.0" diff --git a/rtxpy/render_graph.py b/rtxpy/render_graph.py new file mode 100644 index 0000000..b75b5e7 --- /dev/null +++ b/rtxpy/render_graph.py @@ -0,0 +1,547 @@ +"""Render graph: a lightweight DAG of render passes with automatic +dependency resolution, buffer lifetime analysis, and capability gating. + +Usage:: + + graph = RenderGraph(width=1920, height=1080) + graph.add_pass(GBufferPass(...)) + graph.add_pass(ShadowPass(...)) + graph.add_pass(DenoisePass(...)) + graph.add_pass(TonemapPass(...)) + + compiled = graph.compile(capabilities=get_capabilities()) + result = compiled.execute() +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Any, Callable, Optional + + +@dataclass(frozen=True) +class BufferDesc: + """Describes a GPU buffer's shape and data type. + + Parameters + ---------- + dtype : str + NumPy-compatible dtype string (e.g. ``'float32'``, ``'int32'``). + channels : int + Number of channels per element. For a per-pixel RGB buffer this is 3; + for a scalar per-ray buffer this is 1. + per_pixel : bool + If True the buffer shape is ``(height, width, channels)`` (image-like). + If False the shape is ``(width * height, channels)`` (per-ray flat). + """ + + dtype: str = "float32" + channels: int = 3 + per_pixel: bool = True + + def shape(self, width: int, height: int) -> tuple[int, ...]: + if self.per_pixel: + if self.channels == 1: + return (height, width) + return (height, width, self.channels) + n = width * height + if self.channels == 1: + return (n,) + return (n, self.channels) + + +class RenderPass(ABC): + """Abstract base for a single render pass. + + Subclasses declare their buffer *inputs* and *outputs* so the graph can + resolve execution order and manage GPU memory. + + Parameters + ---------- + name : str + Unique name for this pass (e.g. ``"gbuffer"``, ``"shadow"``). + inputs : dict[str, BufferDesc] + Buffers this pass reads. Keys are buffer names (globally unique within + the graph). + outputs : dict[str, BufferDesc] + Buffers this pass writes. + enabled : bool + If False the pass is skipped during compilation. + requires : list[str] + Capability keys (from :func:`rtxpy.get_capabilities`) that must be + truthy for this pass to run. The graph disables the pass automatically + when requirements aren't met. + """ + + def __init__( + self, + name: str, + inputs: dict[str, BufferDesc] | None = None, + outputs: dict[str, BufferDesc] | None = None, + *, + enabled: bool = True, + requires: list[str] | None = None, + ): + self.name = name + self.inputs: dict[str, BufferDesc] = inputs or {} + self.outputs: dict[str, BufferDesc] = outputs or {} + self.enabled = enabled + self.requires: list[str] = requires or [] + + def setup(self, graph: RenderGraph) -> None: + """Called once when the graph is compiled. Override to allocate + one-time resources.""" + + @abstractmethod + def execute(self, buffers: dict[str, Any]) -> None: + """Run this pass. + + *buffers* maps buffer names (both inputs and outputs) to allocated GPU + arrays. The pass should read its declared inputs and write its + declared outputs. + """ + + def teardown(self) -> None: + """Called when the graph is torn down. Override to free resources.""" + + def __repr__(self) -> str: + return ( + f"{type(self).__name__}(name={self.name!r}, " + f"inputs={list(self.inputs)}, outputs={list(self.outputs)}, " + f"enabled={self.enabled})" + ) + + +class GraphValidationError(Exception): + """Raised when the render graph fails validation.""" + + +@dataclass +class _BufferLifetime: + """Tracks first-write and last-read pass indices for a buffer.""" + + first_write: int = -1 + last_read: int = -1 + + +@dataclass +class AllocationPlan: + """The result of buffer lifetime analysis. + + Maps each buffer name to a *pool slot* index. Buffers assigned to the same + slot have non-overlapping lifetimes and can share the same GPU allocation. + """ + + slots: dict[str, int] = field(default_factory=dict) + slot_descs: dict[int, BufferDesc] = field(default_factory=dict) + + @property + def num_slots(self) -> int: + if not self.slot_descs: + return 0 + return max(self.slot_descs) + 1 + + +@dataclass +class CompiledGraph: + """A compiled, ready-to-execute render graph. + + Produced by :meth:`RenderGraph.compile`. + """ + + ordered_passes: list[RenderPass] + allocation_plan: AllocationPlan + buffer_descs: dict[str, BufferDesc] + fallback_map: dict[str, str] + width: int + height: int + + def execute( + self, + external_buffers: dict[str, Any] | None = None, + allocator: Callable[[tuple[int, ...], str], Any] | None = None, + ) -> dict[str, Any]: + """Execute all passes in topological order. + + Parameters + ---------- + external_buffers : dict, optional + Pre-allocated buffers to inject (e.g. the OptiX scene handle). + These are available to passes but are not managed by the graph. + allocator : callable, optional + ``allocator(shape, dtype) -> array``. Defaults to + ``cupy.zeros`` if CuPy is available, else ``numpy.zeros``. + + Returns + ------- + dict[str, array] + All live buffers after execution completes. + """ + if allocator is None: + allocator = _default_allocator() + + buffers: dict[str, Any] = dict(external_buffers or {}) + + # Allocate managed buffers grouped by pool slot + pool: dict[int, Any] = {} + for buf_name, slot_idx in self.allocation_plan.slots.items(): + if buf_name in buffers: + continue # external + if slot_idx not in pool: + desc = self.allocation_plan.slot_descs[slot_idx] + pool[slot_idx] = allocator( + desc.shape(self.width, self.height), desc.dtype + ) + buffers[buf_name] = pool[slot_idx] + + # Also allocate any buffers not in the allocation plan (outputs of + # active passes that somehow escaped lifetime analysis — shouldn't + # happen, but defensive). + for buf_name, desc in self.buffer_descs.items(): + if buf_name not in buffers: + buffers[buf_name] = allocator( + desc.shape(self.width, self.height), desc.dtype + ) + + # Apply fallback wiring: if a buffer is absent (producer was disabled), + # point it at the fallback buffer. + for buf_name, fallback in self.fallback_map.items(): + if buf_name not in buffers and fallback in buffers: + buffers[buf_name] = buffers[fallback] + + for pass_ in self.ordered_passes: + pass_.execute(buffers) + + return buffers + + +class RenderGraph: + """A configurable DAG of render passes. + + Passes declare typed buffer inputs/outputs. The graph resolves execution + order via topological sort, detects cycles, performs buffer lifetime + analysis for memory reuse, and gates passes on runtime GPU capabilities. + """ + + def __init__(self, width: int = 1920, height: int = 1080): + self.width = width + self.height = height + self._passes: dict[str, RenderPass] = {} + self._insertion_order: list[str] = [] + self._fallbacks: dict[str, str] = {} + + # -- Pass management --------------------------------------------------- + + def add_pass(self, pass_: RenderPass) -> None: + """Add a render pass to the graph. + + Raises + ------ + ValueError + If a pass with the same name already exists. + """ + if pass_.name in self._passes: + raise ValueError(f"Pass {pass_.name!r} already exists in the graph") + self._passes[pass_.name] = pass_ + self._insertion_order.append(pass_.name) + + def remove_pass(self, name: str) -> None: + """Remove a render pass by name. + + Raises + ------ + KeyError + If no pass with that name exists. + """ + if name not in self._passes: + raise KeyError(f"No pass named {name!r}") + del self._passes[name] + self._insertion_order.remove(name) + + def get_pass(self, name: str) -> RenderPass: + return self._passes[name] + + @property + def passes(self) -> list[RenderPass]: + return [self._passes[n] for n in self._insertion_order] + + def set_fallback(self, buffer: str, fallback: str) -> None: + """Register a fallback: if *buffer*'s producer is disabled, read + *fallback* instead. + + Example: ``graph.set_fallback("denoised_color", "color")`` — if the + denoise pass is skipped, downstream passes reading ``denoised_color`` + transparently receive ``color``. + """ + self._fallbacks[buffer] = fallback + + # -- Compilation ------------------------------------------------------- + + def compile( + self, + capabilities: dict[str, Any] | None = None, + validate: bool = True, + ) -> CompiledGraph: + """Compile the graph into a ready-to-execute form. + + 1. Capability-gate passes (disable those whose requirements aren't met). + 2. Resolve fallback wiring for disabled pass outputs. + 3. Topological sort by buffer dependencies. + 4. Buffer lifetime analysis and allocation planning. + 5. Validation (missing inputs, cycles). + + Parameters + ---------- + capabilities : dict, optional + Runtime capabilities from :func:`rtxpy.get_capabilities`. + validate : bool + If True, raise :class:`GraphValidationError` on problems. + + Returns + ------- + CompiledGraph + """ + caps = capabilities or {} + active = self._gate_passes(caps) + active_names = {p.name for p in active} + + # Collect all buffer descriptors (outputs define the canonical desc) + buffer_descs: dict[str, BufferDesc] = {} + for p in active: + for buf_name, desc in p.outputs.items(): + buffer_descs[buf_name] = desc + + # Resolve fallbacks: find buffers consumed but not produced + fallback_map = self._resolve_fallbacks(active, buffer_descs) + + # Build dependency edges + producer: dict[str, str] = {} # buffer_name -> pass_name + for p in active: + for buf_name in p.outputs: + producer[buf_name] = p.name + + adj: dict[str, list[str]] = defaultdict(list) # pass -> [deps] + for p in active: + for buf_name in p.inputs: + resolved = fallback_map.get(buf_name, buf_name) + if resolved in producer: + dep_pass = producer[resolved] + if dep_pass != p.name: + adj[p.name].append(dep_pass) + + # Topological sort (Kahn's algorithm) + ordered = self._topological_sort(active, adj) + + if validate: + self._validate(ordered, buffer_descs, fallback_map, producer) + + # Buffer lifetime analysis + allocation_plan = self._lifetime_analysis(ordered, fallback_map) + + # Call setup on each pass + for p in ordered: + p.setup(self) + + return CompiledGraph( + ordered_passes=ordered, + allocation_plan=allocation_plan, + buffer_descs=buffer_descs, + fallback_map=fallback_map, + width=self.width, + height=self.height, + ) + + # -- Internal helpers -------------------------------------------------- + + def _gate_passes(self, capabilities: dict[str, Any]) -> list[RenderPass]: + """Return only passes that are enabled and whose capability + requirements are met.""" + active: list[RenderPass] = [] + for name in self._insertion_order: + p = self._passes[name] + if not p.enabled: + continue + if p.requires and not all(capabilities.get(r) for r in p.requires): + continue + active.append(p) + return active + + def _resolve_fallbacks( + self, + active: list[RenderPass], + buffer_descs: dict[str, BufferDesc], + ) -> dict[str, str]: + """Build the fallback map for buffers whose producers are inactive.""" + produced = set(buffer_descs) + consumed: set[str] = set() + for p in active: + consumed.update(p.inputs) + + fallback_map: dict[str, str] = {} + for buf in consumed - produced: + fb = buf + visited: set[str] = set() + while fb in self._fallbacks and fb not in produced: + if fb in visited: + break # avoid cycles in fallback chain + visited.add(fb) + fb = self._fallbacks[fb] + if fb != buf: + fallback_map[buf] = fb + return fallback_map + + def _topological_sort( + self, + active: list[RenderPass], + adj: dict[str, list[str]], + ) -> list[RenderPass]: + """Kahn's algorithm. Stable: breaks ties by insertion order.""" + in_degree: dict[str, int] = {p.name: 0 for p in active} + reverse_adj: dict[str, list[str]] = defaultdict(list) + + for node, deps in adj.items(): + for dep in deps: + if dep in in_degree: + reverse_adj[dep].append(node) + in_degree[node] = in_degree.get(node, 0) + 1 + + # Seed queue with zero-in-degree passes, ordered by insertion order + order_idx = {name: i for i, name in enumerate(self._insertion_order)} + queue = sorted( + [p.name for p in active if in_degree.get(p.name, 0) == 0], + key=lambda n: order_idx.get(n, 0), + ) + + result: list[str] = [] + while queue: + node = queue.pop(0) + result.append(node) + for neighbor in sorted( + reverse_adj.get(node, []), + key=lambda n: order_idx.get(n, 0), + ): + in_degree[neighbor] -= 1 + if in_degree[neighbor] == 0: + queue.append(neighbor) + + if len(result) != len(active): + in_cycle = {p.name for p in active} - set(result) + raise GraphValidationError( + f"Cycle detected among passes: {in_cycle}" + ) + + pass_map = {p.name: p for p in active} + return [pass_map[n] for n in result] + + def _lifetime_analysis( + self, + ordered: list[RenderPass], + fallback_map: dict[str, str], + ) -> AllocationPlan: + """Compute buffer lifetimes and assign pool slots for memory reuse.""" + lifetimes: dict[str, _BufferLifetime] = {} + descs: dict[str, BufferDesc] = {} + + for idx, p in enumerate(ordered): + for buf_name, desc in p.outputs.items(): + if buf_name not in lifetimes: + lifetimes[buf_name] = _BufferLifetime() + lifetimes[buf_name].first_write = idx + descs[buf_name] = desc + + for buf_name in p.inputs: + resolved = fallback_map.get(buf_name, buf_name) + if resolved not in lifetimes: + lifetimes[resolved] = _BufferLifetime() + lifetimes[resolved].last_read = max( + lifetimes[resolved].last_read, idx + ) + + # Extend last_read for buffers never explicitly read (keep alive + # until end so they appear in the final output dict). + num_passes = len(ordered) + for buf, lt in lifetimes.items(): + if lt.last_read < lt.first_write: + lt.last_read = num_passes + + # Greedy interval-colouring: assign buffers to pool slots such that + # overlapping lifetimes get different slots. Buffers can only share + # a slot if they have the same shape & dtype. + slots: dict[str, int] = {} + slot_descs: dict[int, BufferDesc] = {} + # slot_end[slot_idx] = last_read of current occupant + slot_end: dict[int, int] = {} + next_slot = 0 + + # Sort buffers by first_write for deterministic allocation + sorted_bufs = sorted(lifetimes, key=lambda b: lifetimes[b].first_write) + + for buf in sorted_bufs: + lt = lifetimes[buf] + desc = descs.get(buf) + if desc is None: + continue # external buffer, skip + + # Try to reuse an existing slot whose occupant has expired + reused = False + for sid in sorted(slot_end): + if ( + slot_end[sid] < lt.first_write + and slot_descs[sid] == desc + ): + slots[buf] = sid + slot_end[sid] = lt.last_read + reused = True + break + + if not reused: + slots[buf] = next_slot + slot_descs[next_slot] = desc + slot_end[next_slot] = lt.last_read + next_slot += 1 + + return AllocationPlan(slots=slots, slot_descs=slot_descs) + + def _validate( + self, + ordered: list[RenderPass], + buffer_descs: dict[str, BufferDesc], + fallback_map: dict[str, str], + producer: dict[str, str], + ) -> None: + """Check for missing inputs and warn on unused outputs.""" + produced = set(buffer_descs) + produced.update(fallback_map.values()) + + errors: list[str] = [] + for p in ordered: + for buf_name in p.inputs: + resolved = fallback_map.get(buf_name, buf_name) + if resolved not in produced: + errors.append( + f"Pass {p.name!r} requires input {buf_name!r} " + f"but no active pass produces it" + ) + + if errors: + raise GraphValidationError("\n".join(errors)) + + +def _default_allocator(): + """Return a zero-allocator: tries CuPy first, falls back to NumPy.""" + try: + import cupy + + def _alloc(shape, dtype): + return cupy.zeros(shape, dtype=dtype) + + return _alloc + except ImportError: + import numpy as np + + def _alloc(shape, dtype): + return np.zeros(shape, dtype=dtype) + + return _alloc diff --git a/rtxpy/tests/test_render_graph.py b/rtxpy/tests/test_render_graph.py new file mode 100644 index 0000000..fbedd85 --- /dev/null +++ b/rtxpy/tests/test_render_graph.py @@ -0,0 +1,659 @@ +"""Tests for the render graph framework (issue #73).""" + +import numpy as np +import pytest + +from rtxpy.render_graph import ( + AllocationPlan, + BufferDesc, + CompiledGraph, + GraphValidationError, + RenderGraph, + RenderPass, +) + + +# --------------------------------------------------------------------------- +# Helpers — concrete pass implementations for testing +# --------------------------------------------------------------------------- + + +class StubPass(RenderPass): + """Minimal concrete pass that records execute calls.""" + + def __init__(self, name, inputs=None, outputs=None, **kwargs): + super().__init__(name, inputs, outputs, **kwargs) + self.executed = False + self.exec_order = -1 + + def execute(self, buffers): + self.executed = True + + +class WritingPass(StubPass): + """Writes a constant into its output buffer for verification.""" + + def __init__(self, name, inputs=None, outputs=None, value=1.0, **kwargs): + super().__init__(name, inputs, outputs, **kwargs) + self.value = value + + def execute(self, buffers): + super().execute(buffers) + for buf_name in self.outputs: + buffers[buf_name][:] = self.value + + +class SummingPass(StubPass): + """Sums all input buffers into the output for verification.""" + + def __init__(self, name, inputs=None, outputs=None, **kwargs): + super().__init__(name, inputs, outputs, **kwargs) + + def execute(self, buffers): + super().execute(buffers) + out_name = list(self.outputs)[0] + buffers[out_name][:] = 0 + for buf_name in self.inputs: + buffers[out_name] += buffers[buf_name] + + +# --------------------------------------------------------------------------- +# BufferDesc tests +# --------------------------------------------------------------------------- + + +class TestBufferDesc: + def test_per_pixel_rgb(self): + desc = BufferDesc(dtype="float32", channels=3, per_pixel=True) + assert desc.shape(1920, 1080) == (1080, 1920, 3) + + def test_per_pixel_scalar(self): + desc = BufferDesc(dtype="float32", channels=1, per_pixel=True) + assert desc.shape(1920, 1080) == (1080, 1920) + + def test_per_ray_multi_channel(self): + desc = BufferDesc(dtype="float32", channels=4, per_pixel=False) + assert desc.shape(1920, 1080) == (1920 * 1080, 4) + + def test_per_ray_scalar(self): + desc = BufferDesc(dtype="int32", channels=1, per_pixel=False) + assert desc.shape(100, 50) == (5000,) + + def test_frozen(self): + desc = BufferDesc() + with pytest.raises(AttributeError): + desc.dtype = "int32" + + def test_equality(self): + a = BufferDesc(dtype="float32", channels=3) + b = BufferDesc(dtype="float32", channels=3) + assert a == b + + def test_inequality(self): + a = BufferDesc(dtype="float32", channels=3) + b = BufferDesc(dtype="float32", channels=4) + assert a != b + + +# --------------------------------------------------------------------------- +# Pass management +# --------------------------------------------------------------------------- + + +class TestPassManagement: + def test_add_and_list(self): + g = RenderGraph(width=64, height=64) + g.add_pass(StubPass("a")) + g.add_pass(StubPass("b")) + assert [p.name for p in g.passes] == ["a", "b"] + + def test_duplicate_name_raises(self): + g = RenderGraph() + g.add_pass(StubPass("a")) + with pytest.raises(ValueError, match="already exists"): + g.add_pass(StubPass("a")) + + def test_remove_pass(self): + g = RenderGraph() + g.add_pass(StubPass("a")) + g.add_pass(StubPass("b")) + g.remove_pass("a") + assert [p.name for p in g.passes] == ["b"] + + def test_remove_missing_raises(self): + g = RenderGraph() + with pytest.raises(KeyError): + g.remove_pass("nonexistent") + + def test_get_pass(self): + g = RenderGraph() + p = StubPass("a") + g.add_pass(p) + assert g.get_pass("a") is p + + +# --------------------------------------------------------------------------- +# Topological sort +# --------------------------------------------------------------------------- + + +class TestTopologicalSort: + def test_linear_chain(self): + """A -> B -> C should execute in that order.""" + rgb = BufferDesc(channels=3) + g = RenderGraph(width=8, height=8) + g.add_pass(StubPass("A", outputs={"color": rgb})) + g.add_pass(StubPass("B", inputs={"color": rgb}, outputs={"denoised": rgb})) + g.add_pass(StubPass("C", inputs={"denoised": rgb}, outputs={"final": rgb})) + compiled = g.compile() + names = [p.name for p in compiled.ordered_passes] + assert names == ["A", "B", "C"] + + def test_diamond_dependency(self): + """A -> B, A -> C, B+C -> D.""" + rgb = BufferDesc(channels=3) + g = RenderGraph(width=8, height=8) + g.add_pass(StubPass("A", outputs={"color": rgb, "normal": rgb})) + g.add_pass(StubPass("B", inputs={"color": rgb}, outputs={"ao": rgb})) + g.add_pass(StubPass("C", inputs={"normal": rgb}, outputs={"shadow": rgb})) + g.add_pass( + StubPass("D", inputs={"ao": rgb, "shadow": rgb}, outputs={"final": rgb}) + ) + compiled = g.compile() + names = [p.name for p in compiled.ordered_passes] + assert names.index("A") < names.index("B") + assert names.index("A") < names.index("C") + assert names.index("B") < names.index("D") + assert names.index("C") < names.index("D") + + def test_independent_passes_preserve_insertion_order(self): + """Passes with no dependencies keep their insertion order.""" + rgb = BufferDesc(channels=3) + g = RenderGraph(width=8, height=8) + g.add_pass(StubPass("X", outputs={"x": rgb})) + g.add_pass(StubPass("Y", outputs={"y": rgb})) + g.add_pass(StubPass("Z", outputs={"z": rgb})) + compiled = g.compile() + names = [p.name for p in compiled.ordered_passes] + assert names == ["X", "Y", "Z"] + + def test_cycle_detection(self): + rgb = BufferDesc(channels=3) + g = RenderGraph(width=8, height=8) + g.add_pass(StubPass("A", inputs={"b_out": rgb}, outputs={"a_out": rgb})) + g.add_pass(StubPass("B", inputs={"a_out": rgb}, outputs={"b_out": rgb})) + with pytest.raises(GraphValidationError, match="Cycle"): + g.compile() + + +# --------------------------------------------------------------------------- +# Capability gating +# --------------------------------------------------------------------------- + + +class TestCapabilityGating: + def test_pass_disabled_when_capability_missing(self): + rgb = BufferDesc(channels=3) + g = RenderGraph(width=8, height=8) + g.add_pass(StubPass("A", outputs={"color": rgb})) + g.add_pass( + StubPass( + "denoiser", + inputs={"color": rgb}, + outputs={"denoised": rgb}, + requires=["optix_denoiser"], + ) + ) + # No capabilities -> denoiser skipped + compiled = g.compile(capabilities={}) + names = [p.name for p in compiled.ordered_passes] + assert "denoiser" not in names + + def test_pass_enabled_when_capability_present(self): + rgb = BufferDesc(channels=3) + g = RenderGraph(width=8, height=8) + g.add_pass(StubPass("A", outputs={"color": rgb})) + g.add_pass( + StubPass( + "denoiser", + inputs={"color": rgb}, + outputs={"denoised": rgb}, + requires=["optix_denoiser"], + ) + ) + compiled = g.compile(capabilities={"optix_denoiser": True}) + names = [p.name for p in compiled.ordered_passes] + assert "denoiser" in names + + def test_manually_disabled_pass(self): + rgb = BufferDesc(channels=3) + g = RenderGraph(width=8, height=8) + g.add_pass(StubPass("A", outputs={"color": rgb}, enabled=False)) + compiled = g.compile() + assert len(compiled.ordered_passes) == 0 + + +# --------------------------------------------------------------------------- +# Fallback wiring +# --------------------------------------------------------------------------- + + +class TestFallbackWiring: + def test_fallback_when_producer_disabled(self): + """When denoiser is skipped, 'denoised' falls back to 'color'.""" + rgb = BufferDesc(channels=3) + g = RenderGraph(width=8, height=8) + g.add_pass(WritingPass("shade", outputs={"color": rgb}, value=42.0)) + g.add_pass( + StubPass( + "denoise", + inputs={"color": rgb}, + outputs={"denoised": rgb}, + requires=["optix_denoiser"], + ) + ) + g.add_pass(StubPass("tonemap", inputs={"denoised": rgb}, outputs={"final": rgb})) + g.set_fallback("denoised", "color") + + compiled = g.compile(capabilities={}) + names = [p.name for p in compiled.ordered_passes] + assert "denoise" not in names + assert "tonemap" in names + + result = compiled.execute( + allocator=lambda shape, dtype: np.zeros(shape, dtype=dtype) + ) + # 'denoised' should map to the same array as 'color' + assert result["denoised"] is result["color"] + + def test_chained_fallbacks(self): + """A -> B -> C fallback chain resolves correctly.""" + rgb = BufferDesc(channels=3) + g = RenderGraph(width=4, height=4) + g.add_pass(WritingPass("src", outputs={"raw": rgb}, value=1.0)) + g.add_pass( + StubPass("mid", inputs={"raw": rgb}, outputs={"enhanced": rgb}, enabled=False) + ) + g.add_pass( + StubPass( + "final_proc", + inputs={"enhanced": rgb}, + outputs={"polished": rgb}, + enabled=False, + ) + ) + g.add_pass(StubPass("out", inputs={"polished": rgb}, outputs={"display": rgb})) + g.set_fallback("polished", "enhanced") + g.set_fallback("enhanced", "raw") + + compiled = g.compile() + result = compiled.execute( + allocator=lambda shape, dtype: np.zeros(shape, dtype=dtype) + ) + assert result["polished"] is result["raw"] + + +# --------------------------------------------------------------------------- +# Validation +# --------------------------------------------------------------------------- + + +class TestValidation: + def test_missing_input_raises(self): + rgb = BufferDesc(channels=3) + g = RenderGraph(width=8, height=8) + g.add_pass(StubPass("A", inputs={"nonexistent": rgb}, outputs={"out": rgb})) + with pytest.raises(GraphValidationError, match="nonexistent"): + g.compile() + + def test_missing_input_with_fallback_ok(self): + rgb = BufferDesc(channels=3) + g = RenderGraph(width=8, height=8) + g.add_pass(WritingPass("src", outputs={"raw": rgb})) + g.add_pass(StubPass("consumer", inputs={"missing": rgb}, outputs={"out": rgb})) + g.set_fallback("missing", "raw") + # Should not raise + compiled = g.compile() + assert len(compiled.ordered_passes) == 2 + + +# --------------------------------------------------------------------------- +# Buffer lifetime analysis & allocation +# --------------------------------------------------------------------------- + + +class TestLifetimeAnalysis: + def test_non_overlapping_buffers_share_slot(self): + """Two buffers with non-overlapping lifetimes and same desc share a slot.""" + rgb = BufferDesc(channels=3) + g = RenderGraph(width=8, height=8) + g.add_pass(WritingPass("A", outputs={"buf_a": rgb})) + g.add_pass( + WritingPass("B", inputs={"buf_a": rgb}, outputs={"buf_b": rgb}) + ) + # buf_a is last read by B (idx 1), buf_b is first written by B (idx 1). + # buf_a lifetime: [0, 1], buf_b lifetime: [1, end] + # They overlap at index 1, so they should NOT share. + compiled = g.compile() + plan = compiled.allocation_plan + assert plan.slots["buf_a"] != plan.slots["buf_b"] + + def test_truly_non_overlapping_share(self): + """A produces X, B consumes X and produces Y, C consumes Y and produces Z. + X and Z don't overlap -> should share a slot if same desc.""" + rgb = BufferDesc(channels=3) + g = RenderGraph(width=8, height=8) + g.add_pass(WritingPass("A", outputs={"X": rgb})) + g.add_pass(WritingPass("B", inputs={"X": rgb}, outputs={"Y": rgb})) + g.add_pass(WritingPass("C", inputs={"Y": rgb}, outputs={"Z": rgb})) + compiled = g.compile() + plan = compiled.allocation_plan + # X: written at 0, last read at 1 + # Y: written at 1, last read at 2 + # Z: written at 2, last read at end (3) + # X and Z don't overlap -> can share + assert plan.slots["X"] == plan.slots["Z"] + assert plan.slots["X"] != plan.slots["Y"] + + def test_different_descs_dont_share(self): + """Buffers with different descriptors never share a slot.""" + rgb = BufferDesc(dtype="float32", channels=3) + scalar = BufferDesc(dtype="int32", channels=1) + g = RenderGraph(width=8, height=8) + g.add_pass(WritingPass("A", outputs={"X": rgb})) + g.add_pass(WritingPass("B", inputs={"X": rgb}, outputs={"Y": scalar})) + g.add_pass(WritingPass("C", inputs={"Y": scalar}, outputs={"Z": rgb})) + compiled = g.compile() + plan = compiled.allocation_plan + # Even though X and Z don't overlap, Z is rgb and could share with X + # but Y is scalar, so it gets its own slot + assert plan.slots["Y"] != plan.slots["X"] + assert plan.slots["Y"] != plan.slots["Z"] + + def test_allocation_plan_num_slots(self): + rgb = BufferDesc(channels=3) + g = RenderGraph(width=8, height=8) + g.add_pass(WritingPass("A", outputs={"a": rgb})) + g.add_pass(WritingPass("B", outputs={"b": rgb})) + compiled = g.compile() + # Two independent buffers, both alive until end -> 2 slots + assert compiled.allocation_plan.num_slots == 2 + + +# --------------------------------------------------------------------------- +# Execution +# --------------------------------------------------------------------------- + + +class TestExecution: + def test_simple_pipeline(self): + rgb = BufferDesc(channels=3) + g = RenderGraph(width=4, height=4) + g.add_pass(WritingPass("producer", outputs={"color": rgb}, value=7.0)) + g.add_pass( + SummingPass("consumer", inputs={"color": rgb}, outputs={"result": rgb}) + ) + + compiled = g.compile() + result = compiled.execute( + allocator=lambda shape, dtype: np.zeros(shape, dtype=dtype) + ) + np.testing.assert_allclose(result["result"], 7.0) + + def test_external_buffers(self): + """External buffers are available to passes without allocation.""" + rgb = BufferDesc(channels=3) + g = RenderGraph(width=4, height=4) + + class ReaderPass(StubPass): + def execute(self, buffers): + super().execute(buffers) + for buf_name in self.outputs: + buffers[buf_name][:] = buffers["scene_handle"] * 2 + + g.add_pass(ReaderPass("shade", inputs={"scene_handle": rgb}, outputs={"color": rgb})) + + # Inject an external buffer (simulating an OptiX scene handle) + ext = np.full((4, 4, 3), 3.0, dtype=np.float32) + compiled = g.compile(validate=False) # scene_handle has no producer + result = compiled.execute( + external_buffers={"scene_handle": ext}, + allocator=lambda shape, dtype: np.zeros(shape, dtype=dtype), + ) + np.testing.assert_allclose(result["color"], 6.0) + + def test_all_passes_execute(self): + rgb = BufferDesc(channels=3) + g = RenderGraph(width=4, height=4) + p1 = StubPass("A", outputs={"x": rgb}) + p2 = StubPass("B", inputs={"x": rgb}, outputs={"y": rgb}) + g.add_pass(p1) + g.add_pass(p2) + compiled = g.compile() + compiled.execute( + allocator=lambda shape, dtype: np.zeros(shape, dtype=dtype) + ) + assert p1.executed + assert p2.executed + + +# --------------------------------------------------------------------------- +# RenderPass interface +# --------------------------------------------------------------------------- + + +class TestRenderPassInterface: + def test_repr(self): + p = StubPass("test", inputs={"a": BufferDesc()}, outputs={"b": BufferDesc()}) + r = repr(p) + assert "test" in r + assert "inputs" in r + + def test_setup_and_teardown_called(self): + class TrackedPass(StubPass): + def __init__(self, *a, **kw): + super().__init__(*a, **kw) + self.setup_called = False + self.teardown_called = False + + def setup(self, graph): + self.setup_called = True + + def teardown(self): + self.teardown_called = True + + rgb = BufferDesc(channels=3) + g = RenderGraph(width=4, height=4) + p = TrackedPass("t", outputs={"out": rgb}) + g.add_pass(p) + compiled = g.compile() + assert p.setup_called + # teardown is user-called + p.teardown() + assert p.teardown_called + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +class TestEdgeCases: + def test_empty_graph(self): + g = RenderGraph(width=8, height=8) + compiled = g.compile() + assert compiled.ordered_passes == [] + result = compiled.execute( + allocator=lambda shape, dtype: np.zeros(shape, dtype=dtype) + ) + assert result == {} + + def test_single_pass_no_deps(self): + rgb = BufferDesc(channels=3) + g = RenderGraph(width=4, height=4) + g.add_pass(WritingPass("only", outputs={"out": rgb}, value=99.0)) + compiled = g.compile() + result = compiled.execute( + allocator=lambda shape, dtype: np.zeros(shape, dtype=dtype) + ) + np.testing.assert_allclose(result["out"], 99.0) + + def test_pass_self_loop_is_not_cycle(self): + """A pass that reads and writes the same buffer (in-place) is fine.""" + rgb = BufferDesc(channels=3) + g = RenderGraph(width=4, height=4) + g.add_pass(WritingPass("src", outputs={"buf": rgb})) + g.add_pass(StubPass("inplace", inputs={"buf": rgb}, outputs={"buf": rgb})) + compiled = g.compile() + assert len(compiled.ordered_passes) == 2 + + def test_wide_fan_out(self): + """One producer, many consumers.""" + rgb = BufferDesc(channels=3) + g = RenderGraph(width=4, height=4) + g.add_pass(WritingPass("src", outputs={"shared": rgb}, value=1.0)) + for i in range(10): + g.add_pass( + SummingPass(f"consumer_{i}", inputs={"shared": rgb}, outputs={f"out_{i}": rgb}) + ) + compiled = g.compile() + result = compiled.execute( + allocator=lambda shape, dtype: np.zeros(shape, dtype=dtype) + ) + for i in range(10): + np.testing.assert_allclose(result[f"out_{i}"], 1.0) + + def test_allocation_plan_empty_graph(self): + plan = AllocationPlan() + assert plan.num_slots == 0 + + +# --------------------------------------------------------------------------- +# Integration-style: realistic pipeline shape +# --------------------------------------------------------------------------- + + +class TestRealisticPipeline: + def test_full_pipeline_shape(self): + """Mirror the proposed pass structure from issue #73.""" + rgb = BufferDesc(dtype="float32", channels=3, per_pixel=True) + scalar = BufferDesc(dtype="float32", channels=1, per_pixel=True) + vec4 = BufferDesc(dtype="float32", channels=4, per_pixel=False) + mat_id = BufferDesc(dtype="int32", channels=1, per_pixel=False) + + g = RenderGraph(width=64, height=64) + + g.add_pass( + WritingPass( + "gbuffer", + outputs={ + "albedo": rgb, + "normal": rgb, + "depth": scalar, + "position": rgb, + "material_id": mat_id, + }, + value=0.5, + ) + ) + g.add_pass( + WritingPass( + "shadow", + inputs={"position": rgb, "normal": rgb}, + outputs={"shadow_mask": scalar}, + value=1.0, + ) + ) + g.add_pass( + WritingPass( + "ao", + inputs={"position": rgb, "normal": rgb, "depth": scalar}, + outputs={"ao_map": scalar}, + value=0.9, + ) + ) + g.add_pass( + WritingPass( + "gi", + inputs={ + "position": rgb, + "normal": rgb, + "albedo": rgb, + "shadow_mask": scalar, + }, + outputs={"indirect_light": rgb}, + value=0.2, + ) + ) + g.add_pass( + WritingPass( + "denoise", + inputs={"color": rgb, "albedo": rgb, "normal": rgb}, + outputs={"denoised_color": rgb}, + requires=["optix_denoiser"], + value=0.6, + ) + ) + g.add_pass( + WritingPass( + "tonemap", + inputs={"denoised_color": rgb}, + outputs={"ldr_color": rgb}, + value=0.8, + ) + ) + g.add_pass( + StubPass( + "composite", + inputs={"ldr_color": rgb}, + outputs={"final": rgb}, + ) + ) + + # The shade pass produces 'color' consumed by denoise + g.add_pass( + WritingPass( + "shade", + inputs={ + "albedo": rgb, + "shadow_mask": scalar, + "ao_map": scalar, + "indirect_light": rgb, + }, + outputs={"color": rgb}, + value=0.7, + ) + ) + + g.set_fallback("denoised_color", "color") + + # Compile without denoiser capability + compiled = g.compile(capabilities={}) + names = [p.name for p in compiled.ordered_passes] + assert "denoise" not in names + assert "gbuffer" in names + assert "tonemap" in names + + # gbuffer must come before shadow, ao + assert names.index("gbuffer") < names.index("shadow") + assert names.index("gbuffer") < names.index("ao") + # shade must come after shadow, ao, gi + assert names.index("shade") > names.index("shadow") + assert names.index("shade") > names.index("ao") + assert names.index("shade") > names.index("gi") + # tonemap after shade (via fallback denoised_color -> color) + assert names.index("tonemap") > names.index("shade") + + result = compiled.execute( + allocator=lambda shape, dtype: np.zeros(shape, dtype=dtype) + ) + # denoised_color falls back to color (which shade wrote as 0.7) + assert result["denoised_color"] is result["color"] + np.testing.assert_allclose(result["color"], 0.7) + + # Now compile WITH denoiser + compiled2 = g.compile(capabilities={"optix_denoiser": True}) + names2 = [p.name for p in compiled2.ordered_passes] + assert "denoise" in names2 + assert names2.index("shade") < names2.index("denoise") + assert names2.index("denoise") < names2.index("tonemap") From d42427203eca423d6d81457f707542a72f62b7da Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Thu, 5 Mar 2026 08:52:28 -0800 Subject: [PATCH 2/3] Add render graph documentation (#73) API reference for BufferDesc, RenderPass, RenderGraph, CompiledGraph. User guide section with worked example showing capability gating and fallbacks. --- docs/api-reference.md | 86 +++++++++++++++++++++++++++++++++++++++++++ docs/user-guide.md | 57 ++++++++++++++++++++++++++++ 2 files changed, 143 insertions(+) diff --git a/docs/api-reference.md b/docs/api-reference.md index a88dbc9..963c4e6 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -517,6 +517,92 @@ Find which zarr chunks overlap a pixel-coordinate window. --- +## Render Graph + +A configurable DAG of render passes. Declare inputs/outputs per pass, and the graph resolves execution order, manages GPU buffers, and gates passes on hardware capabilities. + +```python +from rtxpy import RenderGraph, RenderPass, BufferDesc +``` + +### `BufferDesc(dtype='float32', channels=3, per_pixel=True)` + +Describes a GPU buffer's shape and data type. + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `dtype` | str | `'float32'` | NumPy-compatible dtype | +| `channels` | int | `3` | Channels per element | +| `per_pixel` | bool | `True` | `True` = `(H, W, C)` shape; `False` = `(N, C)` flat | + +#### `shape(width, height)` + +**Returns:** `tuple[int, ...]` — concrete shape for the given resolution + +### `RenderPass` (abstract base class) + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `name` | str | required | Unique pass name | +| `inputs` | dict[str, BufferDesc] | `{}` | Buffers this pass reads | +| `outputs` | dict[str, BufferDesc] | `{}` | Buffers this pass writes | +| `enabled` | bool | `True` | Set `False` to skip | +| `requires` | list[str] | `[]` | Capability keys that must be truthy | + +Override `execute(buffers)` with your pass logic. Optionally override `setup(graph)` and `teardown()`. + +### `RenderGraph(width=1920, height=1080)` + +#### `add_pass(pass_)` + +Add a render pass. Raises `ValueError` on duplicate names. + +#### `remove_pass(name)` + +Remove a pass by name. Raises `KeyError` if not found. + +#### `get_pass(name)` + +**Returns:** `RenderPass` + +#### `set_fallback(buffer, fallback)` + +If `buffer`'s producer is disabled, downstream passes read `fallback` instead. + +```python +graph.set_fallback("denoised_color", "color") +``` + +#### `compile(capabilities=None, validate=True)` + +Compile the graph: capability-gate passes, topological sort, buffer lifetime analysis. + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `capabilities` | dict | `None` | From `get_capabilities()` | +| `validate` | bool | `True` | Raise `GraphValidationError` on problems | + +**Returns:** `CompiledGraph` + +### `CompiledGraph` + +#### `execute(external_buffers=None, allocator=None)` + +Run all passes in dependency order. + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `external_buffers` | dict | `None` | Pre-allocated buffers to inject | +| `allocator` | callable | `None` | `(shape, dtype) -> array`; defaults to `cupy.zeros` | + +**Returns:** `dict[str, array]` — all buffers after execution + +### `GraphValidationError` + +Raised on cycle detection, missing inputs, or other graph errors. + +--- + ### Device Utilities #### `get_device_count()` diff --git a/docs/user-guide.md b/docs/user-guide.md index 4159e34..46e738f 100644 --- a/docs/user-guide.md +++ b/docs/user-guide.md @@ -535,6 +535,63 @@ verts, indices = dem.rtx.triangulate() write_stl('terrain.stl', verts, indices) ``` +## Render Graph + +The render graph lets you define a pipeline of render passes as a DAG. Instead of editing one monolithic render function, you declare what each pass reads and writes. The graph handles execution order, buffer allocation, and capability gating. + +```python +from rtxpy import RenderGraph, RenderPass, BufferDesc + +rgb = BufferDesc(dtype="float32", channels=3, per_pixel=True) +scalar = BufferDesc(dtype="float32", channels=1, per_pixel=True) + +class MyShadePass(RenderPass): + def __init__(self): + super().__init__("shade", outputs={"color": rgb}) + + def execute(self, buffers): + buffers["color"][:] = 0.5 # your shading logic here + +class MyDenoisePass(RenderPass): + def __init__(self): + super().__init__( + "denoise", + inputs={"color": rgb}, + outputs={"denoised_color": rgb}, + requires=["optix_denoiser"], + ) + + def execute(self, buffers): + buffers["denoised_color"][:] = buffers["color"] # denoiser call here + +class MyTonemapPass(RenderPass): + def __init__(self): + super().__init__( + "tonemap", + inputs={"denoised_color": rgb}, + outputs={"final": rgb}, + ) + + def execute(self, buffers): + buffers["final"][:] = buffers["denoised_color"] ** (1.0 / 2.2) + +graph = RenderGraph(width=1920, height=1080) +graph.add_pass(MyShadePass()) +graph.add_pass(MyDenoisePass()) +graph.add_pass(MyTonemapPass()) + +# If denoiser unavailable, tonemap reads 'color' directly +graph.set_fallback("denoised_color", "color") + +compiled = graph.compile(capabilities={"optix_denoiser": False}) +result = compiled.execute() +# result["final"] contains the output image +``` + +The graph skips passes whose `requires` capabilities aren't present, wiring fallbacks so downstream passes still work. Buffer lifetime analysis reuses GPU memory when buffers don't overlap in time. + +See the [API Reference](api-reference.md#render-graph) for the full interface. + ## Performance Tips - **Subsample large DEMs**: `dem[::2, ::2]` or `explore(subsample=4)` — 4x subsample is 16x less geometry From 0781b3e9b03f2797f72b9faeeef176bd4eef5253 Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Thu, 5 Mar 2026 08:53:59 -0800 Subject: [PATCH 3/3] Add render graph example script (#73) CPU-only demo showing custom passes, capability gating, fallback wiring, and the full compile-then-execute workflow. --- examples/render_graph_demo.py | 196 ++++++++++++++++++++++++++++++++++ 1 file changed, 196 insertions(+) create mode 100644 examples/render_graph_demo.py diff --git a/examples/render_graph_demo.py b/examples/render_graph_demo.py new file mode 100644 index 0000000..aa52e7c --- /dev/null +++ b/examples/render_graph_demo.py @@ -0,0 +1,196 @@ +"""Render graph demo — build a multi-pass pipeline with capability gating. + +Shows how to define custom render passes, wire them into a graph, and let the +graph handle execution order and fallback wiring when optional passes are +unavailable. + +This example runs on CPU (numpy) and doesn't need a GPU. +""" + +import numpy as np + +from rtxpy import BufferDesc, RenderGraph, RenderPass + +# Buffer descriptors for the pipeline +RGB = BufferDesc(dtype="float32", channels=3, per_pixel=True) +SCALAR = BufferDesc(dtype="float32", channels=1, per_pixel=True) + + +# --- Pass definitions -------------------------------------------------------- + + +class GBufferPass(RenderPass): + """Simulate a GBuffer pass that produces albedo, normals, and depth.""" + + def __init__(self): + super().__init__( + "gbuffer", + outputs={"albedo": RGB, "normal": RGB, "depth": SCALAR}, + ) + + def execute(self, buffers): + h, w, _ = buffers["albedo"].shape + # Checkerboard albedo + yy, xx = np.mgrid[:h, :w] + checker = ((xx // 8 + yy // 8) % 2).astype(np.float32) + buffers["albedo"][:, :, 0] = 0.2 + 0.6 * checker + buffers["albedo"][:, :, 1] = 0.3 + 0.3 * checker + buffers["albedo"][:, :, 2] = 0.1 + 0.2 * (1 - checker) + + # Upward-facing normals + buffers["normal"][:] = [0.0, 0.0, 1.0] + + # Linear depth gradient + buffers["depth"][:, :] = np.linspace(0.0, 1.0, w, dtype=np.float32) + + +class ShadowPass(RenderPass): + """Compute a simple shadow mask from depth.""" + + def __init__(self): + super().__init__( + "shadow", + inputs={"depth": SCALAR}, + outputs={"shadow_mask": SCALAR}, + ) + + def execute(self, buffers): + # Fake shadow: darker where depth > 0.5 + buffers["shadow_mask"][:] = np.where( + buffers["depth"] > 0.5, 0.4, 1.0 + ).astype(np.float32) + + +class AOPass(RenderPass): + """Fake ambient occlusion from depth edges.""" + + def __init__(self): + super().__init__( + "ao", + inputs={"depth": SCALAR}, + outputs={"ao_map": SCALAR}, + ) + + def execute(self, buffers): + depth = buffers["depth"] + # Approximate AO by depth variance in a 3x3 window + padded = np.pad(depth, ((1, 1), (1, 1)), mode="edge") + ao = np.ones_like(depth) + for dy in (-1, 0, 1): + for dx in (-1, 0, 1): + ao -= 0.02 * np.abs( + padded[1 + dy : depth.shape[0] + 1 + dy, + 1 + dx : depth.shape[1] + 1 + dx] + - depth + ) + buffers["ao_map"][:] = np.clip(ao, 0.3, 1.0) + + +class ShadePass(RenderPass): + """Combine albedo, shadow, and AO into a lit color buffer.""" + + def __init__(self): + super().__init__( + "shade", + inputs={"albedo": RGB, "shadow_mask": SCALAR, "ao_map": SCALAR}, + outputs={"color": RGB}, + ) + + def execute(self, buffers): + albedo = buffers["albedo"] + shadow = buffers["shadow_mask"][:, :, np.newaxis] + ao = buffers["ao_map"][:, :, np.newaxis] + buffers["color"][:] = albedo * shadow * ao + + +class DenoisePass(RenderPass): + """Placeholder denoiser — requires 'optix_denoiser' capability.""" + + def __init__(self): + super().__init__( + "denoise", + inputs={"color": RGB, "albedo": RGB, "normal": RGB}, + outputs={"denoised_color": RGB}, + requires=["optix_denoiser"], + ) + + def execute(self, buffers): + # Real implementation would call OptiX denoiser + buffers["denoised_color"][:] = buffers["color"] + + +class TonemapPass(RenderPass): + """Simple Reinhard tone mapping.""" + + def __init__(self): + super().__init__( + "tonemap", + inputs={"denoised_color": RGB}, + outputs={"ldr_color": RGB}, + ) + + def execute(self, buffers): + hdr = buffers["denoised_color"] + buffers["ldr_color"][:] = hdr / (1.0 + hdr) + + +# --- Build and run the graph ------------------------------------------------ + + +def main(): + width, height = 128, 96 + + graph = RenderGraph(width=width, height=height) + graph.add_pass(GBufferPass()) + graph.add_pass(ShadowPass()) + graph.add_pass(AOPass()) + graph.add_pass(ShadePass()) + graph.add_pass(DenoisePass()) + graph.add_pass(TonemapPass()) + + # If denoiser is unavailable, tonemap reads 'color' directly + graph.set_fallback("denoised_color", "color") + + # --- Run without denoiser --- + print("Compiling graph WITHOUT denoiser capability...") + compiled = graph.compile(capabilities={}) + print(f" Active passes: {[p.name for p in compiled.ordered_passes]}") + print(f" Buffer pool slots: {compiled.allocation_plan.num_slots}") + + result = compiled.execute( + allocator=lambda shape, dtype: np.zeros(shape, dtype=dtype) + ) + ldr = result["ldr_color"] + print(f" Output shape: {ldr.shape}, range: [{ldr.min():.3f}, {ldr.max():.3f}]") + + # --- Run with denoiser --- + print("\nCompiling graph WITH denoiser capability...") + compiled2 = graph.compile(capabilities={"optix_denoiser": True}) + print(f" Active passes: {[p.name for p in compiled2.ordered_passes]}") + + result2 = compiled2.execute( + allocator=lambda shape, dtype: np.zeros(shape, dtype=dtype) + ) + ldr2 = result2["ldr_color"] + print(f" Output shape: {ldr2.shape}, range: [{ldr2.min():.3f}, {ldr2.max():.3f}]") + + # Save to PNG if matplotlib available + try: + import matplotlib.pyplot as plt + + fig, axes = plt.subplots(1, 2, figsize=(10, 4)) + axes[0].imshow(np.clip(result["ldr_color"], 0, 1)) + axes[0].set_title("Without denoiser") + axes[0].axis("off") + axes[1].imshow(np.clip(result2["ldr_color"], 0, 1)) + axes[1].set_title("With denoiser") + axes[1].axis("off") + plt.tight_layout() + plt.savefig("render_graph_demo.png", dpi=150) + print("\nSaved render_graph_demo.png") + except ImportError: + print("\nmatplotlib not available, skipping image save") + + +if __name__ == "__main__": + main()