diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala index 55d72013d66c..6009b6a80761 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala @@ -43,7 +43,7 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] { import DataSourceV2Implicits._ override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsDown { - case a @ PaimonV2WriteCommand(table) if !paimonWriteResolved(a.query, table) => + case a @ PaimonV2WriteCommand(table) if !paimonWriteResolved(a.query, table, a.isByName) => val mergeSchemaEnabled = writeOptions(a).get(SparkConnectorOptions.MERGE_SCHEMA.key()).contains("true") || OptionUtils.writeMergeSchemaEnabled() @@ -77,13 +77,16 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] { } } - private def paimonWriteResolved(query: LogicalPlan, table: NamedRelation): Boolean = { + private def paimonWriteResolved( + query: LogicalPlan, + table: NamedRelation, + isByName: Boolean): Boolean = { query.output.size == table.output.size && query.output.zip(table.output).forall { case (inAttr, outAttr) => val inType = CharVarcharUtils.getRawType(inAttr.metadata).getOrElse(inAttr.dataType) val outType = CharVarcharUtils.getRawType(outAttr.metadata).getOrElse(outAttr.dataType) - inAttr.name == outAttr.name && schemaCompatible(inType, outType) + inAttr.name == outAttr.name && schemaCompatible(inType, outType, isByName) } } @@ -176,21 +179,42 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] { Project(project, query) } - private def schemaCompatible(dataSchema: DataType, tableSchema: DataType): Boolean = { + private def schemaCompatible( + dataSchema: DataType, + tableSchema: DataType, + checkFieldNames: Boolean): Boolean = { (dataSchema, tableSchema) match { case (s1: StructType, s2: StructType) => - s1.zip(s2).forall { case (d1, d2) => schemaCompatible(d1.dataType, d2.dataType) } + s1.length == s2.length && + (!checkFieldNames || + (!hasResolverConflicts(s1) && + !hasResolverConflicts(s2) && + structFieldsResolved(s1, s2))) && + s1.zip(s2).forall { + case (d1, d2) => schemaCompatible(d1.dataType, d2.dataType, checkFieldNames) + } case (a1: ArrayType, a2: ArrayType) => // todo: support array type nullable evaluation - schemaCompatible(a1.elementType, a2.elementType) + schemaCompatible(a1.elementType, a2.elementType, checkFieldNames) case (m1: MapType, m2: MapType) => m1.valueContainsNull == m2.valueContainsNull && - schemaCompatible(m1.keyType, m2.keyType) && - schemaCompatible(m1.valueType, m2.valueType) + schemaCompatible(m1.keyType, m2.keyType, checkFieldNames) && + schemaCompatible(m1.valueType, m2.valueType, checkFieldNames) case (d1, d2) => d1 == d2 } } + private def structFieldsResolved(source: StructType, target: StructType): Boolean = { + source.zip(target).forall { + case (sourceField, targetField) => + conf.resolver(sourceField.name, targetField.name) + } + } + + private def hasResolverConflicts(struct: StructType): Boolean = { + struct.fields.combinations(2).exists { case Array(a, b) => conf.resolver(a.name, b.name) } + } + private def addCastToColumn( attr: Attribute, targetAttr: Attribute, @@ -223,34 +247,76 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] { parent: NamedExpression, source: StructType, target: StructType): NamedExpression = { + // Reject target fields that collide under the current resolver + // (e.g. `name` and `Name` with case-insensitive resolution), otherwise + // we would silently map both target fields to the same source field. + val targetConflicts = target.fields + .combinations(2) + .collect { + case Array(a, b) if conf.resolver(a.name, b.name) => (a.name, b.name) + } + .toSeq + if (targetConflicts.nonEmpty) { + throw new RuntimeException( + "Cannot write incompatible data: nested struct has conflicting target field names: " + + targetConflicts.map { case (a, b) => s"`$a` vs `$b`" }.mkString(", ") + ".") + } + + // Single pass: resolve each target field to its source match(es) and track + // which source indices were consumed, so we can detect extras without + // rescanning source and target repeatedly. + val sourceWithIndex = source.fields.zipWithIndex + val consumed = mutable.BitSet.empty + val resolved = target.fields.map { + tgt => + val matches = sourceWithIndex.filter { case (f, _) => conf.resolver(f.name, tgt.name) } + matches.foreach { case (_, i) => consumed += i } + (tgt, matches) + } + // If source struct has fields not in target, reject so that merge-schema // can handle the evolution instead of silently dropping the extra fields. - val targetFieldNames = target.fieldNames.toSet - val extraFields = source.fieldNames.filterNot(targetFieldNames.contains) + val extraFields = sourceWithIndex.collect { + case (f, i) if !consumed(i) => f.name + } if (extraFields.nonEmpty) { throw new RuntimeException( s"Cannot write incompatible data: nested struct has extra fields: ${extraFields.mkString(", ")}.") } - val fields = target.map { - case targetField @ StructField(name, nested: StructType, _, _) => - val sourceIndex = source.fieldIndex(name) - val sourceField = source(sourceIndex) - sourceField.dataType match { - case s: StructType => - val subField = castStructField(parent, sourceIndex, sourceField.name, targetField) + val fields = resolved.map { + case (targetField, matches) => + val (sourceIndex, sourceField) = resolveSingleSourceField(matches, targetField.name, source) + (targetField.dataType, sourceField.dataType) match { + case (nested: StructType, s: StructType) => + val subField = extractStructField(parent, sourceIndex, sourceField.name, targetField) addCastToStructByName(subField, s, nested) - case o => + case (_: StructType, o) => throw new RuntimeException(s"Can not support to cast $o to StructType.") + case _ => + castStructField(parent, sourceIndex, sourceField.name, targetField) } - case targetField => - val sourceIndex = source.fieldIndex(targetField.name) - val sourceField = source(sourceIndex) - castStructField(parent, sourceIndex, sourceField.name, targetField) } structAlias(fields, parent) } + private def resolveSingleSourceField( + matches: Array[(StructField, Int)], + name: String, + source: StructType): (Int, StructField) = { + if (matches.length == 1) { + val (field, index) = matches(0) + (index, field) + } else if (matches.isEmpty) { + throw new RuntimeException( + s"""Field "$name" does not exist in source struct type: ${source.simpleString}.""") + } else { + throw new RuntimeException( + s"""Cannot resolve nested field "$name" due to name conflicts: """ + + matches.map(_._1.name).mkString(", ") + ".") + } + } + private def addCastToStructByPosition( parent: NamedExpression, source: StructType, @@ -300,6 +366,15 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] { targetField.name)(explicitMetadata = Option(targetField.metadata)) } + private def extractStructField( + parent: NamedExpression, + i: Int, + sourceFieldName: String, + targetField: StructField): NamedExpression = { + Alias(GetStructField(parent, i, Option(sourceFieldName)), targetField.name)( + explicitMetadata = Option(targetField.metadata)) + } + private def castToArrayStruct( parent: NamedExpression, source: StructType, diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTableTestBase.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTableTestBase.scala index 92e2c3ee1983..baecdaf997ea 100644 --- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTableTestBase.scala +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTableTestBase.scala @@ -126,6 +126,116 @@ abstract class InsertOverwriteTableTestBase extends PaimonSparkTestBase { } } + test("Paimon: insert by name with case-insensitive nested struct field matching") { + assume(gteqSpark3_5) + withTable("t1", "t2") { + // Source struct field order / types differ from target so paimonWriteResolved + // falls through schemaCompatible and we actually exercise addCastToStructByName. + spark.sql("""CREATE TABLE t1 (id INT NOT NULL, info STRUCT) + |TBLPROPERTIES ('write-only' = 'true')""".stripMargin) + spark.sql("""CREATE TABLE t2 (id INT NOT NULL, info STRUCT) + |TBLPROPERTIES ('write-only' = 'true')""".stripMargin) + + sql("INSERT INTO t1 VALUES (1, struct(30, 'Alice')), (2, struct(25, 'Bob'))") + + sql("INSERT INTO t2 BY NAME SELECT * FROM t1") + checkAnswer( + sql("SELECT * FROM t2 ORDER BY id"), + Row(1, Row("Alice", 30)) :: Row(2, Row("Bob", 25)) :: Nil) + } + } + + test("Paimon: insert by name reorders same-type nested struct fields") { + assume(gteqSpark3_5) + withTable("t1", "t2") { + spark.sql("""CREATE TABLE t1 (id INT NOT NULL, info STRUCT) + |TBLPROPERTIES ('write-only' = 'true')""".stripMargin) + spark.sql("""CREATE TABLE t2 (id INT NOT NULL, info STRUCT) + |TBLPROPERTIES ('write-only' = 'true')""".stripMargin) + + sql("INSERT INTO t1 VALUES (1, struct('Ally', 'Alice'))") + + sql("INSERT INTO t2 BY NAME SELECT * FROM t1") + checkAnswer(sql("SELECT * FROM t2"), Row(1, Row("Alice", "Ally")) :: Nil) + } + } + + test("Paimon: insert by name with case-insensitive matching inside array>") { + assume(gteqSpark3_5) + withTable("t1", "t2") { + spark.sql("""CREATE TABLE t1 (id INT NOT NULL, items ARRAY>) + |TBLPROPERTIES ('write-only' = 'true')""".stripMargin) + spark.sql("""CREATE TABLE t2 (id INT NOT NULL, items ARRAY>) + |TBLPROPERTIES ('write-only' = 'true')""".stripMargin) + + sql("INSERT INTO t1 VALUES (1, array(struct(30, 'Alice'), struct(25, 'Bob')))") + + sql("INSERT INTO t2 BY NAME SELECT * FROM t1") + checkAnswer(sql("SELECT * FROM t2"), Row(1, Seq(Row("Alice", 30), Row("Bob", 25))) :: Nil) + } + } + + test("Paimon: insert by name with case-insensitive matching inside nested struct") { + assume(gteqSpark3_5) + withTable("t1", "t2") { + spark.sql("""CREATE TABLE t1 ( + | id INT NOT NULL, + | info STRUCT>) + |TBLPROPERTIES ('write-only' = 'true')""".stripMargin) + spark.sql("""CREATE TABLE t2 ( + | id INT NOT NULL, + | info STRUCT>) + |TBLPROPERTIES ('write-only' = 'true')""".stripMargin) + + sql("INSERT INTO t1 VALUES (1, struct(struct(30, 'Alice')))") + + sql("INSERT INTO t2 BY NAME SELECT * FROM t1") + checkAnswer(sql("SELECT * FROM t2"), Row(1, Row(Row("Alice", 30))) :: Nil) + } + } + + test("Paimon: insert by name rejects ambiguous source nested struct fields") { + assume(gteqSpark3_5) + withSparkSQLConf("spark.sql.caseSensitive" -> "true") { + withTable("t1", "t2") { + // source has both `name` and `Name`, legal when session is case-sensitive + spark.sql("""CREATE TABLE t1 (id INT NOT NULL, info STRUCT) + |TBLPROPERTIES ('write-only' = 'true')""".stripMargin) + spark.sql("""CREATE TABLE t2 (id INT NOT NULL, info STRUCT) + |TBLPROPERTIES ('write-only' = 'true')""".stripMargin) + sql("INSERT INTO t1 VALUES (1, struct('Alice', 'Bob'))") + + withSparkSQLConf("spark.sql.caseSensitive" -> "false") { + val msg = intercept[Exception] { + sql("INSERT INTO t2 BY NAME SELECT * FROM t1") + }.getMessage + assert(msg.contains("name conflicts")) + } + } + } + } + + test("Paimon: insert by name rejects nested struct target fields colliding under resolver") { + assume(gteqSpark3_5) + withSparkSQLConf("spark.sql.caseSensitive" -> "true") { + withTable("t1", "t2") { + spark.sql("""CREATE TABLE t1 (id INT NOT NULL, info STRUCT) + |TBLPROPERTIES ('write-only' = 'true')""".stripMargin) + // target has both `name` and `Name`, legal when session is case-sensitive + spark.sql("""CREATE TABLE t2 (id INT NOT NULL, info STRUCT) + |TBLPROPERTIES ('write-only' = 'true')""".stripMargin) + sql("INSERT INTO t1 VALUES (1, struct('Alice'))") + + withSparkSQLConf("spark.sql.caseSensitive" -> "false") { + val msg = intercept[Exception] { + sql("INSERT INTO t2 BY NAME SELECT * FROM t1") + }.getMessage + assert(msg.contains("conflicting target field names")) + } + } + } + } + withPk.foreach { hasPk => bucketModes.foreach {