Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 147 additions & 0 deletions datafusion/functions/src/core/cast_to_type.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! [`CastToTypeFunc`]: Implementation of the `cast_to_type`

use arrow::datatypes::{DataType, Field, FieldRef};
use datafusion_common::{
Result, datatype::DataTypeExt, internal_err, utils::take_function_args,
};
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
use datafusion_expr::{
Coercion, ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs,
ScalarUDFImpl, Signature, TypeSignatureClass, Volatility,
};
use datafusion_macros::user_doc;

/// Casts the first argument to the data type of the second argument.
///
/// Only the type of the second argument is used; its value is ignored.
/// This is useful in macros or generic SQL where you need to preserve
/// or match types dynamically.
///
/// For example:
/// ```sql
/// select cast_to_type('42', NULL::INTEGER);
/// ```
#[user_doc(
doc_section(label = "Other Functions"),
description = "Casts the first argument to the data type of the second argument. Only the type of the second argument is used; its value is ignored.",
syntax_example = "cast_to_type(expression, reference)",
sql_example = r#"```sql
> select cast_to_type('42', NULL::INTEGER) as a;
+----+
| a |
+----+
| 42 |
+----+

> select cast_to_type(1 + 2, NULL::DOUBLE) as b;
+-----+
| b |
+-----+
| 3.0 |
+-----+
```"#,
argument(
name = "expression",
description = "Expression to cast. The expression can be a constant, column, or function, and any combination of operators."
),
argument(
name = "reference",
description = "Reference expression whose data type determines the target cast type. The value is ignored."
)
)]
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct CastToTypeFunc {
signature: Signature,
}

impl Default for CastToTypeFunc {
fn default() -> Self {
Self::new()
}
}

impl CastToTypeFunc {
pub fn new() -> Self {
Self {
signature: Signature::coercible(
vec![
Coercion::new_exact(TypeSignatureClass::Any),
Coercion::new_exact(TypeSignatureClass::Any),
],
Volatility::Immutable,
),
}
}
}

impl ScalarUDFImpl for CastToTypeFunc {
fn name(&self) -> &str {
"cast_to_type"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
internal_err!("return_field_from_args should be called instead")
}

fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
let [_, reference_field] = take_function_args(self.name(), args.arg_fields)?;
let target_type = reference_field.data_type().clone();
Ok(Field::new(self.name(), target_type, nullable).into())
}

fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
internal_err!("cast_to_type should have been simplified to cast")
}

fn simplify(
&self,
mut args: Vec<Expr>,
info: &SimplifyContext,
) -> Result<ExprSimplifyResult> {
let [_, type_arg] = take_function_args(self.name(), &args)?;
let target_type = info.get_data_type(type_arg)?;

// remove second (reference) argument
args.pop().unwrap();
let arg = args.pop().unwrap();

let source_type = info.get_data_type(&arg)?;
let new_expr = if source_type == target_type {
// the argument's data type is already the correct type
arg
} else {
// Use an actual cast to get the correct type
Expr::Cast(datafusion_expr::Cast {
expr: Box::new(arg),
field: target_type.into_nullable_field_ref(),
})
};
Ok(ExprSimplifyResult::Simplified(new_expr))
}

fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
}
14 changes: 14 additions & 0 deletions datafusion/functions/src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pub mod arrow_cast;
pub mod arrow_metadata;
pub mod arrow_try_cast;
pub mod arrowtypeof;
pub mod cast_to_type;
pub mod coalesce;
pub mod expr_ext;
pub mod getfield;
Expand All @@ -37,13 +38,16 @@ pub mod nvl2;
pub mod overlay;
pub mod planner;
pub mod r#struct;
pub mod try_cast_to_type;
pub mod union_extract;
pub mod union_tag;
pub mod version;

// create UDFs
make_udf_function!(arrow_cast::ArrowCastFunc, arrow_cast);
make_udf_function!(arrow_try_cast::ArrowTryCastFunc, arrow_try_cast);
make_udf_function!(cast_to_type::CastToTypeFunc, cast_to_type);
make_udf_function!(try_cast_to_type::TryCastToTypeFunc, try_cast_to_type);
make_udf_function!(nullif::NullIfFunc, nullif);
make_udf_function!(nvl::NVLFunc, nvl);
make_udf_function!(nvl2::NVL2Func, nvl2);
Expand Down Expand Up @@ -75,6 +79,14 @@ pub mod expr_fn {
arrow_try_cast,
"Casts a value to a specific Arrow data type, returning NULL if the cast fails",
arg1 arg2
),(
cast_to_type,
"Casts the first argument to the data type of the second argument",
arg1 arg2
),(
try_cast_to_type,
"Casts the first argument to the data type of the second argument, returning NULL on failure",
arg1 arg2
),(
nvl,
"Returns value2 if value1 is NULL; otherwise it returns value1",
Expand Down Expand Up @@ -147,6 +159,8 @@ pub fn functions() -> Vec<Arc<ScalarUDF>> {
nullif(),
arrow_cast(),
arrow_try_cast(),
cast_to_type(),
try_cast_to_type(),
arrow_metadata(),
nvl(),
nvl2(),
Expand Down
135 changes: 135 additions & 0 deletions datafusion/functions/src/core/try_cast_to_type.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! [`TryCastToTypeFunc`]: Implementation of the `try_cast_to_type`

use arrow::datatypes::{DataType, Field, FieldRef};
use datafusion_common::{
Result, datatype::DataTypeExt, internal_err, utils::take_function_args,
};
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
use datafusion_expr::{
Coercion, ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs,
ScalarUDFImpl, Signature, TypeSignatureClass, Volatility,
};
use datafusion_macros::user_doc;

/// Like [`cast_to_type`](super::cast_to_type::CastToTypeFunc) but returns NULL
/// on cast failure instead of erroring.
///
/// This is implemented by simplifying `try_cast_to_type(expr, ref)` into
/// `Expr::TryCast` during optimization.
#[user_doc(
doc_section(label = "Other Functions"),
description = "Casts the first argument to the data type of the second argument, returning NULL if the cast fails. Only the type of the second argument is used; its value is ignored.",
syntax_example = "try_cast_to_type(expression, reference)",
sql_example = r#"```sql
> select try_cast_to_type('123', NULL::INTEGER) as a,
try_cast_to_type('not_a_number', NULL::INTEGER) as b;

