Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] {
import DataSourceV2Implicits._
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsDown {

case a @ PaimonV2WriteCommand(table) if !paimonWriteResolved(a.query, table) =>
case a @ PaimonV2WriteCommand(table) if !paimonWriteResolved(a.query, table, a.isByName) =>
val mergeSchemaEnabled =
writeOptions(a).get(SparkConnectorOptions.MERGE_SCHEMA.key()).contains("true") ||
OptionUtils.writeMergeSchemaEnabled()
Expand Down Expand Up @@ -77,13 +77,16 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] {
}
}

private def paimonWriteResolved(query: LogicalPlan, table: NamedRelation): Boolean = {
private def paimonWriteResolved(
query: LogicalPlan,
table: NamedRelation,
isByName: Boolean): Boolean = {
query.output.size == table.output.size &&
query.output.zip(table.output).forall {
case (inAttr, outAttr) =>
val inType = CharVarcharUtils.getRawType(inAttr.metadata).getOrElse(inAttr.dataType)
val outType = CharVarcharUtils.getRawType(outAttr.metadata).getOrElse(outAttr.dataType)
inAttr.name == outAttr.name && schemaCompatible(inType, outType)
inAttr.name == outAttr.name && schemaCompatible(inType, outType, isByName)
}
}

Expand Down Expand Up @@ -176,21 +179,42 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] {
Project(project, query)
}

