diff --git a/docs/generated/core_configuration.html b/docs/generated/core_configuration.html index 9d1a9624633c..a1ef1427f4c9 100644 --- a/docs/generated/core_configuration.html +++ b/docs/generated/core_configuration.html @@ -909,6 +909,24 @@ Integer Level threshold of lookup to generate remote lookup files. Level files below this threshold will not generate remote lookup files. + +
manifest-sort.enabled
+ false + Boolean + Whether to invoke manifest sort rewrite during commit.
Note: enabling this changes the semantics of 'manifest.merge-min-count'. In the sort rewrite path, small manifest files within the rewrite budget are sorted and merged directly, so the minimum-count gate no longer prevents merging a small number of under-budget manifest files when full compaction is not triggered. + + +
manifest-sort.max-rewrite-size
+ 256 mb + MemorySize + Maximum total size of manifest files to rewrite in a single sort rewrite pass. Sections exceeding this limit are skipped. Set to a larger value to allow more aggressive sort rewriting. The cap only limits the sorted rewrite portion and full/minor cleanup may still happen beyond it. + + +
manifest-sort.partition-field
+ (none) + String + Partition field name to sort manifest entries by. Validated by schema validation, if not configured, defaults to the first partition field. +
manifest.compression
"zstd" @@ -939,24 +957,6 @@ Integer To avoid frequent manifest merges, this parameter specifies the minimum number of ManifestFileMeta to merge.
Note: when 'manifest-sort.enabled' is true, this minimum-count gate is only applied to the trailing sub-segment of a section that exceeds 'manifest-sort.max-rewrite-size'. Small under-budget sections are sorted and rewritten directly, so two small manifest files may be merged into one even when their count is below this threshold and full compaction is not triggered. - -
manifest-sort.enabled
- false - Boolean - Whether to invoke manifest sort rewrite during commit.
Note: enabling this changes the semantics of 'manifest.merge-min-count'. In the sort rewrite path, small manifest files within the rewrite budget are sorted and merged directly, so the minimum-count gate no longer prevents merging a small number of under-budget manifest files when full compaction is not triggered. - - -
manifest-sort.partition-field
- (none) - String - Partition field name to sort manifest entries by. Validated by schema validation, if not configured, defaults to the first partition field. - - -
manifest-sort.max-rewrite-size
- 256 mb - MemorySize - Maximum total size of manifest files to rewrite in a single sort rewrite pass. Sections exceeding this limit are skipped. Set to a larger value to allow more aggressive sort rewriting. The cap only limits the sorted rewrite portion and full/minor cleanup may still happen beyond it. -
manifest.target-file-size
8 mb @@ -1662,6 +1662,12 @@ Boolean Whether to process distributed vector search. + +
vector-search.lateral-join.batch-size
+ 256 + Integer + The batch size for lateral vector search. Each batch executes vector topK search and table lookup for multiple query vectors. +
vector.file.format
(none) @@ -1728,12 +1734,6 @@ Boolean If set to true, compactions and snapshot expiration will be skipped. This option is used along with dedicated compact jobs. - -
write.sequence-number-init-mode
- scan -

Enum

- Specify how to initialize the next sequence number for primary key table writers.

Possible values: -
write.batch-memory
128 mb @@ -1746,6 +1746,12 @@ Integer Write batch size for any file format if it supports. + +
write.sequence-number-init-mode
+ scan +

Enum

+ Specify how to initialize the next sequence number for primary key table writers.

