diff --git a/paimon-python/pypaimon/globalindex/batch_vector_search.py b/paimon-python/pypaimon/globalindex/batch_vector_search.py new file mode 100644 index 000000000000..c5aae32ce905 --- /dev/null +++ b/paimon-python/pypaimon/globalindex/batch_vector_search.py @@ -0,0 +1,95 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""BatchVectorSearch for performing batch vector similarity search.""" + +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Union + +import numpy as np + +from pypaimon.globalindex.vector_search import VectorSearch + + +@dataclass +class BatchVectorSearch: + """Batch vector search over multiple query vectors; result ``i`` maps to ``vectors[i]``.""" + + vectors: List[Union[List[float], np.ndarray]] + limit: int + field_name: str + include_row_ids: Optional['RoaringBitmap64'] = field(default=None) + options: Optional[Dict[str, str]] = field(default=None) + + def __post_init__(self): + if not self.vectors: + raise ValueError("Search vectors cannot be empty") + if self.limit <= 0: + raise ValueError(f"Limit must be positive, got: {self.limit}") + if not self.field_name: + raise ValueError("Field name cannot be null or empty") + # Match VectorSearch: list vectors -> float32. + self.vectors = [ + np.array(v, dtype=np.float32) if isinstance(v, list) else v + for v in self.vectors + ] + self.options = {} if self.options is None else dict(self.options) + + @property + def vector_count(self) -> int: + return len(self.vectors) + + def for_index(self, i: int) -> VectorSearch: + """Return the single VectorSearch for query vector ``i``.""" + return VectorSearch( + vector=self.vectors[i], + limit=self.limit, + field_name=self.field_name, + include_row_ids=self.include_row_ids, + options=self.options, + ) + + def with_include_row_ids(self, include_row_ids: 'RoaringBitmap64') -> 'BatchVectorSearch': + return BatchVectorSearch( + vectors=self.vectors, + limit=self.limit, + field_name=self.field_name, + include_row_ids=include_row_ids, + options=self.options, + ) + + def offset_range(self, from_: int, to: int) -> 'BatchVectorSearch': + """Offset include_row_ids into the given range; vectors are shared by all queries.""" + if self.include_row_ids is None: + return self + from pypaimon.utils.roaring_bitmap import RoaringBitmap64 + + range_bitmap = RoaringBitmap64() + range_bitmap.add_range(from_, to) + and_result = RoaringBitmap64.and_(range_bitmap, self.include_row_ids) + offset_bitmap = RoaringBitmap64() + # Per-element shift (RoaringBitmap64 has no bulk translate yet). + for row_id in and_result: + offset_bitmap.add(row_id - from_) + return self.with_include_row_ids(offset_bitmap) + + def visit(self, visitor: 'GlobalIndexReader') -> 'Future[List[Optional[GlobalIndexResult]]]': + return visitor.visit_batch_vector_search(self) + + def __repr__(self) -> str: + return (f"BatchVectorSearch(field_name={self.field_name}, " + f"limit={self.limit}, vector_count={self.vector_count})") diff --git a/paimon-python/pypaimon/globalindex/global_index_reader.py b/paimon-python/pypaimon/globalindex/global_index_reader.py index 1ac0905bb66e..d3915583cabc 100644 --- a/paimon-python/pypaimon/globalindex/global_index_reader.py +++ b/paimon-python/pypaimon/globalindex/global_index_reader.py @@ -58,6 +58,18 @@ class GlobalIndexReader(ABC): def visit_vector_search(self, vector_search: 'VectorSearch') -> 'Future[Optional[GlobalIndexResult]]': raise NotImplementedError("Vector search not supported by this reader") + def visit_batch_vector_search( + self, batch_vector_search: 'BatchVectorSearch' + ) -> 'Future[List[Optional[GlobalIndexResult]]]': + """Default: fan out to single-vector search; result ``i`` maps to ``vectors[i]``. + + Blocks per future (fine while readers return completed futures); an + async reader should override. + """ + singles = [self.visit_vector_search(batch_vector_search.for_index(i)) + for i in range(batch_vector_search.vector_count)] + return _completed_future([f.result() for f in singles]) + def visit_full_text_search(self, full_text_search: 'FullTextSearch') -> 'Future[Optional[GlobalIndexResult]]': raise NotImplementedError("Full-text search not supported by this reader") diff --git a/paimon-python/pypaimon/globalindex/lumina/lumina_vector_global_index_reader.py b/paimon-python/pypaimon/globalindex/lumina/lumina_vector_global_index_reader.py index cabd4919111f..2abba80617ea 100644 --- a/paimon-python/pypaimon/globalindex/lumina/lumina_vector_global_index_reader.py +++ b/paimon-python/pypaimon/globalindex/lumina/lumina_vector_global_index_reader.py @@ -50,6 +50,21 @@ def _merge_options(base_options, index_options, query_options): return options +def _collect_scored_result(distances, labels, base, k, index_metric): + """Convert one query's [base, base+k) slice of distances/labels into a result.""" + from lumina_data import MetricType + + SENTINEL = 0xFFFFFFFFFFFFFFFF + id_to_scores = {} + for i in range(k): + row_id = labels[base + i] + if row_id == SENTINEL: + continue + id_to_scores[int(row_id)] = MetricType.convert_distance_to_score( + float(distances[base + i]), index_metric) + return DictBasedScoredIndexResult(id_to_scores) + + class LuminaVectorGlobalIndexReader(GlobalIndexReader): """Vector global index reader using Lumina.""" @@ -66,57 +81,69 @@ def __init__(self, file_io, index_path, io_metas, options=None): self._load_lock = threading.Lock() def visit_vector_search(self, vector_search): + # Single-vector search is just the n == 1 case of the batch path. + results = self._run_search( + [vector_search.vector], + vector_search.limit, + vector_search.include_row_ids, + vector_search.options, + ) + return _completed_future(results[0]) + + def visit_batch_vector_search(self, batch_vector_search): + results = self._run_search( + batch_vector_search.vectors, + batch_vector_search.limit, + batch_vector_search.include_row_ids, + batch_vector_search.options, + ) + return _completed_future(results) + + def _run_search(self, vectors, limit, include_row_ids, query_options): + """Run one native batch search; result ``i`` maps to ``vectors[i]`` (``None`` if + no hits). Single search is the n == 1 case, shared by both visit paths. + """ self._ensure_loaded() - from lumina_data import MetricType - query_flat = [float(v) for v in np.asarray(vector_search.vector).tolist()] + n = len(vectors) expected_dim = self._index_meta.dim - if len(query_flat) != expected_dim: - raise ValueError( - "Query vector dimension mismatch: expected %d, got %d" - % (expected_dim, len(query_flat))) + query_flat = [] + for vector in vectors: + flat = [float(v) for v in np.asarray(vector).tolist()] + if len(flat) != expected_dim: + raise ValueError( + "Query vector dimension mismatch: expected %d, got %d" + % (expected_dim, len(flat))) + query_flat.extend(flat) - limit = vector_search.limit index_metric = self._index_meta.metric - count = self._searcher.get_count() effective_k = min(limit, count) if effective_k <= 0: - return _completed_future(None) - - include_row_ids = vector_search.include_row_ids - query_options = vector_search.options + return [None] * n if include_row_ids is not None: filter_id_list = list(include_row_ids) if len(filter_id_list) == 0: - return _completed_future(None) + return [None] * n effective_k = min(effective_k, len(filter_id_list)) - search_opts = _merge_options( - self._options, {}, query_options) + search_opts = _merge_options(self._options, {}, query_options) search_opts["search.thread_safe_filter"] = "true" _ensure_search_list_size(search_opts, effective_k) distances, labels = self._searcher.search_with_filter_list( - query_flat, 1, effective_k, filter_id_list, search_opts) + query_flat, n, effective_k, filter_id_list, search_opts) else: - search_opts = _merge_options( - self._options, {}, query_options) + search_opts = _merge_options(self._options, {}, query_options) _ensure_search_list_size(search_opts, effective_k) distances, labels = self._searcher.search_list( - query_flat, 1, effective_k, search_opts) - - # Collect results with score conversion (same as Java collectResults) - SENTINEL = 0xFFFFFFFFFFFFFFFF - id_to_scores = {} - for i in range(effective_k): - row_id = labels[i] - if row_id == SENTINEL: - continue - score = MetricType.convert_distance_to_score( - float(distances[i]), index_metric) - id_to_scores[int(row_id)] = score - - return _completed_future(DictBasedScoredIndexResult(id_to_scores)) + query_flat, n, effective_k, search_opts) + + # Each query's results occupy a contiguous [q * k, q * k + k) slice. + return [ + _collect_scored_result( + distances, labels, q * effective_k, effective_k, index_metric) + for q in range(n) + ] def _ensure_loaded(self): if self._searcher is not None: diff --git a/paimon-python/pypaimon/globalindex/offset_global_index_reader.py b/paimon-python/pypaimon/globalindex/offset_global_index_reader.py index a0965ed20558..f13145ff5079 100644 --- a/paimon-python/pypaimon/globalindex/offset_global_index_reader.py +++ b/paimon-python/pypaimon/globalindex/offset_global_index_reader.py @@ -48,6 +48,17 @@ def visit_vector_search(self, vector_search) -> 'Future[Optional[GlobalIndexResu self._wrapped.visit_vector_search( vector_search.offset_range(self._offset, self._to))) + def visit_batch_vector_search( + self, batch_vector_search) -> 'Future[List[Optional[GlobalIndexResult]]]': + source = self._wrapped.visit_batch_vector_search( + batch_vector_search.offset_range(self._offset, self._to)) + + def transform(results): + return [r.offset(self._offset) if r is not None else None + for r in results] + + return _map_future(source, transform) + def visit_full_text_search(self, full_text_search) -> 'Future[Optional[GlobalIndexResult]]': return self._apply_offset_future( self._wrapped.visit_full_text_search(full_text_search)) diff --git a/paimon-python/pypaimon/table/source/vector_search_read.py b/paimon-python/pypaimon/table/source/vector_search_read.py index 965630807172..faca1b71b991 100644 --- a/paimon-python/pypaimon/table/source/vector_search_read.py +++ b/paimon-python/pypaimon/table/source/vector_search_read.py @@ -20,6 +20,7 @@ from abc import ABC, abstractmethod from concurrent.futures import wait +from pypaimon.globalindex.batch_vector_search import BatchVectorSearch from pypaimon.globalindex.global_index_meta import GlobalIndexIOMeta from pypaimon.globalindex.global_index_result import GlobalIndexResult from pypaimon.globalindex.offset_global_index_reader import OffsetGlobalIndexReader @@ -168,12 +169,11 @@ def _raw_pre_filter(self, splits): finally: scanner.close() - def _eval(self, row_range_start, row_range_end, vector_index_files, - query_vector, include_row_ids): - from pypaimon.globalindex.global_index_reader import _completed_future + def _open_offset_reader(self, vector_index_files, row_range_start, row_range_end): + """Open a vector index reader for the split, wrapped with the row-id offset. - if not vector_index_files: - return _completed_future(None) + The caller must close the returned reader once its future completes. + """ index_io_meta_list = [] for index_file in vector_index_files: meta = index_file.global_index_meta @@ -187,10 +187,21 @@ def _eval(self, row_range_start, row_range_end, vector_index_files, ) ) - index_type = vector_index_files[0].index_type - index_path = self._table.path_factory().global_index_path_factory().index_path() - file_io = self._table.file_io - options = self._table.table_schema.options + reader = _create_vector_reader( + vector_index_files[0].index_type, + self._table.file_io, + self._table.path_factory().global_index_path_factory().index_path(), + index_io_meta_list, + self._table.table_schema.options, + ) + return reader, OffsetGlobalIndexReader(reader, row_range_start, row_range_end) + + def _eval(self, row_range_start, row_range_end, vector_index_files, + query_vector, include_row_ids): + from pypaimon.globalindex.global_index_reader import _completed_future + + if not vector_index_files: + return _completed_future(None) vector_search = VectorSearch( vector=query_vector, @@ -201,11 +212,8 @@ def _eval(self, row_range_start, row_range_end, vector_index_files, if include_row_ids is not None: vector_search = vector_search.with_include_row_ids(include_row_ids) - reader = _create_vector_reader( - index_type, file_io, index_path, - index_io_meta_list, options - ) - offset_reader = OffsetGlobalIndexReader(reader, row_range_start, row_range_end) + reader, offset_reader = self._open_offset_reader( + vector_index_files, row_range_start, row_range_end) future = offset_reader.visit_vector_search(vector_search) future.add_done_callback(lambda _: reader.close()) return future @@ -253,6 +261,28 @@ def _read_raw_search(self, raw_row_ranges, pre_filter, query_vector, index_type= scores[row_id] = _compute_score(query_vector, stored_vector, metric) return DictBasedScoredIndexResult(scores).top_k(self._limit) + def _eval_batch(self, row_range_start, row_range_end, vector_index_files, + query_vectors, include_row_ids): + from pypaimon.globalindex.global_index_reader import _completed_future + + if not vector_index_files: + return _completed_future([None] * len(query_vectors)) + + batch_vector_search = BatchVectorSearch( + vectors=query_vectors, + limit=self._limit, + field_name=self._vector_column.name, + options=self._options, + ) + if include_row_ids is not None: + batch_vector_search = batch_vector_search.with_include_row_ids(include_row_ids) + + reader, offset_reader = self._open_offset_reader( + vector_index_files, row_range_start, row_range_end) + future = offset_reader.visit_batch_vector_search(batch_vector_search) + future.add_done_callback(lambda _: reader.close()) + return future + class VectorSearchReadImpl(AbstractVectorSearchReadImpl, VectorSearchRead): """Implementation for VectorSearchRead.""" @@ -329,40 +359,42 @@ def read_batch(self, splits): if not index_splits and not raw_splits: return [GlobalIndexResult.create_empty() for _ in range(n)] + # One native batch call per INDEX split (all query vectors at once), + # passing that split's pre-filter. Each future returns n per-query results. pre_filters = self._pre_filters(index_splits) - futures_by_vector = [ - [ - self._eval( - split.row_range_start, split.row_range_end, - split.vector_index_files, - vector, - None if not pre_filters else pre_filters[i] - ) - for i, split in enumerate(index_splits) - ] - for vector in self._query_vectors + futures = [ + self._eval_batch( + split.row_range_start, split.row_range_end, + split.vector_index_files, self._query_vectors, + None if not pre_filters else pre_filters[i], + ) + for i, split in enumerate(index_splits) ] - for futures in futures_by_vector: - wait(futures) + wait(futures) - results = [] + # Merge each query vector's indexed results across index splits. + merged_scores = [{} for _ in range(n)] + for future in futures: + split_results = future.result() + for i in range(n): + split_result = split_results[i] + if split_result is None: + continue + score_getter = split_result.score_getter() + for row_id in split_result.results(): + if row_id not in merged_scores[i]: + merged_scores[i][row_id] = score_getter(row_id) + + # Each query: merge indexed results with the raw (brute-force) fallback. raw_pre_filter = self._raw_pre_filter(raw_splits) raw_ranges = _raw_row_ranges(raw_splits) raw_index_type = _raw_search_index_type(raw_splits) - for futures in futures_by_vector: - merged_scores = {} - for future in futures: - split_result = future.result() - if split_result is not None: - score_getter = split_result.score_getter() - for row_id in split_result.results(): - if row_id not in merged_scores: - merged_scores[row_id] = score_getter(row_id) - indexed = DictBasedScoredIndexResult(merged_scores) - vector = self._query_vectors[len(results)] + results = [] + for i in range(n): + indexed = DictBasedScoredIndexResult(merged_scores[i]) raw = self._read_raw_search( - raw_ranges, raw_pre_filter, vector, raw_index_type) + raw_ranges, raw_pre_filter, self._query_vectors[i], raw_index_type) results.append(indexed.or_(raw).top_k(self._limit)) return results diff --git a/paimon-python/pypaimon/tests/lumina_vector_index_test.py b/paimon-python/pypaimon/tests/lumina_vector_index_test.py index 5a27af20c80a..3c2ae7eec2c1 100644 --- a/paimon-python/pypaimon/tests/lumina_vector_index_test.py +++ b/paimon-python/pypaimon/tests/lumina_vector_index_test.py @@ -27,6 +27,7 @@ lumina_data = pytest.importorskip("lumina_data") from lumina_data import LuminaBuilder +from pypaimon.globalindex.batch_vector_search import BatchVectorSearch from pypaimon.globalindex.global_index_meta import GlobalIndexIOMeta from pypaimon.globalindex.lumina.lumina_index_meta import LuminaIndexMeta from pypaimon.globalindex.lumina.lumina_vector_global_index_reader import ( @@ -165,3 +166,91 @@ def test_filtered_search(self): reader.close() finally: shutil.rmtree(tmp_dir, ignore_errors=True) + + def test_batch_matches_single(self): + """Native batch search must return, per query, the same result as a single search.""" + dim, n = 8, 200 + + paimon_options = { + "lumina.index.dimension": str(dim), + "lumina.index.type": "diskann", + "lumina.distance.metric": "l2", + "lumina.encoding.type": "rawf32", + "lumina.diskann.build.ef_construction": "64", + "lumina.diskann.build.neighbor_count": "32", + "lumina.diskann.build.thread_count": "2", + } + + build_options = strip_lumina_options(paimon_options) + vectors, ids, raw = _make_vectors(n, dim, seed=123) + + tmp_dir = tempfile.mkdtemp(prefix="paimon_lumina_test_") + file_name = "lumina-batch-0.index" + index_file = os.path.join(tmp_dir, file_name) + + def scores(result): + if result is None: + return {} + getter = result.score_getter() + return {row_id: getter(row_id) for row_id in result.results()} + + query_vectors = [raw[i * dim:(i + 1) * dim] for i in (0, 3, 7, 50, 123)] + limit = 5 + + try: + with LuminaBuilder(build_options) as builder: + builder.pretrain(vectors, n, dim) + builder.insert(vectors, ids, n, dim) + builder.dump(index_file) + + meta = LuminaIndexMeta(build_options) + io_meta = GlobalIndexIOMeta( + file_name=file_name, + file_size=os.path.getsize(index_file), + metadata=meta.serialize(), + ) + + include_ids = RoaringBitmap64() + for i in (0, 3, 7, 50, 123, 10, 11, 12): + include_ids.add(i) + + def assert_batch_matches_single(reader, batch_results, include_row_ids): + for i, query_vector in enumerate(query_vectors): + vs = VectorSearch( + vector=query_vector, limit=limit, field_name="embedding") + if include_row_ids is not None: + vs = vs.with_include_row_ids(include_row_ids) + single = scores(reader.visit_vector_search(vs).result()) + batch = scores(batch_results[i]) + self.assertEqual(set(batch), set(single)) + for row_id in batch: + self.assertAlmostEqual( + batch[row_id], single[row_id], places=5) + + with LuminaVectorGlobalIndexReader( + file_io=_SimpleFileIO(), + index_path=tmp_dir, + io_metas=[io_meta], + options=paimon_options, + ) as reader: + # Unfiltered: each query's batch slice must equal its single search. + unfiltered = reader.visit_batch_vector_search( + BatchVectorSearch( + vectors=query_vectors, limit=limit, field_name="embedding") + ).result() + assert_batch_matches_single(reader, unfiltered, None) + + # Distinct queries must not collapse to one identical result set, + # which would hide a wrong per-query slice. + top_sets = {tuple(sorted(scores(r))) for r in unfiltered} + self.assertGreater(len(top_sets), 1) + + # Filtered: a single include_row_ids bitmap shared by all queries. + filtered = reader.visit_batch_vector_search( + BatchVectorSearch( + vectors=query_vectors, limit=limit, field_name="embedding") + .with_include_row_ids(include_ids) + ).result() + assert_batch_matches_single(reader, filtered, include_ids) + finally: + shutil.rmtree(tmp_dir, ignore_errors=True) diff --git a/paimon-python/pypaimon/tests/vector_search_filter_test.py b/paimon-python/pypaimon/tests/vector_search_filter_test.py index 0e29e5c1ad04..63f27748feb9 100644 --- a/paimon-python/pypaimon/tests/vector_search_filter_test.py +++ b/paimon-python/pypaimon/tests/vector_search_filter_test.py @@ -2126,6 +2126,60 @@ def close(self_inner): raw_read.assert_called_once() self.assertEqual([1, 8], sorted(list(result.results()))) + def test_batch_merges_raw_search_results(self): + from pypaimon.globalindex.global_index_reader import GlobalIndexReader + from pypaimon.globalindex.vector_search_result import ( + DictBasedScoredIndexResult, + ) + from pypaimon.table.source.vector_search_read import ( + BatchVectorSearchReadImpl, + ) + from pypaimon.table.source.vector_search_split import ( + IndexVectorSearchSplit, + RawVectorSearchSplit, + ) + + embedding_field = _field(1, "embedding", "FLOAT") + entry = _entry(None, field_id=1, index_type="lumina-vector-ann", + file_name="vec.index", + row_range_start=0, row_range_end=4) + table = _StubTable(fields=[embedding_field], entries=[entry]) + + def _fake_create(index_type, file_io, index_path, + index_io_meta_list, options=None): + class _FakeReader(GlobalIndexReader): + def visit_vector_search(self_inner, vs): + row_id = int(vs.vector[0]) + return _completed_future( + DictBasedScoredIndexResult({row_id: 0.5})) + + def close(self_inner): + pass + return _FakeReader() + + split = IndexVectorSearchSplit( + row_range_start=0, + row_range_end=4, + vector_index_files=[entry.index_file], + ) + raw = RawVectorSearchSplit([Range(5, 9)], [], "lumina-vector-ann") + + with mock.patch( + "pypaimon.table.source.vector_search_read._create_vector_reader", + side_effect=_fake_create): + reader = BatchVectorSearchReadImpl( + table, limit=5, vector_column=embedding_field, + query_vectors=[[1.0], [2.0]], filter_=None) + with mock.patch.object( + reader, "_read_raw_search", + return_value=DictBasedScoredIndexResult({8: 0.9})) as raw_read: + results = reader.read_batch([split, raw]) + + # The raw fallback must be merged into EACH query, not dropped. + self.assertEqual(2, raw_read.call_count) + self.assertEqual([1, 8], sorted(list(results[0].results()))) + self.assertEqual([2, 8], sorted(list(results[1].results()))) + def test_read_uses_empty_index_prefilter_when_scalar_index_missing(self): from pypaimon.table.source.vector_search_read import VectorSearchReadImpl from pypaimon.table.source.vector_search_split import IndexVectorSearchSplit @@ -2341,6 +2395,7 @@ class BatchVectorSearchTest(unittest.TestCase): """Batch vector search returns one result per query vector, in input order.""" def test_batch_returns_per_query_results_in_order(self): + from pypaimon.globalindex.global_index_reader import GlobalIndexReader from pypaimon.globalindex.vector_search_result import ( DictBasedScoredIndexResult, ) @@ -2355,11 +2410,12 @@ def test_batch_returns_per_query_results_in_order(self): table = _StubTable(fields=[embedding_field], entries=[entry]) _patch_snapshot(self, [entry]) - # The fake reader routes each query vector to a distinct row id derived - # from the vector itself, so result i must reflect query_vectors[i]. + # A reader implementing only single search exercises the default batch + # fan-out; it routes each query vector to a row id derived from the + # vector itself, so result i must reflect query_vectors[i]. def _fake_create(index_type, file_io, index_path, index_io_meta_list, options=None): - class _FakeReader: + class _FakeReader(GlobalIndexReader): def visit_vector_search(self_inner, vs): row_id = int(vs.vector[0]) return _completed_future( @@ -2367,12 +2423,6 @@ def visit_vector_search(self_inner, vs): def close(self_inner): pass - - def __enter__(self_inner): - return self_inner - - def __exit__(self_inner, *a): - return False return _FakeReader() query_vectors = [[10.0], [20.0], [30.0]] @@ -2396,6 +2446,63 @@ def __exit__(self_inner, *a): self.assertNotEqual( list(results[0].results()), list(results[1].results())) + def test_batch_uses_reader_native_batch_when_available(self): + from pypaimon.globalindex.global_index_reader import GlobalIndexReader + from pypaimon.globalindex.vector_search_result import ( + DictBasedScoredIndexResult, + ) + from pypaimon.table.source.batch_vector_search_builder import ( + BatchVectorSearchBuilderImpl, + ) + + embedding_field = _field(1, "embedding", "FLOAT") + entry = _entry(None, field_id=1, index_type="lumina-vector-ann", + file_name="vec.index", + row_range_start=0, row_range_end=99) + table = _StubTable(fields=[embedding_field], entries=[entry]) + _patch_snapshot(self, [entry]) + + calls = {"single": 0, "batch": 0} + + # A reader that implements native batch must be driven through one + # batch call per split, not per-vector single calls. + def _fake_create(index_type, file_io, index_path, + index_io_meta_list, options=None): + class _FakeReader(GlobalIndexReader): + def visit_vector_search(self_inner, vs): + calls["single"] += 1 + return _completed_future( + DictBasedScoredIndexResult({int(vs.vector[0]): 1.0})) + + def visit_batch_vector_search(self_inner, bvs): + calls["batch"] += 1 + return _completed_future([ + DictBasedScoredIndexResult({int(bvs.vectors[i][0]): 1.0}) + for i in range(bvs.vector_count) + ]) + + def close(self_inner): + pass + return _FakeReader() + + query_vectors = [[10.0], [20.0], [30.0]] + with mock.patch( + "pypaimon.table.source.vector_search_read._create_vector_reader", + side_effect=_fake_create): + results = ( + BatchVectorSearchBuilderImpl(table) + .with_vector_column("embedding") + .with_query_vectors(query_vectors) + .with_limit(5) + .execute_batch_local() + ) + + self.assertEqual(calls["batch"], 1) + self.assertEqual(calls["single"], 0) + self.assertEqual(len(results), len(query_vectors)) + for i, query_vector in enumerate(query_vectors): + self.assertTrue(results[i].results().contains(int(query_vector[0]))) + def test_batch_empty_splits_returns_empty_per_query(self): from pypaimon.table.source.batch_vector_search_builder import ( BatchVectorSearchBuilderImpl,