private def schemaCompatible(dataSchema: DataType, tableSchema: DataType): Boolean = {
private def schemaCompatible(
dataSchema: DataType,
tableSchema: DataType,
checkFieldNames: Boolean): Boolean = {
(dataSchema, tableSchema) match {
case (s1: StructType, s2: StructType) =>
s1.zip(s2).forall { case (d1, d2) => schemaCompatible(d1.dataType, d2.dataType) }
s1.length == s2.length &&
(!checkFieldNames ||
(!hasResolverConflicts(s1) &&
!hasResolverConflicts(s2) &&
structFieldsResolved(s1, s2))) &&
s1.zip(s2).forall {
case (d1, d2) => schemaCompatible(d1.dataType, d2.dataType, checkFieldNames)
}
case (a1: ArrayType, a2: ArrayType) =>
// todo: support array type nullable evaluation
schemaCompatible(a1.elementType, a2.elementType)
schemaCompatible(a1.elementType, a2.elementType, checkFieldNames)
case (m1: MapType, m2: MapType) =>
m1.valueContainsNull == m2.valueContainsNull &&
schemaCompatible(m1.keyType, m2.keyType) &&
schemaCompatible(m1.valueType, m2.valueType)
schemaCompatible(m1.keyType, m2.keyType, checkFieldNames) &&
schemaCompatible(m1.valueType, m2.valueType, checkFieldNames)
case (d1, d2) => d1 == d2
}
}

private def structFieldsResolved(source: StructType, target: StructType): Boolean = {
source.zip(target).forall {
case (sourceField, targetField) =>
conf.resolver(sourceField.name, targetField.name)
}
}

private def hasResolverConflicts(struct: StructType): Boolean = {
struct.fields.combinations(2).exists { case Array(a, b) => conf.resolver(a.name, b.name) }
}

private def addCastToColumn(
attr: Attribute,
targetAttr: Attribute,
Expand Down Expand Up @@ -223,34 +247,76 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] {
parent: NamedExpression,
source: StructType,
target: StructType): NamedExpression = {
// Reject target fields that collide under the current resolver
// (e.g. `name` and `Name` with case-insensitive resolution), otherwise
// we would silently map both target fields to the same source field.
val targetConflicts = target.fields
.combinations(2)
.collect {
case Array(a, b) if conf.resolver(a.name, b.name) => (a.name, b.name)
}
.toSeq
if (targetConflicts.nonEmpty) {
throw new RuntimeException(
"Cannot write incompatible data: nested struct has conflicting target field names: " +
targetConflicts.map { case (a, b) => s"`$a` vs `$b`" }.mkString(", ") + ".")
}

// Single pass: resolve each target field to its source match(es) and track
// which source indices were consumed, so we can detect extras without
// rescanning source and target repeatedly.
val sourceWithIndex = source.fields.zipWithIndex
val consumed = mutable.BitSet.empty
val resolved = target.fields.map {
tgt =>
val matches = sourceWithIndex.filter { case (f, _) => conf.resolver(f.name, tgt.name) }
matches.foreach { case (_, i) => consumed += i }
(tgt, matches)
}

// If source struct has fields not in target, reject so that merge-schema
// can handle the evolution instead of silently dropping the extra fields.
val targetFieldNames = target.fieldNames.toSet
val extraFields = source.fieldNames.filterNot(targetFieldNames.contains)
val extraFields = sourceWithIndex.collect {
case (f, i) if !consumed(i) => f.name
}
if (extraFields.nonEmpty) {
throw new RuntimeException(
s"Cannot write incompatible data: nested struct has extra fields: ${extraFields.mkString(", ")}.")
}

val fields = target.map {
case targetField @ StructField(name, nested: StructType, _, _) =>
val sourceIndex = source.fieldIndex(name)
val sourceField = source(sourceIndex)
sourceField.dataType match {
case s: StructType =>
val subField = castStructField(parent, sourceIndex, sourceField.name, targetField)
val fields = resolved.map {
case (targetField, matches) =>
val (sourceIndex, sourceField) = resolveSingleSourceField(matches, targetField.name, source)
(targetField.dataType, sourceField.dataType) match {
case (nested: StructType, s: StructType) =>
val subField = extractStructField(parent, sourceIndex, sourceField.name, targetField)
addCastToStructByName(subField, s, nested)
case o =>
case (_: StructType, o) =>
throw new RuntimeException(s"Can not support to cast $o to StructType.")
case _ =>
castStructField(parent, sourceIndex, sourceField.name, targetField)
}
case targetField =>
val sourceIndex = source.fieldIndex(targetField.name)
val sourceField = source(sourceIndex)
castStructField(parent, sourceIndex, sourceField.name, targetField)
}
structAlias(fields, parent)
}

private def resolveSingleSourceField(
matches: Array[(StructField, Int)],
name: String,
source: StructType): (Int, StructField) = {
if (matches.length == 1) {
val (field, index) = matches(0)
(index, field)
} else if (matches.isEmpty) {
throw new RuntimeException(
s"""Field "$name" does not exist in source struct type: ${source.simpleString}.""")
} else {
throw new RuntimeException(
s"""Cannot resolve nested field "$name" due to name conflicts: """ +
matches.map(_._1.name).mkString(", ") + ".")
}
}

private def addCastToStructByPosition(
parent: NamedExpression,
source: StructType,
Expand Down Expand Up @@ -300,6 +366,15 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] {
targetField.name)(explicitMetadata = Option(targetField.metadata))
}

private def extractStructField(
parent: NamedExpression,
i: Int,
sourceFieldName: String,
targetField: StructField): NamedExpression = {
Alias(GetStructField(parent, i, Option(sourceFieldName)), targetField.name)(
explicitMetadata = Option(targetField.metadata))
}

private def castToArrayStruct(
parent: NamedExpression,
source: StructType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,116 @@ abstract class InsertOverwriteTableTestBase extends PaimonSparkTestBase {
}
}

test("Paimon: insert by name with case-insensitive nested struct field matching") {
assume(gteqSpark3_5)
withTable("t1", "t2") {
// Source struct field order / types differ from target so paimonWriteResolved
// falls through schemaCompatible and we actually exercise addCastToStructByName.
spark.sql("""CREATE TABLE t1 (id INT NOT NULL, info STRUCT<Age: INT, Name: STRING>)
|TBLPROPERTIES ('write-only' = 'true')""".stripMargin)
spark.sql("""CREATE TABLE t2 (id INT NOT NULL, info STRUCT<name: STRING, age: INT>)
|TBLPROPERTIES ('write-only' = 'true')""".stripMargin)

sql("INSERT INTO t1 VALUES (1, struct(30, 'Alice')), (2, struct(25, 'Bob'))")

sql("INSERT INTO t2 BY NAME SELECT * FROM t1")
checkAnswer(
sql("SELECT * FROM t2 ORDER BY id"),
Row(1, Row("Alice", 30)) :: Row(2, Row("Bob", 25)) :: Nil)
}
}

test("Paimon: insert by name reorders same-type nested struct fields") {
assume(gteqSpark3_5)
withTable("t1", "t2") {
spark.sql("""CREATE TABLE t1 (id INT NOT NULL, info STRUCT<Nick: STRING, Name: STRING>)
|TBLPROPERTIES ('write-only' = 'true')""".stripMargin)
spark.sql("""CREATE TABLE t2 (id INT NOT NULL, info STRUCT<name: STRING, nick: STRING>)
|TBLPROPERTIES ('write-only' = 'true')""".stripMargin)

sql("INSERT INTO t1 VALUES (1, struct('Ally', 'Alice'))")

sql("INSERT INTO t2 BY NAME SELECT * FROM t1")
checkAnswer(sql("SELECT * FROM t2"), Row(1, Row("Alice", "Ally")) :: Nil)
}
}

test("Paimon: insert by name with case-insensitive matching inside array<struct<...>>") {
assume(gteqSpark3_5)
withTable("t1", "t2") {
spark.sql("""CREATE TABLE t1 (id INT NOT NULL, items ARRAY<STRUCT<Age: INT, Name: STRING>>)
|TBLPROPERTIES ('write-only' = 'true')""".stripMargin)
spark.sql("""CREATE TABLE t2 (id INT NOT NULL, items ARRAY<STRUCT<name: STRING, age: INT>>)
|TBLPROPERTIES ('write-only' = 'true')""".stripMargin)

sql("INSERT INTO t1 VALUES (1, array(struct(30, 'Alice'), struct(25, 'Bob')))")

sql("INSERT INTO t2 BY NAME SELECT * FROM t1")
checkAnswer(sql("SELECT * FROM t2"), Row(1, Seq(Row("Alice", 30), Row("Bob", 25))) :: Nil)
}
}

test("Paimon: insert by name with case-insensitive matching inside nested struct") {
assume(gteqSpark3_5)
withTable("t1", "t2") {
spark.sql("""CREATE TABLE t1 (
| id INT NOT NULL,
| info STRUCT<details: STRUCT<Age: INT, Name: STRING>>)
|TBLPROPERTIES ('write-only' = 'true')""".stripMargin)
spark.sql("""CREATE TABLE t2 (
| id INT NOT NULL,
| info STRUCT<details: STRUCT<name: STRING, age: INT>>)
|TBLPROPERTIES ('write-only' = 'true')""".stripMargin)

sql("INSERT INTO t1 VALUES (1, struct(struct(30, 'Alice')))")

sql("INSERT INTO t2 BY NAME SELECT * FROM t1")
checkAnswer(sql("SELECT * FROM t2"), Row(1, Row(Row("Alice", 30))) :: Nil)
}
}

test("Paimon: insert by name rejects ambiguous source nested struct fields") {
assume(gteqSpark3_5)
withSparkSQLConf("spark.sql.caseSensitive" -> "true") {
withTable("t1", "t2") {
// source has both `name` and `Name`, legal when session is case-sensitive
spark.sql("""CREATE TABLE t1 (id INT NOT NULL, info STRUCT<name: STRING, Name: STRING>)
|TBLPROPERTIES ('write-only' = 'true')""".stripMargin)
spark.sql("""CREATE TABLE t2 (id INT NOT NULL, info STRUCT<name: STRING>)
|TBLPROPERTIES ('write-only' = 'true')""".stripMargin)
sql("INSERT INTO t1 VALUES (1, struct('Alice', 'Bob'))")

withSparkSQLConf("spark.sql.caseSensitive" -> "false") {
val msg = intercept[Exception] {
sql("INSERT INTO t2 BY NAME SELECT * FROM t1")
}.getMessage
assert(msg.contains("name conflicts"))
}
}
}
}

test("Paimon: insert by name rejects nested struct target fields colliding under resolver") {
assume(gteqSpark3_5)
withSparkSQLConf("spark.sql.caseSensitive" -> "true") {
withTable("t1", "t2") {
spark.sql("""CREATE TABLE t1 (id INT NOT NULL, info STRUCT<name: STRING>)
|TBLPROPERTIES ('write-only' = 'true')""".stripMargin)
// target has both `name` and `Name`, legal when session is case-sensitive
spark.sql("""CREATE TABLE t2 (id INT NOT NULL, info STRUCT<name: STRING, Name: STRING>)
|TBLPROPERTIES ('write-only' = 'true')""".stripMargin)
sql("INSERT INTO t1 VALUES (1, struct('Alice'))")

withSparkSQLConf("spark.sql.caseSensitive" -> "false") {
val msg = intercept[Exception] {
sql("INSERT INTO t2 BY NAME SELECT * FROM t1")
}.getMessage
assert(msg.contains("conflicting target field names"))
}
}
}
}

withPk.foreach {
hasPk =>
bucketModes.foreach {
Expand Down