+-----+------+
| a | b |
+-----+------+
| 123 | NULL |
+-----+------+
```"#,
argument(
name = "expression",
description = "Expression to cast. The expression can be a constant, column, or function, and any combination of operators."
),
argument(
name = "reference",
description = "Reference expression whose data type determines the target cast type. The value is ignored."
)
)]
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct TryCastToTypeFunc {
signature: Signature,
}

impl Default for TryCastToTypeFunc {
fn default() -> Self {
Self::new()
}
}

impl TryCastToTypeFunc {
pub fn new() -> Self {
Self {
signature: Signature::coercible(
vec![
Coercion::new_exact(TypeSignatureClass::Any),
Coercion::new_exact(TypeSignatureClass::Any),
],
Volatility::Immutable,
),
}
}
}

impl ScalarUDFImpl for TryCastToTypeFunc {
fn name(&self) -> &str {
"try_cast_to_type"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
internal_err!("return_field_from_args should be called instead")
}

fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
// TryCast can always return NULL (on cast failure), so always nullable
let [_, reference_field] = take_function_args(self.name(), args.arg_fields)?;
let target_type = reference_field.data_type().clone();
Ok(Field::new(self.name(), target_type, true).into())
}

fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
internal_err!("try_cast_to_type should have been simplified to try_cast")
}

fn simplify(
&self,
mut args: Vec<Expr>,
info: &SimplifyContext,
) -> Result<ExprSimplifyResult> {
let [_, type_arg] = take_function_args(self.name(), &args)?;
let target_type = info.get_data_type(type_arg)?;

// remove second (reference) argument
args.pop().unwrap();
let arg = args.pop().unwrap();

let source_type = info.get_data_type(&arg)?;
let new_expr = if source_type == target_type {
arg
} else {
Expr::TryCast(datafusion_expr::TryCast {
expr: Box::new(arg),
field: target_type.into_nullable_field_ref(),
})
};
Ok(ExprSimplifyResult::Simplified(new_expr))
}

fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
}
Loading
Loading