From 263d7f800cff747bbbb63906beb7da07c7bf5b07 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Fri, 26 Jun 2026 20:37:56 +0000 Subject: [PATCH 1/3] feat(bigframes): Experimental local sample cache --- .../bigframes/_config/compute_options.py | 13 + .../bigframes/session/bq_caching_executor.py | 52 ++++ .../bigframes/session/execution_spec.py | 4 + .../bigframes/bigframes/session/peek_cache.py | 96 +++++++ .../tests/unit/session/test_peek_cache.py | 240 ++++++++++++++++++ 5 files changed, 405 insertions(+) create mode 100644 packages/bigframes/bigframes/session/peek_cache.py create mode 100644 packages/bigframes/tests/unit/session/test_peek_cache.py diff --git a/packages/bigframes/bigframes/_config/compute_options.py b/packages/bigframes/bigframes/_config/compute_options.py index 027566ae075f..66b500e1d366 100644 --- a/packages/bigframes/bigframes/_config/compute_options.py +++ b/packages/bigframes/bigframes/_config/compute_options.py @@ -168,6 +168,19 @@ class ComputeOptions: int | None: Number of rows, if set. """ + enable_peek_cache: bool = False + """ + If enabled, peeking at a relation will pull a larger local sample (e.g. 10k rows) + and cache it locally. Subsequent compatible operations on the relation will run + locally on the cached sample, enabling fast interactive iteration. + """ + + peek_cache_size: int = 10000 + """ + The size of the local sample to pull and cache when peeking at a relation. + Defaults to 10000. + """ + semantic_ops_confirmation_threshold: Optional[int] = 0 """ Deprecated. diff --git a/packages/bigframes/bigframes/session/bq_caching_executor.py b/packages/bigframes/bigframes/session/bq_caching_executor.py index dede318d8132..973a6be37239 100644 --- a/packages/bigframes/bigframes/session/bq_caching_executor.py +++ b/packages/bigframes/bigframes/session/bq_caching_executor.py @@ -167,6 +167,8 @@ def __init__( labels=dict(labels), ) self._function_manager = function_manager + from bigframes.session.peek_cache import PeekCache + self._peek_cache = PeekCache() def to_sql( self, @@ -209,6 +211,56 @@ async def _execute_async( execution_spec: ex_spec.ExecutionSpec, ) -> executor.ExecuteResult: await self._publisher.publish_async(bigframes.core.events.ExecutionStarted()) + + enable_peek_cache = ( + execution_spec.bigquery_config.enable_peek_cache + if execution_spec.bigquery_config + else False + ) + + if execution_spec.peek is not None and enable_peek_cache: + from bigframes.session.peek_cache import substitute_peek_cached_subplans + rewritten_node = substitute_peek_cached_subplans(array_value.node, self._peek_cache) + if rewritten_node != array_value.node: + rewritten_array_value = bigframes.core.ArrayValue(rewritten_node) + maybe_result = await self._try_execute_semi_executors( + rewritten_array_value, execution_spec + ) + if maybe_result is not None: + return maybe_result + + sample_size = ( + execution_spec.bigquery_config.peek_cache_size + if execution_spec.bigquery_config + else 10000 + ) + actual_sample_size = max(execution_spec.peek, sample_size) + cache_execution_spec = dataclasses.replace(execution_spec, peek=actual_sample_size) + + bq_result = await self._execute_bigquery( + array_value, + cache_execution_spec, + ) + + arrow_table = await asyncio.to_thread(bq_result.batches().to_arrow_table) + managed_table = local_data.ManagedArrowTable.from_pyarrow(arrow_table, bq_result.schema) + self._peek_cache.put(array_value.node, managed_table) + + sliced_table = arrow_table.slice(0, execution_spec.peek) + result: executor.ExecuteResult = executor.LocalExecuteResult( + sliced_table, + bq_result.schema, + execution_metadata=bq_result.execution_metadata, + ) + + await self._publisher.publish_async( + bigframes.core.events.EventEnvelope( + event=bigframes.core.events.ExecutionFinished(result=result), + cell_execution_count=execution_spec.cell_execution_count, + ) + ) + return result + maybe_result = await self._try_execute_semi_executors( array_value, execution_spec ) diff --git a/packages/bigframes/bigframes/session/execution_spec.py b/packages/bigframes/bigframes/session/execution_spec.py index 89de6eec9021..00e12b4ad84e 100644 --- a/packages/bigframes/bigframes/session/execution_spec.py +++ b/packages/bigframes/bigframes/session/execution_spec.py @@ -27,6 +27,8 @@ class BqComputeOptions: enable_multi_query_execution: bool = True maximum_bytes_billed: Optional[int] = None extra_query_labels: tuple[tuple[str, str], ...] = () + enable_peek_cache: bool = False + peek_cache_size: int = 10000 @classmethod def from_compute_options(cls, compute_options: ComputeOptions) -> BqComputeOptions: @@ -34,6 +36,8 @@ def from_compute_options(cls, compute_options: ComputeOptions) -> BqComputeOptio enable_multi_query_execution=compute_options.enable_multi_query_execution, maximum_bytes_billed=compute_options.maximum_bytes_billed, extra_query_labels=tuple(compute_options.extra_query_labels.items()), + enable_peek_cache=compute_options.enable_peek_cache, + peek_cache_size=compute_options.peek_cache_size, ) def push_labels(self, labels: Mapping[str, str]) -> BqComputeOptions: diff --git a/packages/bigframes/bigframes/session/peek_cache.py b/packages/bigframes/bigframes/session/peek_cache.py new file mode 100644 index 000000000000..f3f7ca1157bf --- /dev/null +++ b/packages/bigframes/bigframes/session/peek_cache.py @@ -0,0 +1,96 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections import OrderedDict +import threading +from typing import Optional + +from bigframes.core import local_data, nodes + + +class PeekCache: + """ + Thread-safe LRU cache for storing local samples of query relations. + This enables fast iteration on subsequent compatible operations. + """ + + def __init__(self, capacity: int = 100): + self.capacity = capacity + self._cache: OrderedDict[nodes.BigFrameNode, local_data.ManagedArrowTable] = OrderedDict() + self._lock = threading.Lock() + + def get(self, key: nodes.BigFrameNode) -> Optional[local_data.ManagedArrowTable]: + with self._lock: + if key not in self._cache: + return None + # Move to end (most recently used) + self._cache.move_to_end(key) + return self._cache[key] + + def put(self, key: nodes.BigFrameNode, value: local_data.ManagedArrowTable) -> None: + with self._lock: + if key in self._cache: + self._cache.move_to_end(key) + self._cache[key] = value + if len(self._cache) > self.capacity: + self._cache.popitem(last=False) + + def clear(self) -> None: + with self._lock: + self._cache.clear() + + +def substitute_peek_cached_subplans( + root: nodes.BigFrameNode, + peek_cache: PeekCache, +) -> nodes.BigFrameNode: + """ + Recursively replaces subplans in the tree that have a cached local sample + in the peek cache with a ReadLocalNode, provided that all ancestors + of the subplan are compatible with running on a sample. + """ + # Intermediate nodes that preserve the semantic validity of a sample. + # WindowOpNode, AggregateNode, OrderByNode, JoinNode, etc. are excluded + # because evaluating them on a sample breaks semantic contracts. + _COMPATIBLE_ANCESTOR_CLASSES = ( + nodes.SelectionNode, + nodes.ProjectionNode, + nodes.FilterNode, + nodes.PromoteOffsetsNode, + ) + + def traverse(node: nodes.BigFrameNode, ancestors_compatible: bool) -> nodes.BigFrameNode: + if ancestors_compatible: + cached_sample = peek_cache.get(node) + if cached_sample is not None: + # Replace the node with a ReadLocalNode containing the cached sample + scan_list = nodes.ScanList( + tuple(nodes.ScanItem(field.id, field.id.sql) for field in node.fields) + ) + session = node.session if node.session is not None else root.session + return nodes.ReadLocalNode( + local_data_source=cached_sample, + scan_list=scan_list, + session=session, + ) + + # If we didn't replace, recursively transform children + is_current_compatible = isinstance(node, _COMPATIBLE_ANCESTOR_CLASSES) + next_ancestors_compatible = ancestors_compatible and is_current_compatible + + return node.transform_children(lambda child: traverse(child, next_ancestors_compatible)) + + return traverse(root, True) diff --git a/packages/bigframes/tests/unit/session/test_peek_cache.py b/packages/bigframes/tests/unit/session/test_peek_cache.py new file mode 100644 index 000000000000..a92cd03b198c --- /dev/null +++ b/packages/bigframes/tests/unit/session/test_peek_cache.py @@ -0,0 +1,240 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +from unittest import mock + +import google.cloud.bigquery +import pyarrow as pa +import pytest + +import bigframes +from bigframes.core import identifiers, local_data, nodes +from bigframes.session import bq_caching_executor, execution_spec, executor +from bigframes.session.peek_cache import PeekCache, substitute_peek_cached_subplans +from bigframes.testing import mocks + + +def test_peek_cache_lru(): + cache = PeekCache(capacity=2) + session = mocks.create_bigquery_session() + + # Create some mock nodes and data sources + table1 = pa.Table.from_pydict({"a": [1, 2]}) + table2 = pa.Table.from_pydict({"b": [3, 4]}) + table3 = pa.Table.from_pydict({"c": [5, 6]}) + + ds1 = local_data.ManagedArrowTable.from_pyarrow(table1) + ds2 = local_data.ManagedArrowTable.from_pyarrow(table2) + ds3 = local_data.ManagedArrowTable.from_pyarrow(table3) + + node1 = nodes.ReadLocalNode(ds1, nodes.ScanList(()), session) + node2 = nodes.ReadLocalNode(ds2, nodes.ScanList(()), session) + node3 = nodes.ReadLocalNode(ds3, nodes.ScanList(()), session) + + cache.put(node1, ds1) + cache.put(node2, ds2) + + # Access node1 to make it most recently used, leaving node2 as least recently used (LRU) + assert cache.get(node1) == ds1 + + # Put node3, which should evict node2 + cache.put(node3, ds3) + + assert cache.get(node2) is None + assert cache.get(node1) == ds1 + assert cache.get(node3) == ds3 + + +def test_substitute_peek_cached_subplans(): + session = mocks.create_bigquery_session() + table = pa.Table.from_pydict({"a": [1, 2]}) + ds = local_data.ManagedArrowTable.from_pyarrow(table) + + # Create a simple leaf node + leaf = nodes.ReadLocalNode( + local_data_source=ds, + scan_list=nodes.ScanList((nodes.ScanItem(identifiers.ColumnId("col_a"), "a"),)), + session=session, + ) + + # Cache the leaf node + cache = PeekCache() + cached_table = pa.Table.from_pydict({"col_a": [100, 200]}) + cached_ds = local_data.ManagedArrowTable.from_pyarrow(cached_table) + cache.put(leaf, cached_ds) + + # Now perform the tree substitution + rewritten = substitute_peek_cached_subplans(leaf, cache) + + # The leaf should be replaced by a new ReadLocalNode containing cached_ds + assert isinstance(rewritten, nodes.ReadLocalNode) + assert rewritten.local_data_source == cached_ds + assert rewritten.session == session + assert len(rewritten.scan_list.items) == 1 + assert rewritten.scan_list.items[0].id == identifiers.ColumnId("col_a") + assert rewritten.scan_list.items[0].source_id == "col_a" + + +def test_executor_peek_cache_integration(): + # Mock all arguments to BigQueryCachingExecutor + bqclient = mock.create_autospec(google.cloud.bigquery.Client, instance=True) + bqclient.project = "test-project" + storage_manager = mock.Mock() + bqstoragereadclient = mock.Mock() + loader = mock.Mock() + publisher = mock.AsyncMock() + function_manager = mock.Mock() + + executor_obj = bq_caching_executor.BigQueryCachingExecutor( + bqclient=bqclient, + storage_manager=storage_manager, + bqstoragereadclient=bqstoragereadclient, + loader=loader, + publisher=publisher, + function_manager=function_manager, + ) + + table = pa.Table.from_pydict({"col": [1, 2, 3, 4, 5]}) + ds = local_data.ManagedArrowTable.from_pyarrow(table) + session = mocks.create_bigquery_session() + + node = nodes.ReadLocalNode( + local_data_source=ds, + scan_list=nodes.ScanList((nodes.ScanItem(identifiers.ColumnId("col"), "col"),)), + session=session, + ) + arr_value = bigframes.core.ArrayValue(node) + + # Mock _execute_bigquery of the executor to return a mock 3-row table + mock_bq_table = pa.Table.from_pydict({"col": [10, 20, 30]}) + mock_bq_result = executor.LocalExecuteResult(mock_bq_table, arr_value.schema) + + execute_bq_mock = mock.AsyncMock(return_value=mock_bq_result) + executor_obj._execute_bigquery = execute_bq_mock + + # Enable peek cache options + compute_options = bigframes.options.compute + compute_options.enable_peek_cache = True + compute_options.peek_cache_size = 3 + + # Call execute with peek=1 (cache miss path) + spec = execution_spec.ExecutionSpec(peek=1).with_compute_options(compute_options) + result = asyncio.run(executor_obj._execute_async(arr_value, spec)) + + # Verify BQ was called with peek=3 (cache size) + assert execute_bq_mock.call_count == 1 + called_spec = execute_bq_mock.call_args[0][1] + assert called_spec.peek == 3 + + # Verify returned result has exactly 1 row + result_table = pa.Table.from_batches(result.batches().arrow_batches) + assert result_table.num_rows == 1 + assert result_table["col"].to_pylist() == [10] + + # Verify peek cache has been populated with the 3-row table + cached_entry = executor_obj._peek_cache.get(node) + assert cached_entry is not None + assert cached_entry.to_pyarrow_table()["col"].to_pylist() == [10, 20, 30] + + # Call execute again with peek=2 (cache hit path) + execute_bq_mock.reset_mock() + spec2 = execution_spec.ExecutionSpec(peek=2).with_compute_options(compute_options) + result2 = asyncio.run(executor_obj._execute_async(arr_value, spec2)) + + # Verify BQ was NOT called + assert execute_bq_mock.call_count == 0 + + # Verify returned result has exactly 2 rows + result_table2 = pa.Table.from_batches(result2.batches().arrow_batches) + assert result_table2.num_rows == 2 + assert result_table2["col"].to_pylist() == [10, 20] + + +def test_peek_cache_thread_safety(): + import threading + + cache = PeekCache(capacity=100) + session = mocks.create_bigquery_session() + + # Create dummy nodes and data sources + num_items = 50 + num_threads = 10 + nodes_list = [] + for i in range(num_items): + table = pa.Table.from_pydict({"col": [i]}) + ds = local_data.ManagedArrowTable.from_pyarrow(table) + node = nodes.ReadLocalNode(ds, nodes.ScanList(()), session) + nodes_list.append((node, ds)) + + def worker(worker_id): + for i in range(100): + node, ds = nodes_list[(worker_id + i) % num_items] + cache.put(node, ds) + cache.get(node) + + threads = [] + for i in range(num_threads): + t = threading.Thread(target=worker, args=(i,)) + threads.append(t) + t.start() + + for t in threads: + t.join() + + # The cache should be in a consistent state and not exceed capacity + assert len(cache._cache) <= 100 + + +def test_substitute_peek_cached_subplans_incompatible_ancestors(): + session = mocks.create_bigquery_session() + table = pa.Table.from_pydict({"a": [1, 2]}) + ds = local_data.ManagedArrowTable.from_pyarrow(table) + + # Leaf node (cached) + leaf = nodes.ReadLocalNode( + local_data_source=ds, + scan_list=nodes.ScanList((nodes.ScanItem(identifiers.ColumnId("col_a"), "a"),)), + session=session, + ) + + cache = PeekCache() + cached_table = pa.Table.from_pydict({"col_a": [100, 200]}) + cached_ds = local_data.ManagedArrowTable.from_pyarrow(cached_table) + cache.put(leaf, cached_ds) + + # Scenario A: Path has only compatible nodes: FilterNode -> Leaf + # FilterNode is a compatible ancestor. + plan_compatible = nodes.FilterNode( + child=leaf, + predicate=bigframes.core.expression.ScalarConstantExpression(True), # Dummy expression + ) + + rewritten_compatible = substitute_peek_cached_subplans(plan_compatible, cache) + # The leaf child of FilterNode should be replaced by ReadLocalNode with cached_ds + assert isinstance(rewritten_compatible, nodes.FilterNode) + assert isinstance(rewritten_compatible.child, nodes.ReadLocalNode) + assert rewritten_compatible.child.local_data_source == cached_ds + + # Scenario B: Path has an incompatible node: ReversedNode -> Leaf + # ReversedNode is an incompatible ancestor. + plan_incompatible = nodes.ReversedNode(child=leaf) + + rewritten_incompatible = substitute_peek_cached_subplans(plan_incompatible, cache) + # The leaf child should NOT be replaced by ReadLocalNode + assert isinstance(rewritten_incompatible, nodes.ReversedNode) + assert rewritten_incompatible.child == leaf + assert rewritten_incompatible.child.local_data_source == ds From 2703ffe8f8c384a8896e859bc15f3a750e57d849 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Fri, 26 Jun 2026 20:46:02 +0000 Subject: [PATCH 2/3] fixes --- .../bigframes/session/bq_caching_executor.py | 16 ++++++-- .../bigframes/bigframes/session/peek_cache.py | 27 +++++++++---- .../tests/unit/session/test_peek_cache.py | 40 ++++++++++++++++--- 3 files changed, 68 insertions(+), 15 deletions(-) diff --git a/packages/bigframes/bigframes/session/bq_caching_executor.py b/packages/bigframes/bigframes/session/bq_caching_executor.py index 973a6be37239..247fb35b910c 100644 --- a/packages/bigframes/bigframes/session/bq_caching_executor.py +++ b/packages/bigframes/bigframes/session/bq_caching_executor.py @@ -168,6 +168,7 @@ def __init__( ) self._function_manager = function_manager from bigframes.session.peek_cache import PeekCache + self._peek_cache = PeekCache() def to_sql( @@ -220,7 +221,12 @@ async def _execute_async( if execution_spec.peek is not None and enable_peek_cache: from bigframes.session.peek_cache import substitute_peek_cached_subplans - rewritten_node = substitute_peek_cached_subplans(array_value.node, self._peek_cache) + + rewritten_node = substitute_peek_cached_subplans( + array_value.node, + self._peek_cache, + min_rows_required=execution_spec.peek, + ) if rewritten_node != array_value.node: rewritten_array_value = bigframes.core.ArrayValue(rewritten_node) maybe_result = await self._try_execute_semi_executors( @@ -235,7 +241,9 @@ async def _execute_async( else 10000 ) actual_sample_size = max(execution_spec.peek, sample_size) - cache_execution_spec = dataclasses.replace(execution_spec, peek=actual_sample_size) + cache_execution_spec = dataclasses.replace( + execution_spec, peek=actual_sample_size + ) bq_result = await self._execute_bigquery( array_value, @@ -243,7 +251,9 @@ async def _execute_async( ) arrow_table = await asyncio.to_thread(bq_result.batches().to_arrow_table) - managed_table = local_data.ManagedArrowTable.from_pyarrow(arrow_table, bq_result.schema) + managed_table = local_data.ManagedArrowTable.from_pyarrow( + arrow_table, bq_result.schema + ) self._peek_cache.put(array_value.node, managed_table) sliced_table = arrow_table.slice(0, execution_spec.peek) diff --git a/packages/bigframes/bigframes/session/peek_cache.py b/packages/bigframes/bigframes/session/peek_cache.py index f3f7ca1157bf..3d57e6b87298 100644 --- a/packages/bigframes/bigframes/session/peek_cache.py +++ b/packages/bigframes/bigframes/session/peek_cache.py @@ -14,8 +14,8 @@ from __future__ import annotations -from collections import OrderedDict import threading +from collections import OrderedDict from typing import Optional from bigframes.core import local_data, nodes @@ -29,7 +29,9 @@ class PeekCache: def __init__(self, capacity: int = 100): self.capacity = capacity - self._cache: OrderedDict[nodes.BigFrameNode, local_data.ManagedArrowTable] = OrderedDict() + self._cache: OrderedDict[nodes.BigFrameNode, local_data.ManagedArrowTable] = ( + OrderedDict() + ) self._lock = threading.Lock() def get(self, key: nodes.BigFrameNode) -> Optional[local_data.ManagedArrowTable]: @@ -56,11 +58,13 @@ def clear(self) -> None: def substitute_peek_cached_subplans( root: nodes.BigFrameNode, peek_cache: PeekCache, + min_rows_required: int, ) -> nodes.BigFrameNode: """ Recursively replaces subplans in the tree that have a cached local sample in the peek cache with a ReadLocalNode, provided that all ancestors - of the subplan are compatible with running on a sample. + of the subplan are compatible with running on a sample, and the cached + sample contains at least the required number of rows. """ # Intermediate nodes that preserve the semantic validity of a sample. # WindowOpNode, AggregateNode, OrderByNode, JoinNode, etc. are excluded @@ -72,13 +76,20 @@ def substitute_peek_cached_subplans( nodes.PromoteOffsetsNode, ) - def traverse(node: nodes.BigFrameNode, ancestors_compatible: bool) -> nodes.BigFrameNode: + def traverse( + node: nodes.BigFrameNode, ancestors_compatible: bool + ) -> nodes.BigFrameNode: if ancestors_compatible: cached_sample = peek_cache.get(node) - if cached_sample is not None: + if ( + cached_sample is not None + and cached_sample.data.num_rows >= min_rows_required + ): # Replace the node with a ReadLocalNode containing the cached sample scan_list = nodes.ScanList( - tuple(nodes.ScanItem(field.id, field.id.sql) for field in node.fields) + tuple( + nodes.ScanItem(field.id, field.id.name) for field in node.fields + ) ) session = node.session if node.session is not None else root.session return nodes.ReadLocalNode( @@ -91,6 +102,8 @@ def traverse(node: nodes.BigFrameNode, ancestors_compatible: bool) -> nodes.BigF is_current_compatible = isinstance(node, _COMPATIBLE_ANCESTOR_CLASSES) next_ancestors_compatible = ancestors_compatible and is_current_compatible - return node.transform_children(lambda child: traverse(child, next_ancestors_compatible)) + return node.transform_children( + lambda child: traverse(child, next_ancestors_compatible) + ) return traverse(root, True) diff --git a/packages/bigframes/tests/unit/session/test_peek_cache.py b/packages/bigframes/tests/unit/session/test_peek_cache.py index a92cd03b198c..5d283e68b05f 100644 --- a/packages/bigframes/tests/unit/session/test_peek_cache.py +++ b/packages/bigframes/tests/unit/session/test_peek_cache.py @@ -19,7 +19,6 @@ import google.cloud.bigquery import pyarrow as pa -import pytest import bigframes from bigframes.core import identifiers, local_data, nodes @@ -78,7 +77,7 @@ def test_substitute_peek_cached_subplans(): cache.put(leaf, cached_ds) # Now perform the tree substitution - rewritten = substitute_peek_cached_subplans(leaf, cache) + rewritten = substitute_peek_cached_subplans(leaf, cache, min_rows_required=1) # The leaf should be replaced by a new ReadLocalNode containing cached_ds assert isinstance(rewritten, nodes.ReadLocalNode) @@ -220,10 +219,14 @@ def test_substitute_peek_cached_subplans_incompatible_ancestors(): # FilterNode is a compatible ancestor. plan_compatible = nodes.FilterNode( child=leaf, - predicate=bigframes.core.expression.ScalarConstantExpression(True), # Dummy expression + predicate=bigframes.core.expression.ScalarConstantExpression( + True + ), # Dummy expression ) - rewritten_compatible = substitute_peek_cached_subplans(plan_compatible, cache) + rewritten_compatible = substitute_peek_cached_subplans( + plan_compatible, cache, min_rows_required=1 + ) # The leaf child of FilterNode should be replaced by ReadLocalNode with cached_ds assert isinstance(rewritten_compatible, nodes.FilterNode) assert isinstance(rewritten_compatible.child, nodes.ReadLocalNode) @@ -233,8 +236,35 @@ def test_substitute_peek_cached_subplans_incompatible_ancestors(): # ReversedNode is an incompatible ancestor. plan_incompatible = nodes.ReversedNode(child=leaf) - rewritten_incompatible = substitute_peek_cached_subplans(plan_incompatible, cache) + rewritten_incompatible = substitute_peek_cached_subplans( + plan_incompatible, cache, min_rows_required=1 + ) # The leaf child should NOT be replaced by ReadLocalNode assert isinstance(rewritten_incompatible, nodes.ReversedNode) assert rewritten_incompatible.child == leaf assert rewritten_incompatible.child.local_data_source == ds + + +def test_substitute_peek_cached_subplans_insufficient_rows(): + session = mocks.create_bigquery_session() + table = pa.Table.from_pydict({"a": [1, 2]}) + ds = local_data.ManagedArrowTable.from_pyarrow(table) + + # Leaf node (cached with a 2-row sample) + leaf = nodes.ReadLocalNode( + local_data_source=ds, + scan_list=nodes.ScanList((nodes.ScanItem(identifiers.ColumnId("col_a"), "a"),)), + session=session, + ) + + cache = PeekCache() + cache.put(leaf, ds) + + # Request min_rows_required = 2 -> Should substitute + rewritten_ok = substitute_peek_cached_subplans(leaf, cache, min_rows_required=2) + assert isinstance(rewritten_ok, nodes.ReadLocalNode) + assert rewritten_ok.local_data_source == ds + + # Request min_rows_required = 3 -> Should NOT substitute (insufficient rows) + rewritten_ng = substitute_peek_cached_subplans(leaf, cache, min_rows_required=3) + assert rewritten_ng == leaf From c8e775dafd72e5589bc46466645921da52210468 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Mon, 29 Jun 2026 23:32:53 +0000 Subject: [PATCH 3/3] more general hyrbid execution --- .../bigframes/bigframes/session/__init__.py | 3 + .../bigframes/bigframes/session/_async.py | 59 ++++ .../bigframes/session/bq_caching_executor.py | 104 +------ .../bigframes/bigframes/session/peek_cache.py | 49 ++-- .../bigframes/session/peek_cache_executor.py | 214 ++++++++++++++ .../tests/unit/session/test_peek_cache.py | 277 +++++++++++++++--- 6 files changed, 553 insertions(+), 153 deletions(-) create mode 100644 packages/bigframes/bigframes/session/_async.py create mode 100644 packages/bigframes/bigframes/session/peek_cache_executor.py diff --git a/packages/bigframes/bigframes/session/__init__.py b/packages/bigframes/bigframes/session/__init__.py index e20f61901f9a..b75e2af3b89f 100644 --- a/packages/bigframes/bigframes/session/__init__.py +++ b/packages/bigframes/bigframes/session/__init__.py @@ -371,6 +371,9 @@ def __init__( labels=tuple(labels.items()), function_manager=self._function_session, ) + from bigframes.session.peek_cache_executor import PeekCacheExecutor + + self._executor = PeekCacheExecutor(self._executor, publisher=self._publisher) def __del__(self): """Automatic cleanup of internal resources.""" diff --git a/packages/bigframes/bigframes/session/_async.py b/packages/bigframes/bigframes/session/_async.py new file mode 100644 index 000000000000..72253c3a20f3 --- /dev/null +++ b/packages/bigframes/bigframes/session/_async.py @@ -0,0 +1,59 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +import threading +from typing import Optional + +_bg_loop: Optional[asyncio.AbstractEventLoop] = None +_bg_thread: Optional[threading.Thread] = None +_bg_lock = threading.Lock() + + +def _get_bg_loop() -> asyncio.AbstractEventLoop: + global _bg_loop, _bg_thread + with _bg_lock: + if _bg_loop is None: + loop = asyncio.new_event_loop() + _bg_loop = loop + + def run(): + asyncio.set_event_loop(loop) + loop.run_forever() + + _bg_thread = threading.Thread( + target=run, daemon=True, name="bigframes-bg-loop" + ) + _bg_thread.start() + return _bg_loop + + +def run_sync(coro): + """ + Runs a coroutine synchronously, either in the current thread's event loop + if none is running, or by scheduling it on a background thread's event loop. + """ + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop is None: + return asyncio.run(coro) + else: + bg_loop = _get_bg_loop() + future = asyncio.run_coroutine_threadsafe(coro, bg_loop) + return future.result() diff --git a/packages/bigframes/bigframes/session/bq_caching_executor.py b/packages/bigframes/bigframes/session/bq_caching_executor.py index 247fb35b910c..8ebcf6bcbf4b 100644 --- a/packages/bigframes/bigframes/session/bq_caching_executor.py +++ b/packages/bigframes/bigframes/session/bq_caching_executor.py @@ -18,7 +18,6 @@ import concurrent.futures import dataclasses import math -import threading from typing import Literal, Optional, Sequence, Tuple import google.api_core.exceptions @@ -71,41 +70,7 @@ MAX_SMALL_RESULT_BYTES = 10 * 1024 * 1024 * 1024 # 10G -_bg_loop = None -_bg_thread = None -_bg_lock = threading.Lock() - - -def _get_bg_loop(): - global _bg_loop, _bg_thread - with _bg_lock: - if _bg_loop is None: - loop = asyncio.new_event_loop() - _bg_loop = loop - - def run(): - asyncio.set_event_loop(loop) - loop.run_forever() - - _bg_thread = threading.Thread( - target=run, daemon=True, name="bigframes-bg-loop" - ) - _bg_thread.start() - return _bg_loop - - -def _run_sync(coro): - try: - loop = asyncio.get_running_loop() - except RuntimeError: - loop = None - - if loop is None: - return asyncio.run(coro) - else: - bg_loop = _get_bg_loop() - future = asyncio.run_coroutine_threadsafe(coro, bg_loop) - return future.result() +from bigframes.session._async import run_sync class BigQueryCachingExecutor(executor.Executor): @@ -167,9 +132,6 @@ def __init__( labels=dict(labels), ) self._function_manager = function_manager - from bigframes.session.peek_cache import PeekCache - - self._peek_cache = PeekCache() def to_sql( self, @@ -185,7 +147,7 @@ def to_sql( if enable_cache else array_value.node ) - node = _run_sync(self._substitute_large_local_sources(node)) + node = run_sync(self._substitute_large_local_sources(node)) compiled = compile.compile_sql( compile.CompileRequest(node, sort_rows=ordered), compiler_name=self._compiler_name, @@ -199,7 +161,7 @@ def execute( ) -> executor.ExecuteResult: # Need to grab thread local before starting async execution. execution_spec = execution_spec.with_compute_options(bigframes.options.compute) - return _run_sync( + return run_sync( self._execute_async( array_value, execution_spec, @@ -213,64 +175,6 @@ async def _execute_async( ) -> executor.ExecuteResult: await self._publisher.publish_async(bigframes.core.events.ExecutionStarted()) - enable_peek_cache = ( - execution_spec.bigquery_config.enable_peek_cache - if execution_spec.bigquery_config - else False - ) - - if execution_spec.peek is not None and enable_peek_cache: - from bigframes.session.peek_cache import substitute_peek_cached_subplans - - rewritten_node = substitute_peek_cached_subplans( - array_value.node, - self._peek_cache, - min_rows_required=execution_spec.peek, - ) - if rewritten_node != array_value.node: - rewritten_array_value = bigframes.core.ArrayValue(rewritten_node) - maybe_result = await self._try_execute_semi_executors( - rewritten_array_value, execution_spec - ) - if maybe_result is not None: - return maybe_result - - sample_size = ( - execution_spec.bigquery_config.peek_cache_size - if execution_spec.bigquery_config - else 10000 - ) - actual_sample_size = max(execution_spec.peek, sample_size) - cache_execution_spec = dataclasses.replace( - execution_spec, peek=actual_sample_size - ) - - bq_result = await self._execute_bigquery( - array_value, - cache_execution_spec, - ) - - arrow_table = await asyncio.to_thread(bq_result.batches().to_arrow_table) - managed_table = local_data.ManagedArrowTable.from_pyarrow( - arrow_table, bq_result.schema - ) - self._peek_cache.put(array_value.node, managed_table) - - sliced_table = arrow_table.slice(0, execution_spec.peek) - result: executor.ExecuteResult = executor.LocalExecuteResult( - sliced_table, - bq_result.schema, - execution_metadata=bq_result.execution_metadata, - ) - - await self._publisher.publish_async( - bigframes.core.events.EventEnvelope( - event=bigframes.core.events.ExecutionFinished(result=result), - cell_execution_count=execution_spec.cell_execution_count, - ) - ) - return result - maybe_result = await self._try_execute_semi_executors( array_value, execution_spec ) @@ -480,7 +384,7 @@ def cached( bq_compute_options = ex_spec.BqComputeOptions.from_compute_options( bigframes.options.compute ) - return _run_sync( + return run_sync( self._cached_async( array_value, config=config, compute_options=bq_compute_options ) diff --git a/packages/bigframes/bigframes/session/peek_cache.py b/packages/bigframes/bigframes/session/peek_cache.py index 3d57e6b87298..608559227a1f 100644 --- a/packages/bigframes/bigframes/session/peek_cache.py +++ b/packages/bigframes/bigframes/session/peek_cache.py @@ -14,6 +14,7 @@ from __future__ import annotations +import dataclasses import threading from collections import OrderedDict from typing import Optional @@ -21,20 +22,24 @@ from bigframes.core import local_data, nodes +@dataclasses.dataclass(frozen=True) +class CachedRelation: + table: local_data.ManagedArrowTable + is_complete: bool = False + + class PeekCache: """ - Thread-safe LRU cache for storing local samples of query relations. + Thread-safe LRU cache for storing local samples or complete copies of query relations. This enables fast iteration on subsequent compatible operations. """ def __init__(self, capacity: int = 100): self.capacity = capacity - self._cache: OrderedDict[nodes.BigFrameNode, local_data.ManagedArrowTable] = ( - OrderedDict() - ) + self._cache: OrderedDict[nodes.BigFrameNode, CachedRelation] = OrderedDict() self._lock = threading.Lock() - def get(self, key: nodes.BigFrameNode) -> Optional[local_data.ManagedArrowTable]: + def get(self, key: nodes.BigFrameNode) -> Optional[CachedRelation]: with self._lock: if key not in self._cache: return None @@ -42,8 +47,14 @@ def get(self, key: nodes.BigFrameNode) -> Optional[local_data.ManagedArrowTable] self._cache.move_to_end(key) return self._cache[key] - def put(self, key: nodes.BigFrameNode, value: local_data.ManagedArrowTable) -> None: + def put( + self, + key: nodes.BigFrameNode, + table: local_data.ManagedArrowTable, + is_complete: bool = False, + ) -> None: with self._lock: + value = CachedRelation(table, is_complete) if key in self._cache: self._cache.move_to_end(key) self._cache[key] = value @@ -58,13 +69,14 @@ def clear(self) -> None: def substitute_peek_cached_subplans( root: nodes.BigFrameNode, peek_cache: PeekCache, - min_rows_required: int, + min_rows_required: Optional[int], ) -> nodes.BigFrameNode: """ - Recursively replaces subplans in the tree that have a cached local sample - in the peek cache with a ReadLocalNode, provided that all ancestors - of the subplan are compatible with running on a sample, and the cached - sample contains at least the required number of rows. + Recursively replaces subplans in the tree that have a cached local relation + in the peek cache with a ReadLocalNode, provided that: + 1. The cached relation is complete (contains the entire dataset). + 2. Or, all ancestors of the subplan are compatible with running on a sample, + and the cached sample contains at least the required number of rows. """ # Intermediate nodes that preserve the semantic validity of a sample. # WindowOpNode, AggregateNode, OrderByNode, JoinNode, etc. are excluded @@ -79,13 +91,14 @@ def substitute_peek_cached_subplans( def traverse( node: nodes.BigFrameNode, ancestors_compatible: bool ) -> nodes.BigFrameNode: - if ancestors_compatible: - cached_sample = peek_cache.get(node) - if ( - cached_sample is not None - and cached_sample.data.num_rows >= min_rows_required + cached_entry = peek_cache.get(node) + if cached_entry is not None: + if cached_entry.is_complete or ( + ancestors_compatible + and min_rows_required is not None + and cached_entry.table.data.num_rows >= min_rows_required ): - # Replace the node with a ReadLocalNode containing the cached sample + # Replace the node with a ReadLocalNode containing the cached relation scan_list = nodes.ScanList( tuple( nodes.ScanItem(field.id, field.id.name) for field in node.fields @@ -93,7 +106,7 @@ def traverse( ) session = node.session if node.session is not None else root.session return nodes.ReadLocalNode( - local_data_source=cached_sample, + local_data_source=cached_entry.table, scan_list=scan_list, session=session, ) diff --git a/packages/bigframes/bigframes/session/peek_cache_executor.py b/packages/bigframes/bigframes/session/peek_cache_executor.py new file mode 100644 index 000000000000..3d2231fca6e8 --- /dev/null +++ b/packages/bigframes/bigframes/session/peek_cache_executor.py @@ -0,0 +1,214 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +import dataclasses +from typing import Optional + +import google.cloud.bigquery as bigquery + +import bigframes.core +import bigframes.core.events +from bigframes.core import local_data +from bigframes.session import execution_spec as ex_spec +from bigframes.session import executor +from bigframes.session._async import run_sync + + +class PeekCacheExecutor(executor.Executor): + """ + Decorator executor that implements a peek cache. + + If the execution spec requests a peek and the peek cache is enabled, + it attempts to rewrite the plan to use cached subplans and execute it + locally using the Polars executor. + + Otherwise, it delegates execution to the target executor and caches + the result. + """ + + def __init__( + self, + target: executor.Executor, + publisher: bigframes.core.events.Publisher, + ): + self._target = target + self._publisher = publisher + + from bigframes.session.peek_cache import PeekCache + + self._peek_cache = PeekCache() + + self._polars_executor = None + try: + from bigframes.session.polars_executor import PolarsExecutor + + self._polars_executor = PolarsExecutor() + except ImportError: + # Polars is not installed, so the peek cache shortcut cannot be used. + pass + + def to_sql( + self, + array_value: bigframes.core.ArrayValue, + offset_column: Optional[str] = None, + ordered: bool = False, + enable_cache: bool = True, + ) -> str: + return self._target.to_sql( + array_value, + offset_column=offset_column, + ordered=ordered, + enable_cache=enable_cache, + ) + + def dry_run( + self, array_value: bigframes.core.ArrayValue, ordered: bool = True + ) -> bigquery.QueryJob: + return self._target.dry_run(array_value, ordered=ordered) + + def cached( + self, + array_value: bigframes.core.ArrayValue, + *, + config: executor.CacheConfig, + ) -> None: + return self._target.cached(array_value, config=config) + + def execute( + self, + array_value: bigframes.core.ArrayValue, + execution_spec: ex_spec.ExecutionSpec, + ) -> executor.ExecuteResult: + execution_spec = execution_spec.with_compute_options(bigframes.options.compute) + + enable_peek_cache = ( + execution_spec.bigquery_config.enable_peek_cache + if execution_spec.bigquery_config + else False + ) + + if not enable_peek_cache or self._polars_executor is None: + return self._target.execute(array_value, execution_spec) + + return run_sync(self._execute_async(array_value, execution_spec)) + + async def _execute_async( + self, + array_value: bigframes.core.ArrayValue, + execution_spec: ex_spec.ExecutionSpec, + ) -> executor.ExecuteResult: + await self._publisher.publish_async(bigframes.core.events.ExecutionStarted()) + + from bigframes.session.peek_cache import substitute_peek_cached_subplans + + # 1. Attempt to rewrite the plan using cached subplans from the peek cache. + rewritten_node = substitute_peek_cached_subplans( + array_value.node, + self._peek_cache, + min_rows_required=execution_spec.peek, + ) + if rewritten_node != array_value.node: + # The plan was rewritten! Try to execute the rewritten plan using only the Polars executor. + assert self._polars_executor is not None + maybe_result = await self._polars_executor.execute( + rewritten_node, execution_spec + ) + if maybe_result is not None: + num_rows = maybe_result.batches().approx_total_rows + # If it's a full execution (peek is None), the result is complete because we only substituted complete entries. + # If it's a peek execution, we must ensure we got enough rows. + is_sufficient = execution_spec.peek is None or ( + num_rows is not None and num_rows >= execution_spec.peek + ) + if is_sufficient: + await self._publisher.publish_async( + bigframes.core.events.EventEnvelope( + event=bigframes.core.events.ExecutionFinished( + result=maybe_result, + ), + cell_execution_count=execution_spec.cell_execution_count, + ) + ) + return maybe_result + + # 2. If the shortcut wasn't used or failed, run the query on the target executor. + if execution_spec.peek is not None: + sample_size = ( + execution_spec.bigquery_config.peek_cache_size + if execution_spec.bigquery_config + else 10000 + ) + actual_sample_size = max(execution_spec.peek, sample_size) + cache_execution_spec = dataclasses.replace( + execution_spec, peek=actual_sample_size + ) + else: + cache_execution_spec = execution_spec + + bq_result = await asyncio.to_thread( + self._target.execute, + array_value, + cache_execution_spec, + ) + + # 3. Cache the result if appropriate. + if execution_spec.peek is not None: + # For peek executions, we always download and cache the sample. + arrow_table = await asyncio.to_thread(bq_result.batches().to_arrow_table) + is_complete = arrow_table.num_rows < actual_sample_size + managed_table = local_data.ManagedArrowTable.from_pyarrow( + arrow_table, bq_result.schema + ) + self._peek_cache.put( + array_value.node, managed_table, is_complete=is_complete + ) + + sliced_table = arrow_table.slice(0, execution_spec.peek) + result: executor.ExecuteResult = executor.LocalExecuteResult( + sliced_table, + bq_result.schema, + execution_metadata=bq_result.execution_metadata, + ) + else: + # For full executions, we only cache if the target executor returned a local result. + if isinstance(bq_result, executor.LocalExecuteResult): + peek_cache_size = ( + execution_spec.bigquery_config.peek_cache_size + if execution_spec.bigquery_config + else 10000 + ) + if bq_result._data.data.num_rows > peek_cache_size: + sliced_data = bq_result._data.data.slice(0, peek_cache_size) + managed_table = local_data.ManagedArrowTable.from_pyarrow( + sliced_data, bq_result.schema + ) + is_complete = False + else: + managed_table = bq_result._data + is_complete = True + self._peek_cache.put( + array_value.node, managed_table, is_complete=is_complete + ) + result = bq_result + + await self._publisher.publish_async( + bigframes.core.events.EventEnvelope( + event=bigframes.core.events.ExecutionFinished(result=result), + cell_execution_count=execution_spec.cell_execution_count, + ) + ) + return result diff --git a/packages/bigframes/tests/unit/session/test_peek_cache.py b/packages/bigframes/tests/unit/session/test_peek_cache.py index 5d283e68b05f..222b1df9f485 100644 --- a/packages/bigframes/tests/unit/session/test_peek_cache.py +++ b/packages/bigframes/tests/unit/session/test_peek_cache.py @@ -14,15 +14,13 @@ from __future__ import annotations -import asyncio from unittest import mock -import google.cloud.bigquery import pyarrow as pa import bigframes from bigframes.core import identifiers, local_data, nodes -from bigframes.session import bq_caching_executor, execution_spec, executor +from bigframes.session import execution_spec, executor from bigframes.session.peek_cache import PeekCache, substitute_peek_cached_subplans from bigframes.testing import mocks @@ -48,14 +46,14 @@ def test_peek_cache_lru(): cache.put(node2, ds2) # Access node1 to make it most recently used, leaving node2 as least recently used (LRU) - assert cache.get(node1) == ds1 + assert cache.get(node1).table == ds1 # Put node3, which should evict node2 cache.put(node3, ds3) assert cache.get(node2) is None - assert cache.get(node1) == ds1 - assert cache.get(node3) == ds3 + assert cache.get(node1).table == ds1 + assert cache.get(node3).table == ds3 def test_substitute_peek_cached_subplans(): @@ -89,22 +87,16 @@ def test_substitute_peek_cached_subplans(): def test_executor_peek_cache_integration(): - # Mock all arguments to BigQueryCachingExecutor - bqclient = mock.create_autospec(google.cloud.bigquery.Client, instance=True) - bqclient.project = "test-project" - storage_manager = mock.Mock() - bqstoragereadclient = mock.Mock() - loader = mock.Mock() + from bigframes.session import semi_executor + from bigframes.session.peek_cache_executor import PeekCacheExecutor + + # Mock target executor + mock_target = mock.create_autospec(executor.Executor, instance=True) publisher = mock.AsyncMock() - function_manager = mock.Mock() - - executor_obj = bq_caching_executor.BigQueryCachingExecutor( - bqclient=bqclient, - storage_manager=storage_manager, - bqstoragereadclient=bqstoragereadclient, - loader=loader, - publisher=publisher, - function_manager=function_manager, + + # Mock PolarsExecutor + mock_polars_executor = mock.create_autospec( + semi_executor.SemiExecutor, instance=True ) table = pa.Table.from_pydict({"col": [1, 2, 3, 4, 5]}) @@ -118,25 +110,30 @@ def test_executor_peek_cache_integration(): ) arr_value = bigframes.core.ArrayValue(node) - # Mock _execute_bigquery of the executor to return a mock 3-row table + # Mock target.execute to return a mock 3-row table mock_bq_table = pa.Table.from_pydict({"col": [10, 20, 30]}) mock_bq_result = executor.LocalExecuteResult(mock_bq_table, arr_value.schema) - - execute_bq_mock = mock.AsyncMock(return_value=mock_bq_result) - executor_obj._execute_bigquery = execute_bq_mock + mock_target.execute.return_value = mock_bq_result # Enable peek cache options compute_options = bigframes.options.compute compute_options.enable_peek_cache = True compute_options.peek_cache_size = 3 + # Patch PolarsExecutor + with mock.patch( + "bigframes.session.polars_executor.PolarsExecutor", + return_value=mock_polars_executor, + ): + peek_cache_executor = PeekCacheExecutor(mock_target, publisher=publisher) + # Call execute with peek=1 (cache miss path) spec = execution_spec.ExecutionSpec(peek=1).with_compute_options(compute_options) - result = asyncio.run(executor_obj._execute_async(arr_value, spec)) + result = peek_cache_executor.execute(arr_value, spec) - # Verify BQ was called with peek=3 (cache size) - assert execute_bq_mock.call_count == 1 - called_spec = execute_bq_mock.call_args[0][1] + # Verify target.execute was called with peek=3 (cache size) + assert mock_target.execute.call_count == 1 + called_spec = mock_target.execute.call_args[0][1] assert called_spec.peek == 3 # Verify returned result has exactly 1 row @@ -145,17 +142,34 @@ def test_executor_peek_cache_integration(): assert result_table["col"].to_pylist() == [10] # Verify peek cache has been populated with the 3-row table - cached_entry = executor_obj._peek_cache.get(node) + cached_entry = peek_cache_executor._peek_cache.get(node) assert cached_entry is not None - assert cached_entry.to_pyarrow_table()["col"].to_pylist() == [10, 20, 30] + assert cached_entry.table.to_pyarrow_table()["col"].to_pylist() == [10, 20, 30] + + # Mock polars executor to return a 2-row table on cache hit + mock_polars_table = pa.Table.from_pydict({"col": [10, 20]}) + mock_polars_result = executor.LocalExecuteResult( + mock_polars_table, arr_value.schema + ) + mock_polars_executor.execute = mock.AsyncMock(return_value=mock_polars_result) # Call execute again with peek=2 (cache hit path) - execute_bq_mock.reset_mock() + mock_target.execute.reset_mock() spec2 = execution_spec.ExecutionSpec(peek=2).with_compute_options(compute_options) - result2 = asyncio.run(executor_obj._execute_async(arr_value, spec2)) + result2 = peek_cache_executor.execute(arr_value, spec2) - # Verify BQ was NOT called - assert execute_bq_mock.call_count == 0 + # Verify target.execute was NOT called + assert mock_target.execute.call_count == 0 + + # Verify polars executor was called with the rewritten node + assert mock_polars_executor.execute.call_count == 1 + called_node = mock_polars_executor.execute.call_args[0][0] + assert isinstance(called_node, nodes.ReadLocalNode) + assert called_node.local_data_source.to_pyarrow_table()["col"].to_pylist() == [ + 10, + 20, + 30, + ] # Verify returned result has exactly 2 rows result_table2 = pa.Table.from_batches(result2.batches().arrow_batches) @@ -268,3 +282,196 @@ def test_substitute_peek_cached_subplans_insufficient_rows(): # Request min_rows_required = 3 -> Should NOT substitute (insufficient rows) rewritten_ng = substitute_peek_cached_subplans(leaf, cache, min_rows_required=3) assert rewritten_ng == leaf + + +def test_executor_peek_cache_populates_from_full_execution(): + from bigframes.session.peek_cache_executor import PeekCacheExecutor + + # Mock target executor + mock_target = mock.create_autospec(executor.Executor, instance=True) + publisher = mock.AsyncMock() + + table = pa.Table.from_pydict({"col": [1, 2, 3, 4, 5]}) + ds = local_data.ManagedArrowTable.from_pyarrow(table) + session = mocks.create_bigquery_session() + + node = nodes.ReadLocalNode( + local_data_source=ds, + scan_list=nodes.ScanList((nodes.ScanItem(identifiers.ColumnId("col"), "col"),)), + session=session, + ) + arr_value = bigframes.core.ArrayValue(node) + + # Mock target.execute to return a mock 5-row local result + mock_result = executor.LocalExecuteResult(table, arr_value.schema) + mock_target.execute.return_value = mock_result + + # Enable peek cache options with cache size = 3 (smaller than 5) + compute_options = bigframes.options.compute + compute_options.enable_peek_cache = True + compute_options.peek_cache_size = 3 + + # Mock PolarsExecutor to avoid ImportError if it's not installed in test env, + # though it is not called on full execution path, but PeekCacheExecutor.__init__ tries to instantiate it. + mock_polars_executor = mock.Mock() + with mock.patch( + "bigframes.session.polars_executor.PolarsExecutor", + return_value=mock_polars_executor, + ): + peek_cache_executor = PeekCacheExecutor(mock_target, publisher=publisher) + + # Call execute with peek=None (full execution) + spec = execution_spec.ExecutionSpec(peek=None).with_compute_options(compute_options) + _ = peek_cache_executor.execute(arr_value, spec) + + # Verify target.execute was called + assert mock_target.execute.call_count == 1 + + # Verify peek cache has been populated and sliced to peek_cache_size (3) + cached_entry = peek_cache_executor._peek_cache.get(node) + assert cached_entry is not None + assert not cached_entry.is_complete # Sliced, so not complete + assert cached_entry.table.to_pyarrow_table().num_rows == 3 + assert cached_entry.table.to_pyarrow_table()["col"].to_pylist() == [1, 2, 3] + + # Now test when full result is smaller than peek_cache_size (cache size = 10) + compute_options.peek_cache_size = 10 + with mock.patch( + "bigframes.session.polars_executor.PolarsExecutor", + return_value=mock_polars_executor, + ): + peek_cache_executor2 = PeekCacheExecutor(mock_target, publisher=publisher) + _ = peek_cache_executor2.execute(arr_value, spec) + + cached_entry2 = peek_cache_executor2._peek_cache.get(node) + assert cached_entry2 is not None + assert cached_entry2.is_complete # Fits entirely, so complete! + assert cached_entry2.table.to_pyarrow_table().num_rows == 5 + assert cached_entry2.table.to_pyarrow_table()["col"].to_pylist() == [1, 2, 3, 4, 5] + + +def test_executor_peek_cache_falls_back_on_insufficient_rows(): + from bigframes.session import semi_executor + from bigframes.session.peek_cache_executor import PeekCacheExecutor + + # Mock target executor + mock_target = mock.create_autospec(executor.Executor, instance=True) + publisher = mock.AsyncMock() + + # Mock PolarsExecutor + mock_polars_executor = mock.create_autospec( + semi_executor.SemiExecutor, instance=True + ) + + table = pa.Table.from_pydict({"col": [1, 2, 3, 4, 5]}) + ds = local_data.ManagedArrowTable.from_pyarrow(table) + session = mocks.create_bigquery_session() + + node = nodes.ReadLocalNode( + local_data_source=ds, + scan_list=nodes.ScanList((nodes.ScanItem(identifiers.ColumnId("col"), "col"),)), + session=session, + ) + arr_value = bigframes.core.ArrayValue(node) + + # Mock target.execute to return a mock 5-row BQ result + mock_bq_table = pa.Table.from_pydict({"col": [10, 20, 30, 40, 50]}) + mock_bq_result = executor.LocalExecuteResult(mock_bq_table, arr_value.schema) + mock_target.execute.return_value = mock_bq_result + + # Enable peek cache options + compute_options = bigframes.options.compute + compute_options.enable_peek_cache = True + compute_options.peek_cache_size = 5 + + with mock.patch( + "bigframes.session.polars_executor.PolarsExecutor", + return_value=mock_polars_executor, + ): + peek_cache_executor = PeekCacheExecutor(mock_target, publisher=publisher) + + # Populate the cache first with a different data source object to trigger rewrite + cached_ds = local_data.ManagedArrowTable.from_pyarrow(table) + peek_cache_executor._peek_cache.put(node, cached_ds) + + # Mock polars executor to return a 2-row table (insufficient for peek=5) + mock_polars_table = pa.Table.from_pydict({"col": [1, 2]}) + mock_polars_result = executor.LocalExecuteResult( + mock_polars_table, arr_value.schema + ) + mock_polars_executor.execute = mock.AsyncMock(return_value=mock_polars_result) + + # Call execute with peek=5 (should trigger fallback because 2 < 5) + spec = execution_spec.ExecutionSpec(peek=5).with_compute_options(compute_options) + result = peek_cache_executor.execute(arr_value, spec) + + # Verify polars executor was called + assert mock_polars_executor.execute.call_count == 1 + + # Verify target.execute WAS called (fallback) + assert mock_target.execute.call_count == 1 + called_spec = mock_target.execute.call_args[0][1] + assert called_spec.peek == 5 + + # Verify returned result is the BQ result + result_table = pa.Table.from_batches(result.batches().arrow_batches) + assert result_table.num_rows == 5 + assert result_table["col"].to_pylist() == [10, 20, 30, 40, 50] + + +def test_executor_complete_cache_local_execution(): + from bigframes.session import semi_executor + from bigframes.session.peek_cache_executor import PeekCacheExecutor + + # Mock target executor + mock_target = mock.create_autospec(executor.Executor, instance=True) + publisher = mock.AsyncMock() + + # Mock PolarsExecutor + mock_polars_executor = mock.create_autospec( + semi_executor.SemiExecutor, instance=True + ) + + table = pa.Table.from_pydict({"col": [1, 2, 3, 4, 5]}) + ds = local_data.ManagedArrowTable.from_pyarrow(table) + session = mocks.create_bigquery_session() + + node = nodes.ReadLocalNode( + local_data_source=ds, + scan_list=nodes.ScanList((nodes.ScanItem(identifiers.ColumnId("col"), "col"),)), + session=session, + ) + arr_value = bigframes.core.ArrayValue(node) + + # Enable peek cache options + compute_options = bigframes.options.compute + compute_options.enable_peek_cache = True + + with mock.patch( + "bigframes.session.polars_executor.PolarsExecutor", + return_value=mock_polars_executor, + ): + peek_cache_executor = PeekCacheExecutor(mock_target, publisher=publisher) + + # Populate the cache first as COMPLETE + cached_ds = local_data.ManagedArrowTable.from_pyarrow(table) + peek_cache_executor._peek_cache.put(node, cached_ds, is_complete=True) + + # Mock polars executor to return the full 5-row table + mock_polars_result = executor.LocalExecuteResult(table, arr_value.schema) + mock_polars_executor.execute = mock.AsyncMock(return_value=mock_polars_result) + + # Call execute with peek=None (full execution!) + spec = execution_spec.ExecutionSpec(peek=None).with_compute_options(compute_options) + result = peek_cache_executor.execute(arr_value, spec) + + # Verify polars executor was called + assert mock_polars_executor.execute.call_count == 1 + + # Verify target.execute was NOT called (no BQ query!) + assert mock_target.execute.call_count == 0 + + # Verify returned result is the local polars result + result_table = pa.Table.from_batches(result.batches().arrow_batches) + assert result_table.num_rows == 5 + assert result_table["col"].to_pylist() == [1, 2, 3, 4, 5]