diff --git a/paimon-spark/paimon-spark-4.0/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/CreatePaimonSQLFunctionCommand.scala b/paimon-spark/paimon-spark-4.0/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/CreatePaimonSQLFunctionCommand.scala new file mode 100644 index 000000000000..f06ae2ca6f35 --- /dev/null +++ b/paimon-spark/paimon-spark-4.0/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/CreatePaimonSQLFunctionCommand.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.parser.extensions + +import org.apache.paimon.spark.catalog.SupportV1Function +import org.apache.paimon.spark.catalog.functions.SQLFunctionConverter +import org.apache.paimon.spark.leafnode.PaimonLeafRunnableCommand + +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.catalyst.FunctionIdentifier + +/** + * Simplified version of CreatePaimonSQLFunctionCommand for Spark 4.0. Persists the function without + * full body analysis (no type inference, no validation). The full version in spark4-common is used + * by Spark 4.1+. + */ +case class CreatePaimonSQLFunctionCommand( + catalog: SupportV1Function, + name: FunctionIdentifier, + inputParamText: Option[String], + returnTypeText: String, + exprText: Option[String], + queryText: Option[String], + comment: Option[String], + isDeterministic: Option[Boolean], + containsSQL: Option[Boolean], + isTableFunc: Boolean, + ignoreIfExists: Boolean, + replace: Boolean) + extends PaimonLeafRunnableCommand { + + override def run(sparkSession: SparkSession): Seq[Row] = { + require( + returnTypeText != null && returnTypeText.trim.nonEmpty, + s"SQL function $name requires an explicit RETURNS clause on Spark 4.0.") + + val parser = sparkSession.sessionState.sqlParser + val paimonFunction = SQLFunctionConverter.toPaimonFunction( + name, + inputParamText, + returnTypeText, + exprText, + queryText, + comment, + isDeterministic, + containsSQL, + parser) + + if (replace) { + catalog.dropV1Function(name, true) + } + catalog.createV1Function(paimonFunction, ignoreIfExists) + Nil + } + + override def simpleString(maxFields: Int): String = { + s"CreatePaimonSQLFunctionCommand: $name" + } +} diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/PaimonFunctionExec.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/PaimonFunctionExec.scala index 3ca46a43de07..0cb9e1cfb546 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/PaimonFunctionExec.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/execution/PaimonFunctionExec.scala @@ -88,25 +88,46 @@ case class DescribePaimonV1FunctionCommand( s"File Resources: ${functionDefinition.fileResources().asScala.map(_.uri()).mkString(", ")}") } case sqlFunctionDefinition: FunctionDefinition.SQLFunctionDefinition => - rows += Row(s"Function: ${function.fullName()}") - rows += Row("Type: SCALAR") + val buffer = new ArrayBuffer[(String, String)] + buffer += ("Function:" -> function.fullName()) + buffer += ("Type:" -> "SCALAR") val inputParams = function.inputParams() if (inputParams.isPresent && !inputParams.get().isEmpty) { - val params = inputParams - .get() - .asScala - .map(field => s"${field.name()} ${field.`type`().asSQLString()}") - .mkString(", ") - rows += Row(s"Input: $params") + val params = formatInputParams(inputParams.get().asScala) + buffer += ("Input:" -> params.head) + params.tail.foreach(s => buffer += ("" -> s)) + } else { + buffer += ("Input:" -> "()") } val returnParams = function.returnParams() if (returnParams.isPresent && !returnParams.get().isEmpty) { - rows += Row(s"Returns: ${returnParams.get().get(0).`type`().asSQLString()}") + buffer += ("Returns:" -> returnParams.get().get(0).`type`().asSQLString()) } if (isExtended) { - Option(function.comment()).foreach(c => rows += Row(s"Comment: $c")) - rows += Row(s"Body: ${sqlFunctionDefinition.definition()}") + Option(function.comment()).foreach(c => buffer += ("Comment:" -> c)) + buffer += ("Deterministic:" -> function.isDeterministic.toString) + val options = function.options() + Option(options.get("spark.sql-function.contains-sql")) + .map(_.toBoolean) + .foreach { + c => + val dataAccess = if (c) "CONTAINS SQL" else "READS SQL DATA" + buffer += ("Data Access:" -> dataAccess) + } + val configs = options.asScala + .filter(_._1.startsWith("sqlConfig.")) + .toSeq + .sortBy(_._1) + .map { case (k, v) => s"${k.stripPrefix("sqlConfig.")}=$v" } + if (configs.nonEmpty) { + buffer += ("Configs:" -> configs.head) + configs.tail.foreach(s => buffer += ("" -> s)) + } + buffer += ("Body:" -> sqlFunctionDefinition.definition()) } + val keys = tabulate(buffer.map(_._1).toSeq) + val values = buffer.map(_._2) + keys.zip(values).foreach { case (key, value) => rows += Row(s"$key $value") } case other => throw new UnsupportedOperationException(s"Unsupported function definition $other") } @@ -114,6 +135,27 @@ case class DescribePaimonV1FunctionCommand( rows.toSeq } + private def tabulate(inputs: Seq[String]): Seq[String] = { + val maxLen = inputs.map(_.length).max + inputs.map(_.padTo(maxLen, ' ')) + } + + private def formatInputParams( + params: Iterable[org.apache.paimon.types.DataField]): Seq[String] = { + val fields = params.toSeq + val names = tabulate(fields.map(_.name())) + val types = tabulate(fields.map(_.`type`().asSQLString())) + val defaults = fields.map { + f => if (isExtended) Option(f.defaultValue()).map(d => s" DEFAULT $d").getOrElse("") else "" + } + val comments = fields.map { + f => if (isExtended) Option(f.description()).map(c => s" '$c'").getOrElse("") else "" + } + names.zip(types).zip(defaults).zip(comments).map { + case (((name, dataType), default), comment) => s"$name $dataType$default$comment" + } + } + override def simpleString(maxFields: Int): String = { s"DescribePaimonV1FunctionCommand: ${function.fullName()}" } diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonSQLFunctionTestBase.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonSQLFunctionTestBase.scala index 230ec6f3dc87..8d13da7d60d8 100644 --- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonSQLFunctionTestBase.scala +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonSQLFunctionTestBase.scala @@ -101,16 +101,30 @@ abstract class PaimonSQLFunctionTestBase extends PaimonSparkTestWithRestCatalogB sql("CREATE FUNCTION area(width DOUBLE, height DOUBLE) RETURNS DOUBLE RETURN width * height") val desc = sql("DESCRIBE FUNCTION area").collect().map(_.getString(0)) - assert(desc.exists(_.contains("Type: SCALAR")), desc.mkString("\n")) - assert(desc.exists(_.contains("Input:")), desc.mkString("\n")) + assert(desc.exists(_.contains("SCALAR")), desc.mkString("\n")) + assert(desc.exists(_.contains("Input")), desc.mkString("\n")) assert(desc.exists(_.contains("width")), desc.mkString("\n")) - assert(desc.exists(_.contains("Returns: DOUBLE")), desc.mkString("\n")) + assert(desc.exists(_.contains("DOUBLE")), desc.mkString("\n")) val descExt = sql("DESCRIBE FUNCTION EXTENDED area").collect().map(_.getString(0)) + assert(descExt.exists(_.contains("Deterministic")), descExt.mkString("\n")) assert(descExt.exists(_.contains("width * height")), descExt.mkString("\n")) } } + test("Paimon SQL Function: describe function with comment") { + withUserDefinedFunction("inc" -> false) { + sql("CREATE FUNCTION inc(x INT) RETURNS INT COMMENT 'increment by one' RETURN x + 1") + + val desc = sql("DESCRIBE FUNCTION inc").collect().map(_.getString(0)) + assert(desc.exists(_.contains("SCALAR")), desc.mkString("\n")) + + val descExt = sql("DESCRIBE FUNCTION EXTENDED inc").collect().map(_.getString(0)) + assert(descExt.exists(_.contains("increment by one")), descExt.mkString("\n")) + assert(descExt.exists(_.contains("x + 1")), descExt.mkString("\n")) + } + } + test("Paimon SQL Function: show functions lists the created function") { withUserDefinedFunction("area" -> false) { sql("CREATE FUNCTION area(w DOUBLE, h DOUBLE) RETURNS DOUBLE RETURN w * h") @@ -132,6 +146,79 @@ abstract class PaimonSQLFunctionTestBase extends PaimonSparkTestWithRestCatalogB } } + test("Paimon SQL Function: SQL configs captured at creation time") { + assume(gteqSpark4_1) + withUserDefinedFunction("div_func" -> false) { + // Create with ANSI enabled — division by zero should throw at query time. + sql("SET spark.sql.ansi.enabled=true") + sql("CREATE FUNCTION div_func(x INT) RETURNS DOUBLE RETURN 1 / x") + sql("SET spark.sql.ansi.enabled=false") + + // Even though ANSI is now disabled in the session, the function was created with ANSI=true, + // so division by zero should still throw ArithmeticException. + val e = intercept[Exception] { + sql("SELECT div_func(0)").collect() + } + assert( + e.getMessage.contains("Division by zero") || + e.getMessage.contains("ArithmeticException") || + e.getMessage.contains("DIVIDE_BY_ZERO")) + + sql("RESET spark.sql.ansi.enabled") + } + } + + test("Paimon SQL Function: non-deterministic function body") { + assume(gteqSpark4_1) + withUserDefinedFunction("rnd" -> false) { + sql("CREATE FUNCTION rnd() RETURNS DOUBLE RETURN rand()") + val r1 = sql("SELECT rnd()").collect()(0).getDouble(0) + val r2 = sql("SELECT rnd()").collect()(0).getDouble(0) + assert(r1 >= 0.0 && r1 < 1.0) + assert(r2 >= 0.0 && r2 < 1.0) + } + } + + test("Paimon SQL Function: reject aggregate in scalar function body") { + assume(gteqSpark4_1) + val e = intercept[Exception] { + sql("CREATE FUNCTION bad_agg(x INT) RETURNS INT RETURN SUM(x)") + } + assert(e.getMessage.contains("CANNOT_CONTAIN_COMPLEX_FUNCTIONS")) + } + + test("Paimon SQL Function: reject window function in scalar function body") { + assume(gteqSpark4_1) + val e = intercept[Exception] { + sql("CREATE FUNCTION bad_win(x INT) RETURNS INT RETURN ROW_NUMBER() OVER (ORDER BY x)") + } + assert(e.getMessage.contains("CANNOT_CONTAIN_COMPLEX_FUNCTIONS")) + } + + test("Paimon SQL Function: reject duplicate parameter names") { + assume(gteqSpark4_1) + val e = intercept[Exception] { + sql("CREATE FUNCTION bad_dup(x INT, x INT) RETURNS INT RETURN x + x") + } + assert(e.getMessage.toLowerCase.contains("duplicate")) + } + + test("Paimon SQL Function: reject non-trailing defaults") { + assume(gteqSpark4_1) + val e = intercept[Exception] { + sql("CREATE FUNCTION bad_def(x INT DEFAULT 1, y INT) RETURNS INT RETURN x + y") + } + assert(e.getMessage.toLowerCase.contains("default")) + } + + test("Paimon SQL Function: omitting RETURNS clause") { + assume(gteqSpark4_1) + withUserDefinedFunction("inc" -> false) { + sql("CREATE FUNCTION inc(x INT) RETURN x + 1") + checkAnswer(sql("SELECT inc(10)"), Row(11)) + } + } + test("Paimon SQL Function: table function is not supported yet") { val e = intercept[Exception] { sql(""" diff --git a/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/paimon/spark/catalog/functions/SQLFunctionConverter.scala b/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/paimon/spark/catalog/functions/SQLFunctionConverter.scala index a81baff8f354..4aee17aa2ea9 100644 --- a/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/paimon/spark/catalog/functions/SQLFunctionConverter.scala +++ b/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/paimon/spark/catalog/functions/SQLFunctionConverter.scala @@ -32,13 +32,16 @@ import org.apache.spark.sql.types.{DataType => SparkDataType, StructType} import java.util.{Collections, HashMap => JHashMap, List => JList} +import scala.collection.JavaConverters._ + /** Converts between Spark SQLFunction and Paimon Function with a SQLFunctionDefinition body. */ object SQLFunctionConverter { - // Spark-specific metadata stored in Paimon Function.options(). - private val IS_QUERY = "spark.sql-function.is-query" - private val DETERMINISTIC = "spark.sql-function.deterministic" - private val CONTAINS_SQL = "spark.sql-function.contains-sql" + // Paimon-specific option keys (prefixed to avoid collision with Spark properties). + private val PAIMON_OPTION_PREFIX = "spark.sql-function." + private val IS_QUERY = PAIMON_OPTION_PREFIX + "is-query" + private val DETERMINISTIC = PAIMON_OPTION_PREFIX + "deterministic" + private val CONTAINS_SQL = PAIMON_OPTION_PREFIX + "contains-sql" /** Build a Paimon function from a parsed CREATE FUNCTION ... RETURN statement. */ def toPaimonFunction( @@ -50,10 +53,11 @@ object SQLFunctionConverter { comment: Option[String], isDeterministic: Option[Boolean], containsSQL: Option[Boolean], - parser: ParserInterface): PaimonFunction = { + parser: ParserInterface, + properties: Map[String, String] = Map.empty): PaimonFunction = { require( returnTypeText != null && returnTypeText.trim.nonEmpty, - s"SQL function $funcIdent must declare an explicit RETURNS type.") + s"SQL function $funcIdent must have a return type (explicit or inferred).") val identifier = FunctionIdentifierConverter.toPaimonIdentifier(funcIdent) val inputParams: JList[DataField] = inputParamText.filter(_.trim.nonEmpty) match { @@ -78,12 +82,13 @@ object SQLFunctionConverter { options.put(IS_QUERY, isQuery.toString) isDeterministic.foreach(d => options.put(DETERMINISTIC, d.toString)) containsSQL.foreach(c => options.put(CONTAINS_SQL, c.toString)) + properties.foreach { case (k, v) => options.put(k, v) } new FunctionImpl( identifier, inputParams, returnParams, - isDeterministic.getOrElse(true), + isDeterministic.getOrElse(true), // caller should always pass Some after analysis Collections.singletonMap(FUNCTION_DEFINITION_NAME, FunctionDefinition.sql(body)), comment.orNull, options @@ -139,7 +144,7 @@ object SQLFunctionConverter { deterministic = deterministic, containsSQL = Option(options.get(CONTAINS_SQL)).map(_.toBoolean), isTableFunc = false, - properties = Map.empty + properties = options.asScala.filterNot(_._1.startsWith(PAIMON_OPTION_PREFIX)).toMap ) SQLFunctionExpression( diff --git a/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/CreatePaimonSQLFunctionCommand.scala b/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/CreatePaimonSQLFunctionCommand.scala new file mode 100644 index 000000000000..de4195ac3fcd --- /dev/null +++ b/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/CreatePaimonSQLFunctionCommand.scala @@ -0,0 +1,514 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.parser.extensions + +import org.apache.paimon.spark.catalog.SupportV1Function +import org.apache.paimon.spark.catalog.functions.SQLFunctionConverter +import org.apache.paimon.spark.leafnode.PaimonLeafRunnableCommand + +import org.apache.spark.SparkException +import org.apache.spark.sql.{AnalysisException, Row, SparkSession} +import org.apache.spark.sql.catalyst.CapturesConfig +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.analysis.{withPosition, Analyzer, SQLFunctionExpression, SQLFunctionNode, SQLScalarFunction, SQLTableFunction, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedTableValuedFunction} +import org.apache.spark.sql.catalyst.catalog.{SessionCatalog, SQLFunction, UserDefinedFunction, UserDefinedFunctionErrors} +import org.apache.spark.sql.catalyst.catalog.UserDefinedFunction._ +import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, Expression, Generator, LateralSubquery, Literal, ScalarSubquery, SubqueryExpression, WindowExpression} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.logical.{LateralJoin, LocalRelation, LogicalPlan, OneRowRelation, Project, Range, UnresolvedWith, View} +import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_ATTRIBUTE +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.command.CreateUserDefinedFunctionCommand._ +import org.apache.spark.sql.execution.command.ViewHelper +import org.apache.spark.sql.types.{DataType, StructField, StructType} + +/** + * Adapted from Spark's CreateSQLFunctionCommand. Analyzes the function body, validates, derives + * deterministic/containsSQL, then persists to the Paimon catalog instead of the session catalog. + */ +case class CreatePaimonSQLFunctionCommand( + catalog: SupportV1Function, + name: FunctionIdentifier, + inputParamText: Option[String], + returnTypeText: String, + exprText: Option[String], + queryText: Option[String], + comment: Option[String], + isDeterministic: Option[Boolean], + containsSQL: Option[Boolean], + isTableFunc: Boolean, + ignoreIfExists: Boolean, + replace: Boolean) + extends PaimonLeafRunnableCommand + with CapturesConfig { + + import SQLFunction._ + + override def run(sparkSession: SparkSession): Seq[Row] = { + val parser = sparkSession.sessionState.sqlParser + val analyzer = sparkSession.sessionState.analyzer + val sessionCatalog = sparkSession.sessionState.catalog + val conf = sparkSession.sessionState.conf + + val inputParam = inputParamText.map(UserDefinedFunction.parseRoutineParam(_, parser)) + val returnType = parseReturnTypeText(returnTypeText, isTableFunc, parser) + + val function = SQLFunction( + name, + inputParam, + returnType.getOrElse(if (isTableFunc) Right(null) else Left(null)), + exprText, + queryText, + comment, + isDeterministic, + containsSQL, + isTableFunc, + Map.empty + ) + + val newFunction = { + val (expression, query) = function.getExpressionAndQuery(parser, isTableFunc) + assert(query.nonEmpty || expression.nonEmpty) + + // Build function input. + val inputPlan = if (inputParam.isDefined) { + val param = inputParam.get + checkParameterNotNull(param, inputParamText.get) + checkParameterNameDuplication(param, conf, name) + checkDefaultsTrailing(param, name) + + // Qualify the input parameters with the function name so that attributes referencing + // the function input parameters can be resolved correctly. + val qualifier = Seq(name.funcName) + val input = param.map( + p => + Alias( + { + val defaultExpr = p.getDefault() + if (defaultExpr.isEmpty) { + Literal.create(null, p.dataType) + } else { + val defaultPlan = parseDefault(defaultExpr.get, parser) + if (SubqueryExpression.hasSubquery(defaultPlan)) { + throw new AnalysisException( + errorClass = "USER_DEFINED_FUNCTIONS.NOT_A_VALID_DEFAULT_EXPRESSION", + messageParameters = + Map("functionName" -> name.funcName, "parameterName" -> p.name)) + } else if (defaultPlan.containsPattern(UNRESOLVED_ATTRIBUTE)) { + // TODO(SPARK-50698): use parsed expression instead of expression string. + defaultPlan.collect { + case a: UnresolvedAttribute => + throw QueryCompilationErrors.unresolvedAttributeError( + "UNRESOLVED_COLUMN", + a.sql, + Seq.empty, + a.origin) + } + } + Cast(defaultPlan, p.dataType) + } + }, + p.name + )(qualifier = qualifier)) + Project(input, OneRowRelation()) + } else { + OneRowRelation() + } + + // Build the function body and check if the function body can be analyzed successfully. + val (unresolvedPlan, analyzedPlan, inferredReturnType) = if (!isTableFunc) { + // Build SQL scalar function plan. + val outputExpr = if (query.isDefined) ScalarSubquery(query.get) else expression.get + val plan: LogicalPlan = returnType + .map { + t => + val retType: DataType = t match { + case Left(t) => t + case _ => + throw SparkException.internalError("Unexpected return type for a scalar SQL UDF.") + } + val outputCast = Seq(Alias(Cast(outputExpr, retType), name.funcName)()) + Project(outputCast, inputPlan) + } + .getOrElse { + // If no explicit RETURNS clause is present, infer the result type from the function body. + val outputAlias = Seq(Alias(outputExpr, name.funcName)()) + Project(outputAlias, inputPlan) + } + + // Check cyclic function reference before running the analyzer. + checkCyclicFunctionReference(sessionCatalog, name, plan) + + // Check the function body can be analyzed correctly. + val analyzed = analyzer.execute(plan) + val (resolved, resolvedReturnType) = analyzed match { + case p @ Project(expr :: Nil, _) if expr.resolved => + (p, Left(expr.dataType)) + case other => + (other, function.returnType) + } + + // Check if the SQL function body contains aggregate/window functions. + // This check needs to be performed before checkAnalysis to provide better error messages. + checkAggOrWindowOrGeneratorExpr(resolved) + + // Check if the SQL function body can be analyzed. + checkFunctionBodyAnalysis(analyzer, function, resolved) + + (plan, resolved, resolvedReturnType) + } else { + // Build SQL table function plan. + if (query.isEmpty) { + throw UserDefinedFunctionErrors.bodyIsNotAQueryForSqlTableUdf(name.funcName) + } + // Check cyclic function reference before running the analyzer. + checkCyclicFunctionReference(sessionCatalog, name, query.get) + + // Construct a lateral join to analyze the function body. + val plan = LateralJoin(inputPlan, LateralSubquery(query.get), Inner, None) + val analyzed = analyzer.execute(plan) + val newPlan = analyzed match { + case Project(_, j: LateralJoin) => j + case j: LateralJoin => j + case _ => + throw SparkException.internalError( + "Unexpected plan returned when " + + s"creating a SQL TVF: ${analyzed.getClass.getSimpleName}.") + } + val maybeResolved = newPlan.asInstanceOf[LateralJoin].right.plan + + // Check if the function body can be analyzed. + checkFunctionBodyAnalysis(analyzer, function, maybeResolved) + + // Get the function's return schema. + val returnParam: StructType = returnType + .map { + case Right(t) => t + case Left(_) => + throw SparkException.internalError( + "Unexpected return schema for a SQL table function.") + } + .getOrElse { + query.get match { + case Project(projectList, _) if projectList.exists(_.isInstanceOf[UnresolvedAlias]) => + throw UserDefinedFunctionErrors.missingColumnNamesForSqlTableUdf(name.funcName) + case _ => + StructType(analyzed.asInstanceOf[LateralJoin].right.plan.output.map { + col => StructField(col.name, col.dataType) + }) + } + } + + // Check the return columns cannot have NOT NULL specified. + checkParameterNotNull(returnParam, returnTypeText) + + // Check duplicated return column names. + checkReturnsColumnDuplication(returnParam, conf, name) + + // Check if the actual output size equals to the number of return parameters. + val outputSize = maybeResolved.output.size + if (outputSize != returnParam.size) { + throw new AnalysisException( + errorClass = "USER_DEFINED_FUNCTIONS.RETURN_COLUMN_COUNT_MISMATCH", + messageParameters = Map( + "outputSize" -> s"$outputSize", + "returnParamSize" -> s"${returnParam.size}", + "name" -> s"$name" + ) + ) + } + + (plan, analyzed, Right(returnParam)) + } + + // A permanent function is not allowed to reference temporary objects. + verifyTemporaryObjectsNotExists(sessionCatalog, name, unresolvedPlan, analyzedPlan) + + // Generate function properties. + val properties = generateFunctionProperties(sparkSession, unresolvedPlan, analyzedPlan) + + // Derive determinism of the SQL function. + val deterministic = analyzedPlan.deterministic + + // Derive and check a SQL function with CONTAINS SQL data access should not reads SQL data. + val readsSQLData = deriveSQLDataAccess(analyzedPlan) + + function.copy( + // Assign the return type, inferring from the function body if needed. + returnType = inferredReturnType, + deterministic = Some(function.deterministic.getOrElse(deterministic)), + containsSQL = Some(function.containsSQL.getOrElse(!readsSQLData)), + properties = properties + ) + } + + // ---- Paimon-specific: persist to Paimon catalog ---- + val resolvedReturnTypeText = newFunction.returnType match { + case Left(dt) if dt != null => dt.sql + case _ => + throw new UnsupportedOperationException( + s"Cannot infer return type for SQL function ${name.funcName}. " + + "Please add an explicit RETURNS clause.") + } + + val paimonFunction = SQLFunctionConverter.toPaimonFunction( + name, + inputParamText, + if (returnTypeText != null && returnTypeText.trim.nonEmpty) returnTypeText + else resolvedReturnTypeText, + exprText, + queryText, + comment, + newFunction.deterministic, + newFunction.containsSQL, + parser, + newFunction.properties + ) + + if (replace) { + catalog.dropV1Function(name, true) + } + catalog.createV1Function(paimonFunction, ignoreIfExists) + Nil + } + + /** Check if the function body can be analyzed. */ + private def checkFunctionBodyAnalysis( + analyzer: Analyzer, + function: SQLFunction, + body: LogicalPlan): Unit = { + analyzer.checkAnalysis(SQLFunctionNode(function, body)) + } + + /** Collect all temporary views and functions and return the identifiers separately */ + private def collectTemporaryObjectsInUnresolvedPlan( + catalog: SessionCatalog, + child: LogicalPlan): (Seq[Seq[String]], Seq[String]) = { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + def collectTempViews(child: LogicalPlan): Seq[Seq[String]] = { + child.flatMap { + case UnresolvedRelation(nameParts, _, _) if catalog.isTempView(nameParts) => + Seq(nameParts) + case w: UnresolvedWith if !w.resolved => w.innerChildren.flatMap(collectTempViews) + case plan if !plan.resolved => + plan.expressions.flatMap(_.flatMap { + case e: SubqueryExpression => collectTempViews(e.plan) + case _ => Seq.empty + }) + case _ => Seq.empty + }.distinct + } + + def collectTempFunctions(child: LogicalPlan): Seq[String] = { + child.flatMap { + case w: UnresolvedWith if !w.resolved => w.innerChildren.flatMap(collectTempFunctions) + case plan if !plan.resolved => + plan.expressions.flatMap(_.flatMap { + case e: SubqueryExpression => collectTempFunctions(e.plan) + case e: UnresolvedFunction + if catalog.isTemporaryFunction(e.nameParts.asFunctionIdentifier) => + Seq(e.nameParts.asFunctionIdentifier.funcName) + case _ => Seq.empty + }) + case _ => Seq.empty + }.distinct + } + (collectTempViews(child), collectTempFunctions(child)) + } + + /** + * Permanent functions are not allowed to reference temp objects, including temp functions and + * temp views. + */ + private def verifyTemporaryObjectsNotExists( + catalog: SessionCatalog, + name: FunctionIdentifier, + child: LogicalPlan, + analyzed: LogicalPlan): Unit = { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + val (tempViews, tempFunctions) = collectTemporaryObjectsInUnresolvedPlan(catalog, child) + tempViews.foreach { + nameParts => + throw UserDefinedFunctionErrors.invalidTempViewReference( + routineName = name.asMultipart, + tempViewName = nameParts) + } + tempFunctions.foreach { + funcName => + throw UserDefinedFunctionErrors.invalidTempFuncReference( + routineName = name.asMultipart, + tempFuncName = funcName) + } + val tempVars = ViewHelper.collectTemporaryVariables(analyzed) + tempVars.foreach { + varName => + throw UserDefinedFunctionErrors.invalidTempVarReference( + routineName = name.asMultipart, + varName = varName) + } + } + + /** Check if the given plan contains cyclic function references. */ + private def checkCyclicFunctionReference( + catalog: SessionCatalog, + identifier: FunctionIdentifier, + plan: LogicalPlan): Unit = { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + def checkPlan(plan: LogicalPlan, path: Seq[FunctionIdentifier]): Unit = { + plan.foreach { + case u @ UnresolvedTableValuedFunction(nameParts, arguments, _) => + try { + val funcId = nameParts.asFunctionIdentifier + val info = catalog.lookupFunctionInfo(funcId) + if (isSQLFunction(info.getClassName)) { + val f = withPosition(u) { + catalog.lookupTableFunction(funcId, arguments).asInstanceOf[SQLTableFunction] + } + val newPath = path :+ f.function.name + if (f.function.name == name) { + throw UserDefinedFunctionErrors.cyclicFunctionReference(newPath.mkString(" -> ")) + } + val plan = catalog.makeSQLTableFunctionPlan(f.name, f.function, f.inputs, f.output) + checkPlan(plan, newPath) + } + } catch { + case _: AnalysisException => + } + case p: LogicalPlan => + p.expressions.foreach(checkExpression(_, path)) + } + } + + def checkExpression(expression: Expression, path: Seq[FunctionIdentifier]): Unit = { + expression.foreach { + case s: SubqueryExpression => checkPlan(s.plan, path) + case u @ UnresolvedFunction(nameParts, arguments, _, _, _, _, _) => + try { + val funcId = nameParts.asFunctionIdentifier + val info = catalog.lookupFunctionInfo(funcId) + if (isSQLFunction(info.getClassName)) { + val f = withPosition(u) { + catalog.lookupFunction(funcId, arguments).asInstanceOf[SQLFunctionExpression] + } + val newPath = path :+ f.function.name + if (f.function.name == name) { + throw UserDefinedFunctionErrors.cyclicFunctionReference(newPath.mkString(" -> ")) + } + val plan = catalog.makeSQLFunctionPlan(f.name, f.function, f.inputs) + checkPlan(plan, newPath) + } + } catch { + case _: AnalysisException => + } + case _ => + } + } + + checkPlan(plan, Seq(identifier)) + } + + /** + * Check if the SQL function body contains aggregate/window/generate functions. Note subqueries + * inside the SQL function body can contain aggregate/window/generate functions. + */ + private def checkAggOrWindowOrGeneratorExpr(plan: LogicalPlan): Unit = { + if (plan.resolved) { + plan.transformAllExpressions { + case e + if e.isInstanceOf[WindowExpression] || e.isInstanceOf[Generator] || + e.isInstanceOf[AggregateExpression] => + throw new AnalysisException( + errorClass = "USER_DEFINED_FUNCTIONS.CANNOT_CONTAIN_COMPLEX_FUNCTIONS", + messageParameters = Map("queryText" -> s"${exprText.orElse(queryText).get}") + ) + } + } + } + + /** + * Derive the SQL data access routine of the function and check if the SQL function matches its + * data access routine. If the data access is CONTAINS SQL, the expression should not access + * operators and expressions that read SQL data. + * + * Returns true is SQL data access routine is READS SQL DATA, otherwise returns false. + */ + private def deriveSQLDataAccess(plan: LogicalPlan): Boolean = { + // Find logical plan nodes that read SQL data. + val readsSQLData = plan.find { + case _: View => true + case p if p.children.isEmpty => + p match { + case _: OneRowRelation | _: LocalRelation | _: Range => false + case _ => true + } + case f: SQLFunctionNode => f.function.containsSQL.contains(false) + case p: LogicalPlan => + lazy val sub = p.subqueries.exists(deriveSQLDataAccess) + // If the SQL function contains another SQL function that has SQL data access routine + // to be READS SQL DATA, then this SQL function will also be READS SQL DATA. + p.expressions.exists( + expr => + expr.find { + case f: SQLScalarFunction => f.function.containsSQL.contains(false) + case sub: SubqueryExpression => deriveSQLDataAccess(sub.plan) + case _ => false + }.isDefined) + }.isDefined + + if (containsSQL.contains(true) && readsSQLData) { + throw new AnalysisException( + errorClass = "INVALID_SQL_FUNCTION_DATA_ACCESS", + messageParameters = Map.empty + ) + } + + readsSQLData + } + + /** + * Generate the function properties, including: + * 1. the SQL configs when creating the function. + * 2. the catalog and database name when creating the function. This will be used to provide + * context during nested function resolution. + * 3. referred temporary object names if the function is a temp function. + */ + private def generateFunctionProperties( + session: SparkSession, + plan: LogicalPlan, + analyzed: LogicalPlan): Map[String, String] = { + val catalog = session.sessionState.catalog + val conf = session.sessionState.conf + val manager = session.sessionState.catalogManager + + val tempVars = ViewHelper.collectTemporaryVariables(analyzed) + + sqlConfigsToProps(conf, SQL_CONFIG_PREFIX) ++ + catalogAndNamespaceToProps( + manager.currentCatalog.name, + manager.currentNamespace.toIndexedSeq) ++ + referredTempNamesToProps(Nil, Nil, tempVars) + } + + override def simpleString(maxFields: Int): String = { + s"CreatePaimonSQLFunctionCommand: $name" + } +} diff --git a/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/RewritePaimonSQLFunctionCommands.scala b/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/RewritePaimonSQLFunctionCommands.scala index cc2c4df6a6a2..099f6d84e971 100644 --- a/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/RewritePaimonSQLFunctionCommands.scala +++ b/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/RewritePaimonSQLFunctionCommands.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.catalyst.parser.extensions import org.apache.paimon.spark.catalog.SupportV1Function -import org.apache.paimon.spark.catalog.functions.SQLFunctionConverter -import org.apache.paimon.spark.execution.CreatePaimonV1FunctionCommand import org.apache.paimon.spark.util.OptionUtils import org.apache.spark.sql.SparkSession @@ -29,10 +27,10 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.CatalogManager /** - * Parser-stage rule that rewrites a Paimon-catalog `CREATE FUNCTION ... RETURN ...` - * (`CreateUserDefinedFunction`) into [[CreatePaimonV1FunctionCommand]], before Spark's - * `ResolveSessionCatalog` throws `MISSING_CATALOG_ABILITY.CREATE_FUNCTION`. Fields are read by name - * (not positional unapply) since `CreateUserDefinedFunction`'s arity differs across Spark 4.0/4.1. + * Parser-stage rule that rewrites a Paimon-catalog CREATE FUNCTION ... RETURN + * (CreateUserDefinedFunction) into CreatePaimonSQLFunctionCommand before Spark's + * ResolveSessionCatalog rejects it. Only does plan rewriting; analysis, validation, and derivation + * happen in CreatePaimonSQLFunctionCommand.run(). */ case class RewritePaimonSQLFunctionCommands(spark: SparkSession) extends Rule[LogicalPlan] { @@ -56,7 +54,8 @@ case class RewritePaimonSQLFunctionCommands(spark: SparkSession) extends Rule[Lo throw new UnsupportedOperationException( s"Paimon does not support creating SQL table functions yet: $funcIdent") } - val paimonFunction = SQLFunctionConverter.toPaimonFunction( + CreatePaimonSQLFunctionCommand( + catalog, funcIdent, c.inputParamText, c.returnTypeText, @@ -65,13 +64,10 @@ case class RewritePaimonSQLFunctionCommands(spark: SparkSession) extends Rule[Lo c.comment, c.isDeterministic, c.containsSQL, - spark.sessionState.sqlParser) - CreatePaimonV1FunctionCommand( - catalog, - funcIdent, - paimonFunction, + isTableFunc = false, c.ignoreIfExists, - c.replace) + c.replace + ) case _ => c } }