Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ object CometBatchKernelCodegen extends Logging with CometExprTraitShim {
val subExprsCode = ctx.subexprFunctionsCode
val (cls, setup, snippet) =
CometBatchKernelCodegenOutput.emitOutputWriter(boundExpr.dataType, ev.value, ctx)
(cls, setup, defaultBody(boundExpr, ev, snippet, subExprsCode))
(cls, setup, defaultBody(boundExpr, inputSchema, ev, snippet, subExprsCode))
}

val typedFieldDecls = CometBatchKernelCodegenInput.emitInputFieldDecls(inputSchema)
Expand Down Expand Up @@ -343,6 +343,7 @@ object CometBatchKernelCodegen extends Logging with CometExprTraitShim {
*/
private def defaultBody(
boundExpr: Expression,
inputSchema: Seq[ArrowColumnSpec],
ev: ExprCode,
writeSnippet: String,
subExprsCode: String): String = {
Expand All @@ -353,9 +354,17 @@ object CometBatchKernelCodegen extends Logging with CometExprTraitShim {
// make this incorrect (`coalesce(null, x)` is `x`); `allNullIntolerant` rejects those.
val inputOrdinals =
boundExpr.collect { case b: BoundReference => b.ordinal }.distinct
// Primitive Arrow vectors are wrapped in `CometPlainVector` at input-cast time, which
// exposes `isNullAt(int)` rather than the raw Arrow `isNull(int)`. Pick the right method
// per ordinal so the short-circuit compiles for timestamp / int / float columns too,
// not just VarChar / Decimal vectors that stay as raw Arrow types.
def nullCheckCall(ord: Int): String = {
val method = CometBatchKernelCodegenInput.nullCheckMethod(inputSchema(ord))
s"this.col$ord.$method(i)"
}
val nullCheck =
if (inputOrdinals.isEmpty) "false"
else inputOrdinals.map(ord => s"this.col$ord.isNull(i)").mkString(" || ")
else inputOrdinals.map(nullCheckCall).mkString(" || ")
s"""
|if ($nullCheck) {
| output.setNull(i);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,8 +404,10 @@ private[codegen] object CometBatchKernelCodegenInput {
/**
* Java method name for the per-column null check. Primitive scalars wrapped in
* [[CometPlainVector]] expose `isNullAt`; Arrow typed fields expose `isNull`. Same semantics.
* Used both by `emitTypedGetters` (for the kernel's `isNullAt` switch) and by
* `CometBatchKernelCodegen.defaultBody` (for the `NullIntolerant` short-circuit).
*/
private def nullCheckMethod(spec: ArrowColumnSpec): String = spec match {
def nullCheckMethod(spec: ArrowColumnSpec): String = spec match {
case sc: ScalarColumnSpec if wrapsInCometPlainVector(sc.vectorClass) => "isNullAt"
case _ => "isNull"
}
Expand Down
28 changes: 24 additions & 4 deletions spark/src/main/scala/org/apache/comet/serde/CometScalaUDF.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
package org.apache.comet.serde

import org.apache.spark.SparkEnv
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSeq, BindReferences, Literal, ScalaUDF}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSeq, BindReferences, Expression, Literal, ScalaUDF}
import org.apache.spark.sql.types.BinaryType

import org.apache.comet.CometConf
Expand All @@ -45,15 +45,35 @@ import org.apache.comet.udf.codegen.CometScalaUDFCodegen
*
* Gated by [[CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED]]. When disabled, plans containing a
* `ScalaUDF` fall back to Spark for the enclosing operator.
*
* [[emitJvmCodegenDispatch]] exposes the same closure-serialize + dispatcher-proto path to other
* serdes that want to keep a built-in Spark expression inside the Comet pipeline when no native
* lowering is viable. See [[CometDateFormat]] for an example.
*/
object CometScalaUDF extends CometExpressionSerde[ScalaUDF] {

override def convert(expr: ScalaUDF, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = {
override def convert(expr: ScalaUDF, inputs: Seq[Attribute], binding: Boolean): Option[Expr] =
emitJvmCodegenDispatch(expr, inputs, binding)

/**
* Bind `expr`, closure-serialize it, and emit a `JvmScalarUdf` proto routed through
* [[CometScalaUDFCodegen]] so that native execution evaluates the expression inside the
* Arrow-direct codegen dispatcher. The dispatcher will Janino-compile `expr.doGenCode` into a
* batch kernel on first invocation per task.
*
* Returns `None` (with `withInfo` tagging the reason) when the dispatcher is disabled via
* [[CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED]] or when [[CometBatchKernelCodegen.canHandle]]
* refuses the expression tree. Callers should treat `None` as a clean Spark-fallback signal.
*/
def emitJvmCodegenDispatch(
expr: Expression,
inputs: Seq[Attribute],
binding: Boolean): Option[Expr] = {
if (!CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.get()) {
withInfo(
expr,
s"${CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key}=false; ScalaUDF has no native path " +
"so the plan falls back to Spark")
s"${CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key}=false; expression has no native " +
"path so the plan falls back to Spark")
return None
}

Expand Down
103 changes: 47 additions & 56 deletions spark/src/main/scala/org/apache/comet/serde/datetime.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DateType, DoubleType, FloatType, IntegerType, LongType, StringType, TimestampNTZType, TimestampType}
import org.apache.spark.unsafe.types.UTF8String

import org.apache.comet.CometConf
import org.apache.comet.CometSparkSessionExtensions.withInfo
import org.apache.comet.expressions.{CometCast, CometEvalMode}
import org.apache.comet.serde.CometGetDateField.CometGetDateField
Expand Down Expand Up @@ -593,17 +594,23 @@ object CometTruncTimestamp extends CometExpressionSerde[TruncTimestamp] {
}

/**
* Converts Spark DateFormatClass expression to DataFusion's to_char function.
* Converts Spark `DateFormatClass` to DataFusion's `to_char` when format and timezone are
* mappable, otherwise routes the expression through the Arrow-direct codegen dispatcher so that
* Spark's own `DateFormatClass.doGenCode` runs inside the Comet pipeline.
*
* Spark uses Java SimpleDateFormat patterns while DataFusion uses strftime patterns. This
* implementation supports a whitelist of common format strings that can be reliably mapped
* between the two systems.
* Routing:
* - format is a literal in `supportedFormats` AND timezone is UTC -> native `to_char`
* - format is a literal in `supportedFormats` AND timezone is non-UTC, with the per-expression
* `allowIncompatible` flag set -> native `to_char` (results may differ from Spark)
* - all other cases -> JVM codegen dispatcher ([[CometScalaUDF.emitJvmCodegenDispatch]]), gated
* by [[CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED]]. When that flag is disabled the operator
* falls back to Spark.
*/
object CometDateFormat extends CometExpressionSerde[DateFormatClass] {

/**
* Mapping from Spark SimpleDateFormat patterns to strftime patterns. Only formats in this map
* are supported.
* are supported by the native path.
*/
val supportedFormats: Map[String, String] = Map(
// Full date formats
Expand Down Expand Up @@ -637,66 +644,50 @@ object CometDateFormat extends CometExpressionSerde[DateFormatClass] {
// ISO formats
"yyyy-MM-dd'T'HH:mm:ss" -> "%Y-%m-%dT%H:%M:%S")

override def getIncompatibleReasons(): Seq[String] = Seq(
"Non-UTC timezones may produce different results than Spark")
// Compatibility is decided inside `convert`: the native path covers a subset, and the codegen
// dispatcher covers everything else when enabled. Plan-time tagging happens via `withInfo` on
// the path that returns None.
override def getSupportLevel(expr: DateFormatClass): SupportLevel = Compatible()

override def getUnsupportedReasons(): Seq[String] = Seq(
"Only the following formats are supported:" +
supportedFormats.keys.toSeq.sorted
.map(k => s"`$k`")
.mkString("\n - ", "\n - ", ""))

override def getSupportLevel(expr: DateFormatClass): SupportLevel = {
// Check timezone - only UTC is fully compatible
val timezone = expr.timeZoneId.getOrElse("UTC")
val isUtc = timezone == "UTC" || timezone == "Etc/UTC"

expr.right match {
case Literal(fmt: UTF8String, _) =>
val format = fmt.toString
if (supportedFormats.contains(format)) {
if (isUtc) {
Compatible()
} else {
Incompatible(Some(s"Non-UTC timezone '$timezone' may produce different results"))
}
} else {
Unsupported(
Some(
s"Format '$format' is not supported. Supported formats: " +
supportedFormats.keys.mkString(", ")))
}
case _ =>
Unsupported(Some("Only literal format strings are supported"))
}
}
override def getCompatibleNotes(): Seq[String] = Seq(
"Format strings in a curated allow-list run natively via DataFusion's `to_char` for UTC " +
"sessions. Other format strings (including non-literal formats), as well as non-UTC " +
"sessions, route through Spark's own `DateFormatClass.doGenCode` via the Arrow-direct " +
"codegen dispatcher when `spark.comet.exec.scalaUDF.codegen.enabled=true`. When the " +
"codegen dispatcher is disabled (default) the operator falls back to Spark in those " +
"cases.")

override def convert(
expr: DateFormatClass,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
// Get the format string - must be a literal for us to map it
val strftimeFormat = expr.right match {
case Literal(fmt: UTF8String, _) =>
supportedFormats.get(fmt.toString)
val timezone = expr.timeZoneId.getOrElse("UTC")
val isUtc = timezone == "UTC" || timezone == "Etc/UTC"

val nativeFormat: Option[String] = expr.right match {
case Literal(fmt: UTF8String, _) => supportedFormats.get(fmt.toString)
case _ => None
}

strftimeFormat match {
case Some(format) =>
val childExpr = exprToProtoInternal(expr.left, inputs, binding)
val formatExpr = exprToProtoInternal(Literal(format), inputs, binding)

val optExpr = scalarFunctionExprToProtoWithReturnType(
"to_char",
StringType,
false,
childExpr,
formatExpr)
optExprWithInfo(optExpr, expr, expr.left, expr.right)
case None =>
withInfo(expr, expr.left, expr.right)
None
val canUseNative = nativeFormat.isDefined && {
isUtc || CometConf.isExprAllowIncompat(getExprConfigName(expr))
}

if (canUseNative) {
val childExpr = exprToProtoInternal(expr.left, inputs, binding)
val formatExpr = exprToProtoInternal(Literal(nativeFormat.get), inputs, binding)
val optExpr = scalarFunctionExprToProtoWithReturnType(
"to_char",
StringType,
false,
childExpr,
formatExpr)
optExprWithInfo(optExpr, expr, expr.left, expr.right)
} else {
// Hand the full `DateFormatClass` (with `timeZoneId` already stamped by `ResolveTimeZone`)
// to the codegen dispatcher. It closure-serializes the bound tree, so non-UTC timezones
// and non-whitelisted / non-literal format strings produce Spark-identical results.
CometScalaUDF.emitJvmCodegenDispatch(expr, inputs, binding)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,27 @@
-- specific language governing permissions and limitations
-- under the License.

-- Pin the session timezone so the test exercises the non-UTC path regardless of the JVM
-- default. Enable the codegen dispatcher so non-UTC and non-whitelisted formats stay inside
-- Comet via Spark's own DateFormatClass.doGenCode instead of falling back to Spark.
-- Config: spark.sql.session.timeZone=America/Los_Angeles
-- Config: spark.comet.exec.scalaUDF.codegen.enabled=true

statement
CREATE TABLE test_date_format(ts timestamp) USING parquet

statement
INSERT INTO test_date_format VALUES (timestamp('2024-06-15 10:30:45')), (timestamp('1970-01-01 00:00:00')), (NULL)

query expect_fallback(Non-UTC timezone)
query
SELECT date_format(ts, 'yyyy-MM-dd') FROM test_date_format

query expect_fallback(Non-UTC timezone)
query
SELECT date_format(ts, 'HH:mm:ss') FROM test_date_format

query expect_fallback(Non-UTC timezone)
query
SELECT date_format(ts, 'yyyy-MM-dd HH:mm:ss') FROM test_date_format

-- literal arguments
query expect_fallback(Non-UTC timezone)
query
SELECT date_format(timestamp('2024-06-15 10:30:45'), 'yyyy-MM-dd')
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp
.withColumn("arrUnsupportedArgs", expr("array_insert(arr, idx, 1)"))
checkSparkAnswerAndFallbackReasons(
df.select("arrUnsupportedArgs"),
Set("ScalaUDF has no native path", "unsupported arguments for ArrayInsert"))
Set("expression has no native path", "unsupported arguments for ArrayInsert"))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ package org.apache.comet
import org.scalatest.funsuite.AnyFunSuite

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Add, BoundReference, Coalesce, Concat, CreateArray, CreateMap, ElementAt, Expression, GetStructField, LeafExpression, Length, Literal, Nondeterministic, Rand, Size, Unevaluable, Upper}
import org.apache.spark.sql.catalyst.expressions.{Add, BoundReference, Coalesce, Concat, CreateArray, CreateMap, DateFormatClass, ElementAt, Expression, GetStructField, LeafExpression, Length, Literal, Nondeterministic, Rand, Size, Unevaluable, Upper}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodegenContext, CodegenFallback, ExprCode}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

import org.apache.comet.codegen.CometBatchKernelCodegen
import org.apache.comet.codegen.CometBatchKernelCodegen.{ArrayColumnSpec, ArrowColumnSpec, MapColumnSpec, ScalarColumnSpec, StructColumnSpec, StructFieldSpec}
Expand Down Expand Up @@ -61,6 +62,26 @@ class CometCodegenSourceSuite extends AnyFunSuite {
specs: ArrowColumnSpec*): String =
CometBatchKernelCodegen.generateSource(expr, specs.toIndexedSeq).body

test("NullIntolerant short-circuit uses isNullAt for CometPlainVector-wrapped columns") {
// Primitive Arrow vectors (timestamp / int / float / ...) are wrapped in `CometPlainVector`
// at input-cast time. The short-circuit must call `isNullAt(i)`, not `isNull(i)`, otherwise
// Janino fails to compile the kernel with "method isNull not declared". Verified end-to-end
// by `CometTemporalExpressionSuite` date_format tests over `TimeStampMicroTZVector` inputs.
val tsVec = CometBatchKernelCodegen.vectorClassBySimpleName("TimeStampMicroTZVector")
val spec = ArrowColumnSpec(tsVec, nullable = true)
val expr = DateFormatClass(
BoundReference(0, TimestampType, nullable = true),
Literal(UTF8String.fromString("yyyy-MM-dd EEEE"), StringType),
Some("UTC"))
val src = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(spec)).body
assert(
src.contains("if (this.col0.isNullAt(i))"),
s"expected short-circuit to use isNullAt for CometPlainVector-wrapped col0; got:\n$src")
assert(
!src.contains("if (this.col0.isNull(i))"),
s"expected no raw Arrow isNull on the CometPlainVector-wrapped col0; got:\n$src")
}

test("non-nullable column emits literal-false isNullAt case") {
val expr = Length(BoundReference(0, StringType, nullable = false))
val src = gen(expr, nonNullableString)
Expand Down
Loading
Loading