Possible values: +
zorder.var-length-contribution
8 diff --git a/paimon-api/src/main/java/org/apache/paimon/CoreOptions.java b/paimon-api/src/main/java/org/apache/paimon/CoreOptions.java index bf85dd66dc85..58cffaf3e187 100644 --- a/paimon-api/src/main/java/org/apache/paimon/CoreOptions.java +++ b/paimon-api/src/main/java/org/apache/paimon/CoreOptions.java @@ -2627,6 +2627,14 @@ public InlineElement getDescription() { .defaultValue(false) .withDescription("Whether to process distributed vector search."); + public static final ConfigOption VECTOR_SEARCH_LATERAL_JOIN_BATCH_SIZE = + key("vector-search.lateral-join.batch-size") + .intType() + .defaultValue(256) + .withDescription( + "The batch size for lateral vector search. Each batch executes vector " + + "topK search and table lookup for multiple query vectors."); + @Immutable public static final ConfigOption PK_CLUSTERING_OVERRIDE = key("pk-clustering-override") @@ -4120,6 +4128,10 @@ public boolean vectorSearchDistributeEnabled() { return options.get(VECTOR_SEARCH_DISTRIBUTE_ENABLED); } + public int vectorSearchLateralJoinBatchSize() { + return options.get(VECTOR_SEARCH_LATERAL_JOIN_BATCH_SIZE); + } + /** Specifies the merge engine for table with primary key. */ public enum MergeEngine implements DescribedEnum { DEDUPLICATE("deduplicate", "De-duplicate and keep the last row."), diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonIncompatibleResolutionRules.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonIncompatibleResolutionRules.scala index 9824597c1ef3..1fafb7e6a037 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonIncompatibleResolutionRules.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonIncompatibleResolutionRules.scala @@ -21,6 +21,8 @@ package org.apache.paimon.spark.catalyst.analysis import org.apache.paimon.spark.catalyst.plans.logical.{PaimonTableValuedFunctions, PaimonTableValueFunction} import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.LateralSubquery +import org.apache.spark.sql.catalyst.plans.logical.LateralJoin import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule @@ -32,6 +34,12 @@ case class PaimonIncompatibleResolutionRules(session: SparkSession) extends Rule case func: PaimonTableValueFunction if func.args.forall(_.resolved) => PaimonTableValuedFunctions.resolvePaimonTableValuedFunction(session, func) + case LateralJoin(left, lateralSubquery: LateralSubquery, joinType, condition) + if left.resolved && lateralSubquery.plan.resolved => + PaimonTableValuedFunctions + .resolveLateralVectorSearch(left, lateralSubquery.plan, joinType, condition) + .getOrElse(plan) + } } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/optimizer/PushDownLateralVectorSearchFilter.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/optimizer/PushDownLateralVectorSearchFilter.scala new file mode 100644 index 000000000000..678c95fe4b4c --- /dev/null +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/optimizer/PushDownLateralVectorSearchFilter.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.catalyst.optimizer + +import org.apache.paimon.spark.catalyst.plans.logical.LateralVectorSearch + +import org.apache.spark.sql.catalyst.expressions.{And, PredicateHelper} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule + +/** Pushes filters on the query side below lateral vector search. */ +object PushDownLateralVectorSearchFilter extends Rule[LogicalPlan] with PredicateHelper { + + override def apply(plan: LogicalPlan): LogicalPlan = plan.transform { + case filter @ Filter(condition, lvs: LateralVectorSearch) => + val predicates = splitConjunctivePredicates(condition) + val (pushDownToLeft, otherPredicates) = predicates.partition { + predicate => predicate.deterministic && predicate.references.subsetOf(lvs.child.outputSet) + } + val (pushDownToSearch, stayUp) = otherPredicates.partition { + predicate => + predicate.deterministic && + predicate.references.nonEmpty && + predicate.references.subsetOf(lvs.searchFilterOutputSet) + } + + if (pushDownToLeft.isEmpty && pushDownToSearch.isEmpty) { + filter + } else { + val lvsWithPushedLeft = if (pushDownToLeft.isEmpty) { + lvs + } else { + lvs.copy(left = Filter(buildBalancedPredicate(pushDownToLeft, And), lvs.child)) + } + val newLateralVectorSearch = if (pushDownToSearch.isEmpty) { + lvsWithPushedLeft + } else { + lvsWithPushedLeft.copy( + searchFilters = lvsWithPushedLeft.searchFilters ++ pushDownToSearch) + } + if (stayUp.isEmpty) { + newLateralVectorSearch + } else { + Filter(buildBalancedPredicate(stayUp, And), newLateralVectorSearch) + } + } + } +} diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/plans/logical/PaimonTableValuedFunctions.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/plans/logical/PaimonTableValuedFunctions.scala index bf3e53bf2f1d..17a4f15e0c61 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/plans/logical/PaimonTableValuedFunctions.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/plans/logical/PaimonTableValuedFunctions.scala @@ -21,18 +21,20 @@ package org.apache.paimon.spark.catalyst.plans.logical import org.apache.paimon.CoreOptions import org.apache.paimon.globalindex.HybridSearchRanker import org.apache.paimon.predicate.{FullTextQuery, FullTextSearch, HybridSearch, HybridSearchRoute, VectorSearch} -import org.apache.paimon.spark.SparkTable +import org.apache.paimon.spark.{SparkTable, SparkTypeUtils} import org.apache.paimon.spark.catalyst.plans.logical.PaimonTableValuedFunctions._ +import org.apache.paimon.spark.schema.PaimonMetadataColumn import org.apache.paimon.table.{DataTable, FullTextSearchTable, HybridSearchTable, InnerTable, VectorSearchTable} import org.apache.paimon.table.source.snapshot.TimeTravelUtil.InconsistentTagBucketException -import org.apache.spark.sql.PaimonUtils.createDataset +import org.apache.spark.sql.PaimonUtils.{createDataset, toAttributes} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistryBase import org.apache.spark.sql.catalyst.analysis.TableFunctionRegistry.TableFunctionBuilder -import org.apache.spark.sql.catalyst.expressions.{Attribute, CreateArray, CreateMap, CreateNamedStruct, Expression, ExpressionInfo, Literal} -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, CreateArray, CreateMap, CreateNamedStruct, Expression, ExpressionInfo, Literal, NamedExpression, OuterReference} +import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LeafNode, LogicalPlan, Project, SubqueryAlias, UnaryNode} import org.apache.spark.sql.catalyst.util.MapData import org.apache.spark.sql.connector.catalog.{Identifier, Table, TableCatalog} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation @@ -153,6 +155,9 @@ object PaimonTableValuedFunctions { argsWithoutTable: Seq[Expression]): LogicalPlan = { sparkTable match { case st @ SparkTable(innerTable: InnerTable) => + if (vsq.hasOuterReference(argsWithoutTable)) { + return vsq.createDynamicVectorSearch(innerTable, argsWithoutTable) + } val vectorSearch = vsq.createVectorSearch(innerTable, argsWithoutTable) val vectorSearchTable = VectorSearchTable.create(innerTable, vectorSearch) DataSourceV2Relation.create( @@ -189,6 +194,57 @@ object PaimonTableValuedFunctions { } } + def resolveLateralVectorSearch( + left: LogicalPlan, + right: LogicalPlan, + joinType: JoinType, + condition: Option[Expression]): Option[LogicalPlan] = { + extractDynamicVectorSearch(right) match { + case None => + None + case Some(_) if joinType != Inner => + throw new RuntimeException( + s"LATERAL vector_search only supports INNER join, but got: $joinType.") + case Some((relation, projectList, projectOutput)) => + val vectorSearchOutput = vectorSearchOutputForProject(relation, projectList) + val lateralVectorSearch = + LateralVectorSearch( + left, + relation.innerTable, + relation.columnName, + relation.queryVectorExpr, + relation.limit, + relation.options, + vectorSearchOutput, + projectList, + projectOutput) + Some(condition.map(Filter(_, lateralVectorSearch)).getOrElse(lateralVectorSearch)) + } + } + + private def vectorSearchOutputForProject( + relation: DynamicVectorSearchRelation, + projectList: Seq[NamedExpression]): Seq[Attribute] = { + val projectReferences = AttributeSet.fromAttributeSets(projectList.map(_.references)) + relation.output.filter(projectReferences.contains) + } + + private def extractDynamicVectorSearch(plan: LogicalPlan) + : Option[(DynamicVectorSearchRelation, Seq[NamedExpression], Seq[Attribute])] = { + plan match { + case SubqueryAlias(_, child) => + extractDynamicVectorSearch(child).map { + case (relation, projectList, _) => (relation, projectList, plan.output) + } + case Project(projectList, relation: DynamicVectorSearchRelation) + if projectList.forall(_.resolved) => + Some((relation, projectList, plan.output)) + case relation: DynamicVectorSearchRelation => + Some((relation, relation.output, relation.output)) + case _ => None + } + } + private def resolveFullTextSearchQuery( sparkTable: Table, sparkCatalog: TableCatalog, @@ -447,6 +503,48 @@ case class VectorSearchQuery(override val args: Seq[Expression]) } value.toString } + + def hasOuterReference(argsWithoutTable: Seq[Expression]): Boolean = { + val queryVector = argsWithoutTable(1) + (argsWithoutTable.size == 3 || argsWithoutTable.size == 4) && + (queryVector.references.nonEmpty || containsOuterReference(queryVector)) + } + + private def containsOuterReference(expr: Expression): Boolean = { + expr.isInstanceOf[OuterReference] || expr.children.exists(containsOuterReference) + } + + def createDynamicVectorSearch( + innerTable: InnerTable, + argsWithoutTable: Seq[Expression]): DynamicVectorSearchRelation = { + if (argsWithoutTable.size != 3 && argsWithoutTable.size != 4) { + throw new RuntimeException( + s"$VECTOR_SEARCH needs three or four parameters after table_name: " + + s"column_name, query_vector, limit[, options]. " + + s"Got ${argsWithoutTable.size} parameters after table_name." + ) + } + val columnName = argsWithoutTable.head.eval().toString + if (!innerTable.rowType().containsField(columnName)) { + throw new RuntimeException( + s"Column $columnName does not exist in table ${innerTable.name()}" + ) + } + val limit = parsePositiveLimit(argsWithoutTable(2).eval()) + val options: Map[String, String] = + if (argsWithoutTable.size == 4) { + extractOptions(argsWithoutTable(3)) + } else { + Map.empty[String, String] + } + DynamicVectorSearchRelation( + innerTable, + columnName, + argsWithoutTable(1), + limit, + options, + toAttributes(SparkTypeUtils.fromPaimonRowType(innerTable.rowType()))) + } } /** @@ -634,6 +732,61 @@ case class HybridSearchQuery(override val args: Seq[Expression]) } +case class DynamicVectorSearchRelation( + innerTable: InnerTable, + columnName: String, + queryVectorExpr: Expression, + limit: Int, + options: Map[String, String], + relationOutput: Seq[Attribute]) + extends LeafNode { + + private lazy val outputWithScore: Seq[Attribute] = + relationOutput ++ + Seq(PaimonMetadataColumn.SEARCH_SCORE.toAttribute) + + override def output: Seq[Attribute] = outputWithScore +} + +case class LateralVectorSearch( + left: LogicalPlan, + innerTable: InnerTable, + columnName: String, + queryVectorExpr: Expression, + limit: Int, + options: Map[String, String], + vectorSearchOutput: Seq[Attribute], + projectList: Seq[NamedExpression], + projectOutput: Seq[Attribute], + searchFilters: Seq[Expression] = Nil) + extends UnaryNode { + + override def child: LogicalPlan = left + + override def output: Seq[Attribute] = left.output ++ projectOutput + + lazy val searchFilterOutputSet: AttributeSet = { + val tableOutputSet = AttributeSet( + vectorSearchOutput.filterNot(_.name == PaimonMetadataColumn.SEARCH_SCORE_COLUMN)) + AttributeSet(projectList.zip(projectOutput).collect { + case (expr, attr) if expr.references.nonEmpty && expr.references.subsetOf(tableOutputSet) => + attr + }) + } + + override lazy val producedAttributes: AttributeSet = { + AttributeSet(vectorSearchOutput ++ output.filterNot(attr => inputSet.contains(attr))) + } + + override lazy val references: AttributeSet = { + AttributeSet.fromAttributeSets(expressions.map(_.references)) -- producedAttributes + } + + override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = { + copy(left = newChild) + } +} + /** * Plan for the [[FULL_TEXT_SEARCH]] table-valued function. * diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/PaimonStrategy.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/PaimonStrategy.scala index 321e61f2cb59..ce1a28df50c7 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/PaimonStrategy.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/PaimonStrategy.scala @@ -18,18 +18,31 @@ package org.apache.paimon.spark.execution +import org.apache.paimon.CoreOptions +import org.apache.paimon.globalindex.{GlobalIndexResult, ScoredGlobalIndexResult} import org.apache.paimon.partition.PartitionPredicate -import org.apache.paimon.spark.{SparkCatalog, SparkGenericCatalog, SparkTable, SparkUtils} +import org.apache.paimon.partition.PartitionPredicate.splitPartitionPredicatesAndDataPredicates +import org.apache.paimon.predicate.{Predicate, PredicateBuilder} +import org.apache.paimon.spark.{PaimonRecordReaderIterator, SparkCatalog, SparkGenericCatalog, SparkTable, SparkUtils, SparkV2FilterConverter} import org.apache.paimon.spark.catalog.{SparkBaseCatalog, SupportView} import org.apache.paimon.spark.catalyst.analysis.ResolvedPaimonView -import org.apache.paimon.spark.catalyst.plans.logical.{CopyIntoLocationCommand, CopyIntoLocationSource, CopyIntoTableCommand, CreateOrReplaceTagCommand, CreatePaimonView, DeleteTagCommand, DropPaimonView, PaimonCallCommand, PaimonDropPartitions, RenameTagCommand, ResolvedIdentifier, ShowPaimonViews, ShowTagsCommand, TruncatePaimonTableWithFilter} -import org.apache.paimon.table.Table +import org.apache.paimon.spark.catalyst.plans.logical.{CopyIntoLocationCommand, CopyIntoLocationSource, CopyIntoTableCommand, CreateOrReplaceTagCommand, CreatePaimonView, DeleteTagCommand, DropPaimonView, LateralVectorSearch, PaimonCallCommand, PaimonDropPartitions, RenameTagCommand, ResolvedIdentifier, ShowPaimonViews, ShowTagsCommand, TruncatePaimonTableWithFilter} +import org.apache.paimon.spark.data.SparkInternalRow +import org.apache.paimon.spark.schema.PaimonMetadataColumn +import org.apache.paimon.table.{InnerTable, SpecialFields, Table} +import org.apache.paimon.table.source.{BatchVectorSearchBuilder, InnerTableScan, ReadBuilder, VectorScan} +import org.apache.paimon.types.RowType +import org.apache.paimon.utils.RoaringNavigableMap64 +import org.apache.spark.TaskContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.PaimonUtils.{normalizeExprs, translateFilterV2} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{ResolvedNamespace, ResolvedTable} -import org.apache.spark.sql.catalyst.expressions.{Expression, GenericInternalRow, PredicateHelper} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, Expression, GenericInternalRow, JoinedRow, NamedExpression, OuterReference, PredicateHelper, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, DescribeRelation, LogicalPlan, ReplaceTable, ReplaceTableAsSelect, ShowCreateTable} +import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.connector.catalog.{Identifier, PaimonLookupCatalog, TableCatalog} import org.apache.spark.sql.execution.{PaimonDescribeTableExec, SparkPlan, SparkStrategy} import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Implicits, DataSourceV2Relation} @@ -37,6 +50,7 @@ import org.apache.spark.sql.execution.shim.{PaimonCreateTableAsSelectStrategy, P import org.apache.spark.sql.paimon.shims.SparkShimLoader import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer case class PaimonStrategy(spark: SparkSession) extends SparkStrategy @@ -61,6 +75,20 @@ case class PaimonStrategy(spark: SparkSession) val input = buildInternalRow(args) PaimonCallExec(c.output, procedure, input) :: Nil + case lvs: LateralVectorSearch => + LateralVectorSearchExec( + lvs.innerTable, + lvs.columnName, + lvs.queryVectorExpr, + lvs.limit, + lvs.options, + lvs.vectorSearchOutput, + lvs.projectList, + lvs.projectOutput, + lvs.searchFilters, + planLater(lvs.left) + ) :: Nil + case t @ ShowTagsCommand(PaimonCatalogAndIdentifier(catalog, ident)) => ShowTagsExec(catalog, ident, t.output) :: Nil @@ -215,3 +243,316 @@ case class PaimonStrategy(spark: SparkSession) SparkShimLoader.shim.classicApi.recacheByPlan(spark, v2Relation) } } + +case class LateralVectorSearchExec( + innerTable: InnerTable, + columnName: String, + queryVectorExpr: Expression, + limit: Int, + options: Map[String, String], + vectorSearchOutput: Seq[Attribute], + projectList: Seq[NamedExpression], + projectOutput: Seq[Attribute], + searchFilters: Seq[Expression], + child: SparkPlan) + extends SparkPlan + with PredicateHelper { + + override def children: Seq[SparkPlan] = Seq(child) + + override def output: Seq[Attribute] = child.output ++ projectOutput + + @transient override lazy val producedAttributes: AttributeSet = { + AttributeSet(vectorSearchOutput ++ output.filterNot(attr => inputSet.contains(attr))) + } + + @transient + override lazy val references: AttributeSet = { + AttributeSet.fromAttributeSets(expressions.map(_.references)) -- producedAttributes + } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[SparkPlan]): SparkPlan = { + copy(child = newChildren.head) + } + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitions { + outerRows => + val strippedQueryExpr = queryVectorExpr.transform { + case OuterReference(namedExpression) => namedExpression.toAttribute + } + val queryVectorProjection = UnsafeProjection.create(Seq(strippedQueryExpr), child.output) + val rightProjection = UnsafeProjection.create(projectList, vectorSearchOutput) + val joinedRow = new JoinedRow + val searchContext = createSearchContext(rightProjection) + val batchSize = searchContext.batchSize + + outerRows.map(_.copy()).grouped(batchSize).flatMap { + outerRowBatch => + val searchBatch = ArrayBuffer[LateralVectorSearchQuery]() + outerRowBatch.foreach { + outerRow => + toFloatArray(queryVectorProjection(outerRow).get(0, strippedQueryExpr.dataType)) + .foreach( + queryVector => searchBatch += LateralVectorSearchQuery(outerRow, queryVector)) + } + + if (searchBatch.isEmpty) { + Iterator.empty + } else { + search(searchBatch.toVector, searchContext).map { + case (outerRow, rightRow) => + joinedRow(outerRow, rightRow) + joinedRow.copy() + } + } + } + } + } + + private def createSearchContext(rightProjection: UnsafeProjection): LateralVectorSearchContext = { + val rowType = innerTable.rowType() + val readFieldNames = vectorSearchOutput + .filterNot(_.name == PaimonMetadataColumn.SEARCH_SCORE_COLUMN) + .map(_.name) + val readFieldNamesWithRowId = + if (readFieldNames.contains(SpecialFields.ROW_ID.name())) { + readFieldNames + } else { + readFieldNames :+ SpecialFields.ROW_ID.name() + } + val rowTypeWithRowId = SpecialFields.rowTypeWithRowId(rowType) + val readRowType = rowType.project(readFieldNames.asJava) + val readRowTypeWithRowId = SpecialFields.rowTypeWithRowId(readRowType) + val readBuilder = innerTable + .newReadBuilder() + .withReadType(rowTypeWithRowId.project(readFieldNamesWithRowId.asJava)) + val scoreMetadataColumns = + if (vectorSearchOutput.exists(_.name == PaimonMetadataColumn.SEARCH_SCORE_COLUMN)) { + Seq(PaimonMetadataColumn.SEARCH_SCORE) + } else { + Seq.empty + } + val resultRowType = + if (scoreMetadataColumns.isEmpty) { + readRowTypeWithRowId + } else { + new RowType( + (readRowTypeWithRowId.getFields.asScala ++ scoreMetadataColumns.map( + _.toPaimonDataField)).asJava) + } + val sparkRow = SparkInternalRow.create(resultRowType) + val vectorSearchBuilder = innerTable + .newBatchVectorSearchBuilder() + .withVectorColumn(columnName) + .withLimit(limit) + .withOptions(options.asJava) + pushSearchFilters(readBuilder, vectorSearchBuilder) + + val vectorPlan = vectorSearchBuilder.newVectorScan().scan() + val batchSize = + Math.max(1, new CoreOptions(innerTable.options()).vectorSearchLateralJoinBatchSize()) + + LateralVectorSearchContext( + readBuilder, + vectorSearchBuilder, + vectorPlan, + scoreMetadataColumns, + sparkRow, + rowIdOrdinal = resultRowType.getFieldIndex(SpecialFields.ROW_ID.name()), + projectionInputOrdinals = vectorSearchOutput.map { + attr => + if (attr.name == PaimonMetadataColumn.SEARCH_SCORE_COLUMN) { + -1 + } else { + resultRowType.getFieldIndex(attr.name) + } + }, + rightProjection, + batchSize + ) + } + + private def pushSearchFilters( + readBuilder: ReadBuilder, + vectorSearchBuilder: BatchVectorSearchBuilder): Unit = { + val predicates = convertSearchFilters() + if (predicates.nonEmpty) { + val split = splitPartitionPredicatesAndDataPredicates( + predicates.asJava, + innerTable.rowType(), + innerTable.partitionKeys()) + if (split.getLeft.isPresent) { + val partitionFilter = split.getLeft.get() + readBuilder.withPartitionFilter(partitionFilter) + vectorSearchBuilder.withPartitionFilter(partitionFilter) + } + if (!split.getRight.isEmpty) { + val dataFilter = PredicateBuilder.and(split.getRight) + readBuilder.withFilter(dataFilter) + vectorSearchBuilder.withFilter(dataFilter) + } + } + } + + private def convertSearchFilters(): Seq[Predicate] = { + if (searchFilters.isEmpty) { + Seq.empty + } else { + val converter = SparkV2FilterConverter(innerTable.rowType()) + normalizeExprs(searchFilters.map(rewriteSearchFilter), vectorSearchOutput) + .flatMap(splitConjunctivePredicates) + .map { + filter => + val sparkPredicate = translateFilterV2(filter).getOrElse { + throw new UnsupportedOperationException( + s"Cannot push down searched-table predicate for LATERAL vector_search: $filter") + } + converter.convert(sparkPredicate, ignoreFailure = false).getOrElse { + throw new UnsupportedOperationException( + s"Cannot convert searched-table predicate for LATERAL vector_search: $filter") + } + } + } + } + + private def rewriteSearchFilter(filter: Expression): Expression = { + val projectionByExprId = projectList + .zip(projectOutput) + .map { case (project, outputAttr) => outputAttr.exprId -> stripAlias(project) } + .toMap + filter.transform { case attr: Attribute => projectionByExprId.getOrElse(attr.exprId, attr) } + } + + private def stripAlias(expression: Expression): Expression = { + expression match { + case Alias(child, _) => child + case other => other + } + } + + private def search( + queries: Seq[LateralVectorSearchQuery], + context: LateralVectorSearchContext): Iterator[(InternalRow, InternalRow)] = { + val vectors = queries.map(_.queryVector).toArray + val globalIndexResults = context.vectorSearchBuilder + .withVectors(vectors) + .newBatchVectorRead() + .readBatch(context.vectorPlan) + .asScala + .toVector + val rowIdToMatches = createRowIdToMatches(queries, globalIndexResults) + val batchGlobalIndexResult = createBatchGlobalIndexResult(globalIndexResults) + val scan = context.readBuilder + .newScan() + .withGlobalIndexResult(batchGlobalIndexResult) + .asInstanceOf[InnerTableScan] + val read = context.readBuilder.newRead() + + scan.plan().splits().asScala.iterator.flatMap { + split => + val reader = + PaimonRecordReaderIterator(read.createReader(split), context.scoreMetadataColumns, split) + new Iterator[Iterator[(InternalRow, InternalRow)]] { + private var closed = false + + Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => closeOnce())) + + private def closeOnce(): Unit = { + if (!closed) { + closed = true + reader.close() + } + } + + override def hasNext: Boolean = { + val hasNext = reader.hasNext + if (!hasNext) { + closeOnce() + } + hasNext + } + + override def next(): Iterator[(InternalRow, InternalRow)] = { + val rightRow = context.sparkRow.replace(reader.next()) + val rowId = rightRow.getLong(context.rowIdOrdinal) + rowIdToMatches.getOrElse(rowId, Seq.empty).iterator.map { + searchMatch => + val projectedRow = projectRightRow(rightRow, searchMatch, context) + (searchMatch.outerRow, projectedRow) + } + } + }.flatMap(identity) + } + } + + private def projectRightRow( + rightRow: InternalRow, + searchMatch: LateralVectorSearchMatch, + context: LateralVectorSearchContext): InternalRow = { + val values = new Array[Any](vectorSearchOutput.size) + vectorSearchOutput.zipWithIndex.foreach { + case (attr, index) => + val ordinal = context.projectionInputOrdinals(index) + values(index) = if (ordinal < 0) { + searchMatch.score + } else { + rightRow.get(ordinal, attr.dataType) + } + } + context.rightProjection(new GenericInternalRow(values)) + } + + private def createRowIdToMatches( + queries: Seq[LateralVectorSearchQuery], + globalIndexResults: Seq[GlobalIndexResult]): Map[Long, Seq[LateralVectorSearchMatch]] = { + val rowIdToMatches = + scala.collection.mutable.LinkedHashMap[Long, ArrayBuffer[LateralVectorSearchMatch]]() + queries.zip(globalIndexResults).foreach { + case (query, result) => + val scoreGetter = result match { + case scored: ScoredGlobalIndexResult => Some(scored.scoreGetter()) + case _ => None + } + result.results().iterator().asScala.foreach { + rowId => + rowIdToMatches.getOrElseUpdate(rowId, ArrayBuffer()) += + LateralVectorSearchMatch( + query.outerRow, + scoreGetter.map(_.score(rowId)).getOrElse(Float.NaN)) + } + } + rowIdToMatches.iterator.map { case (rowId, matches) => rowId -> matches.toSeq }.toMap + } + + private def createBatchGlobalIndexResult( + globalIndexResults: Seq[GlobalIndexResult]): GlobalIndexResult = { + val rowIds = new RoaringNavigableMap64() + globalIndexResults.foreach(result => rowIds.or(result.results())) + GlobalIndexResult.create(rowIds) + } + + private def toFloatArray(value: Any): Option[Array[Float]] = { + value match { + case null => None + case arrayData: ArrayData => Some(arrayData.toFloatArray()) + case _ => + throw new RuntimeException(s"Cannot extract query vector from expression value: $value") + } + } + + private case class LateralVectorSearchContext( + readBuilder: ReadBuilder, + vectorSearchBuilder: BatchVectorSearchBuilder, + vectorPlan: VectorScan.Plan, + scoreMetadataColumns: Seq[PaimonMetadataColumn], + sparkRow: SparkInternalRow, + rowIdOrdinal: Int, + projectionInputOrdinals: Seq[Int], + rightProjection: UnsafeProjection, + batchSize: Int) + + private case class LateralVectorSearchQuery(outerRow: InternalRow, queryVector: Array[Float]) + + private case class LateralVectorSearchMatch(outerRow: InternalRow, score: Float) +} diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/extensions/PaimonSparkSessionExtensions.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/extensions/PaimonSparkSessionExtensions.scala index e433a5f7d49f..61481e201c0e 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/extensions/PaimonSparkSessionExtensions.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/extensions/PaimonSparkSessionExtensions.scala @@ -19,7 +19,7 @@ package org.apache.paimon.spark.extensions import org.apache.paimon.spark.catalyst.analysis.{PaimonAnalysis, PaimonDeleteTable, PaimonFunctionResolver, PaimonIncompatibleResolutionRules, PaimonMergeInto, PaimonPostHocResolutionRules, PaimonProcedureResolver, PaimonUpdateTable, PaimonViewResolver, ReplacePaimonFunctions, RewriteUpsertTable} -import org.apache.paimon.spark.catalyst.optimizer.{MergePaimonScalarSubqueries, OptimizeMetadataOnlyDeleteFromPaimonTable} +import org.apache.paimon.spark.catalyst.optimizer.{MergePaimonScalarSubqueries, OptimizeMetadataOnlyDeleteFromPaimonTable, PushDownLateralVectorSearchFilter} import org.apache.paimon.spark.catalyst.plans.logical.PaimonTableValuedFunctions import org.apache.paimon.spark.commands.BucketExpression import org.apache.paimon.spark.execution.{OldCompatibleStrategy, PaimonStrategy} @@ -102,6 +102,7 @@ class PaimonSparkSessionExtensions extends (SparkSessionExtensions => Unit) { extensions.injectOptimizerRule(spark => ReplacePaimonFunctions(spark)) extensions.injectOptimizerRule(_ => OptimizeMetadataOnlyDeleteFromPaimonTable) extensions.injectOptimizerRule(_ => MergePaimonScalarSubqueries) + extensions.injectOptimizerRule(_ => PushDownLateralVectorSearchFilter) // planner extensions extensions.injectPlannerStrategy(spark => PaimonStrategy(spark)) diff --git a/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkMultimodalITCase.java b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkMultimodalITCase.java index af3401ea2afa..3a4dd41b1685 100644 --- a/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkMultimodalITCase.java +++ b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkMultimodalITCase.java @@ -170,6 +170,29 @@ public void testVector(@TempDir java.nio.file.Path tempDir) throws IOException { .collect(Collectors.toList())); spark.close(); + spark = builder.getOrCreate(); + spark.sql("SET `spark.paimon.vector-search.distribute.enabled`=`false`"); + rows = + spark.sql( + "SELECT q.gid AS query_gid, q.embs AS query_embs, r.gid AS result_gid FROM my_db1.vector_test AS q, LATERAL (SELECT gid FROM vector_search('my_db1.vector_test', 'embs', q.embs, 5)) AS r WHERE q.`date` = '20260420';") + .collectAsList(); + assertThat(rows).hasSize(40); + assertThat( + rows.stream() + .collect( + Collectors.groupingBy( + row -> row.getLong(0), Collectors.counting()))) + .hasSize(8) + .containsEntry(1L, 5L) + .containsEntry(2L, 5L) + .containsEntry(3L, 5L) + .containsEntry(4L, 5L) + .containsEntry(5L, 5L) + .containsEntry(6L, 5L) + .containsEntry(7L, 5L) + .containsEntry(8L, 5L); + spark.close(); + spark = builder.getOrCreate(); spark.sql("DROP TABLE IF EXISTS `my_db1`.`vector_test`;"); spark.close(); diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/TableValuedFunctionsTest.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/TableValuedFunctionsTest.scala index 8114766a9911..0e60feaeb507 100644 --- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/TableValuedFunctionsTest.scala +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/TableValuedFunctionsTest.scala @@ -21,10 +21,11 @@ package org.apache.paimon.spark.sql import org.apache.paimon.data.{BinaryString, GenericRow, Timestamp} import org.apache.paimon.manifest.ManifestCommittable import org.apache.paimon.spark.PaimonHiveTestBase -import org.apache.paimon.spark.catalyst.plans.logical.PaimonTableValuedFunctions +import org.apache.paimon.spark.catalyst.plans.logical.{LateralVectorSearch, PaimonTableValuedFunctions} import org.apache.paimon.utils.DateTimeUtils import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.catalyst.plans.logical.Filter import java.time.LocalDateTime import java.util.Collections @@ -41,6 +42,123 @@ class TableValuedFunctionsTest extends PaimonHiveTestBase { assert(error.getMessage.contains("Limit must be no greater than")) } + test("lateral vector search preserves subquery alias qualifiers") { + withTable("vector_search_source", "vector_search_result") { + spark.sql(""" + |CREATE TABLE vector_search_source (gid BIGINT, embs ARRAY, dt STRING) + |USING paimon + |TBLPROPERTIES ( + | 'vector.file.format' = 'lance', + | 'vector-field' = 'embs', + | 'field.embs.vector-dim' = '3', + | 'row-tracking.enabled' = 'true', + | 'data-evolution.enabled' = 'true') + |PARTITIONED BY (dt) + |""".stripMargin) + spark.sql(""" + |CREATE TABLE vector_search_result ( + | query_gid BIGINT, + | query_embs ARRAY, + | result_gid BIGINT, + | result_embs ARRAY, + | score FLOAT, + | dt STRING) + |USING paimon + |PARTITIONED BY (dt) + |""".stripMargin) + + val insertOptimizedPlan = spark + .sql(""" + |SELECT q.gid AS query_gid, q.embs AS query_embs, + | r.gid AS result_gid, r.embs AS result_embs, + | r.__paimon_search_score AS score + |FROM vector_search_source AS q, + |LATERAL ( + | SELECT gid, embs, __paimon_search_score + | FROM vector_search('vector_search_source', 'embs', q.embs, 5) + |) AS r + |WHERE q.dt = '20260608' + |""".stripMargin) + .queryExecution + .optimizedPlan + val lateralVectorSearches = insertOptimizedPlan.collect { + case lvs: LateralVectorSearch => lvs + } + assert(lateralVectorSearches.size == 1, insertOptimizedPlan.toString) + + val optimizedPlanWithoutScore = spark + .sql(""" + |SELECT q.gid AS query_gid, r.embs AS result_embs + |FROM vector_search_source AS q, + |LATERAL ( + | SELECT embs + | FROM vector_search('vector_search_source', 'embs', q.embs, 5) + |) AS r + |""".stripMargin) + .queryExecution + .optimizedPlan + assert( + optimizedPlanWithoutScore.exists(_.isInstanceOf[LateralVectorSearch]), + optimizedPlanWithoutScore.toString) + + val analyzedPlanWithJoinCondition = spark + .sql(""" + |SELECT q.gid AS query_gid, r.result_gid, r.score + |FROM vector_search_source AS q, + |LATERAL ( + | SELECT gid AS result_gid, __paimon_search_score AS score + | FROM vector_search('vector_search_source', 'embs', q.embs, 5) + |) AS r + |WHERE q.gid = r.result_gid AND r.score >= 0.0 + |""".stripMargin) + .queryExecution + .analyzed + val lateralVectorSearchFilters = analyzedPlanWithJoinCondition.collect { + case filter @ Filter(_, _: LateralVectorSearch) => filter + } + assert(lateralVectorSearchFilters.size == 1, analyzedPlanWithJoinCondition.toString) + assert( + lateralVectorSearchFilters.head.condition.references + .subsetOf(lateralVectorSearchFilters.head.child.outputSet), + analyzedPlanWithJoinCondition.toString + ) + + val optimizedPlanWithSearchFilter = spark + .sql(""" + |SELECT q.gid AS query_gid, r.result_gid, r.dt + |FROM vector_search_source AS q, + |LATERAL ( + | SELECT gid AS result_gid, dt + | FROM vector_search('vector_search_source', 'embs', q.embs, 5) + |) AS r + |WHERE r.dt = '20260608' + |""".stripMargin) + .queryExecution + .optimizedPlan + val lateralVectorSearchesWithSearchFilter = optimizedPlanWithSearchFilter.collect { + case lvs: LateralVectorSearch => lvs + } + assert( + lateralVectorSearchesWithSearchFilter.size == 1, + optimizedPlanWithSearchFilter.toString) + assert( + lateralVectorSearchesWithSearchFilter.head.searchFilters.nonEmpty, + optimizedPlanWithSearchFilter.toString) + + val constantVectorPlan = spark + .sql(""" + |SELECT gid + |FROM vector_search( + | 'vector_search_source', 'embs', array(1.0f, 2.0f, 3.0f), 5) + |""".stripMargin) + .queryExecution + .optimizedPlan + assert( + !constantVectorPlan.exists(_.isInstanceOf[LateralVectorSearch]), + constantVectorPlan.toString) + } + } + withPk.foreach { hasPk => bucketModes.foreach {