diff --git a/datafusion/functions/src/core/cast_to_type.rs b/datafusion/functions/src/core/cast_to_type.rs new file mode 100644 index 000000000000..22b91e470da5 --- /dev/null +++ b/datafusion/functions/src/core/cast_to_type.rs @@ -0,0 +1,147 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`CastToTypeFunc`]: Implementation of the `cast_to_type` + +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::{ + Result, datatype::DataTypeExt, internal_err, utils::take_function_args, +}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; +use datafusion_expr::{ + Coercion, ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, + ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, +}; +use datafusion_macros::user_doc; + +/// Casts the first argument to the data type of the second argument. +/// +/// Only the type of the second argument is used; its value is ignored. +/// This is useful in macros or generic SQL where you need to preserve +/// or match types dynamically. +/// +/// For example: +/// ```sql +/// select cast_to_type('42', NULL::INTEGER); +/// ``` +#[user_doc( + doc_section(label = "Other Functions"), + description = "Casts the first argument to the data type of the second argument. Only the type of the second argument is used; its value is ignored.", + syntax_example = "cast_to_type(expression, reference)", + sql_example = r#"```sql +> select cast_to_type('42', NULL::INTEGER) as a; ++----+ +| a | ++----+ +| 42 | ++----+ + +> select cast_to_type(1 + 2, NULL::DOUBLE) as b; ++-----+ +| b | ++-----+ +| 3.0 | ++-----+ +```"#, + argument( + name = "expression", + description = "Expression to cast. The expression can be a constant, column, or function, and any combination of operators." + ), + argument( + name = "reference", + description = "Reference expression whose data type determines the target cast type. The value is ignored." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct CastToTypeFunc { + signature: Signature, +} + +impl Default for CastToTypeFunc { + fn default() -> Self { + Self::new() + } +} + +impl CastToTypeFunc { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![ + Coercion::new_exact(TypeSignatureClass::Any), + Coercion::new_exact(TypeSignatureClass::Any), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for CastToTypeFunc { + fn name(&self) -> &str { + "cast_to_type" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be called instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + let [_, reference_field] = take_function_args(self.name(), args.arg_fields)?; + let target_type = reference_field.data_type().clone(); + Ok(Field::new(self.name(), target_type, nullable).into()) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!("cast_to_type should have been simplified to cast") + } + + fn simplify( + &self, + mut args: Vec, + info: &SimplifyContext, + ) -> Result { + let [_, type_arg] = take_function_args(self.name(), &args)?; + let target_type = info.get_data_type(type_arg)?; + + // remove second (reference) argument + args.pop().unwrap(); + let arg = args.pop().unwrap(); + + let source_type = info.get_data_type(&arg)?; + let new_expr = if source_type == target_type { + // the argument's data type is already the correct type + arg + } else { + // Use an actual cast to get the correct type + Expr::Cast(datafusion_expr::Cast { + expr: Box::new(arg), + field: target_type.into_nullable_field_ref(), + }) + }; + Ok(ExprSimplifyResult::Simplified(new_expr)) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index e8737612a1dc..d3c48573667c 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -24,6 +24,7 @@ pub mod arrow_cast; pub mod arrow_metadata; pub mod arrow_try_cast; pub mod arrowtypeof; +pub mod cast_to_type; pub mod coalesce; pub mod expr_ext; pub mod getfield; @@ -37,6 +38,7 @@ pub mod nvl2; pub mod overlay; pub mod planner; pub mod r#struct; +pub mod try_cast_to_type; pub mod union_extract; pub mod union_tag; pub mod version; @@ -44,6 +46,8 @@ pub mod version; // create UDFs make_udf_function!(arrow_cast::ArrowCastFunc, arrow_cast); make_udf_function!(arrow_try_cast::ArrowTryCastFunc, arrow_try_cast); +make_udf_function!(cast_to_type::CastToTypeFunc, cast_to_type); +make_udf_function!(try_cast_to_type::TryCastToTypeFunc, try_cast_to_type); make_udf_function!(nullif::NullIfFunc, nullif); make_udf_function!(nvl::NVLFunc, nvl); make_udf_function!(nvl2::NVL2Func, nvl2); @@ -75,6 +79,14 @@ pub mod expr_fn { arrow_try_cast, "Casts a value to a specific Arrow data type, returning NULL if the cast fails", arg1 arg2 + ),( + cast_to_type, + "Casts the first argument to the data type of the second argument", + arg1 arg2 + ),( + try_cast_to_type, + "Casts the first argument to the data type of the second argument, returning NULL on failure", + arg1 arg2 ),( nvl, "Returns value2 if value1 is NULL; otherwise it returns value1", @@ -147,6 +159,8 @@ pub fn functions() -> Vec> { nullif(), arrow_cast(), arrow_try_cast(), + cast_to_type(), + try_cast_to_type(), arrow_metadata(), nvl(), nvl2(), diff --git a/datafusion/functions/src/core/try_cast_to_type.rs b/datafusion/functions/src/core/try_cast_to_type.rs new file mode 100644 index 000000000000..4eed4ab7ddd2 --- /dev/null +++ b/datafusion/functions/src/core/try_cast_to_type.rs @@ -0,0 +1,135 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`TryCastToTypeFunc`]: Implementation of the `try_cast_to_type` + +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::{ + Result, datatype::DataTypeExt, internal_err, utils::take_function_args, +}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; +use datafusion_expr::{ + Coercion, ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, + ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, +}; +use datafusion_macros::user_doc; + +/// Like [`cast_to_type`](super::cast_to_type::CastToTypeFunc) but returns NULL +/// on cast failure instead of erroring. +/// +/// This is implemented by simplifying `try_cast_to_type(expr, ref)` into +/// `Expr::TryCast` during optimization. +#[user_doc( + doc_section(label = "Other Functions"), + description = "Casts the first argument to the data type of the second argument, returning NULL if the cast fails. Only the type of the second argument is used; its value is ignored.", + syntax_example = "try_cast_to_type(expression, reference)", + sql_example = r#"```sql +> select try_cast_to_type('123', NULL::INTEGER) as a, + try_cast_to_type('not_a_number', NULL::INTEGER) as b; + ++-----+------+ +| a | b | ++-----+------+ +| 123 | NULL | ++-----+------+ +```"#, + argument( + name = "expression", + description = "Expression to cast. The expression can be a constant, column, or function, and any combination of operators." + ), + argument( + name = "reference", + description = "Reference expression whose data type determines the target cast type. The value is ignored." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct TryCastToTypeFunc { + signature: Signature, +} + +impl Default for TryCastToTypeFunc { + fn default() -> Self { + Self::new() + } +} + +impl TryCastToTypeFunc { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![ + Coercion::new_exact(TypeSignatureClass::Any), + Coercion::new_exact(TypeSignatureClass::Any), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for TryCastToTypeFunc { + fn name(&self) -> &str { + "try_cast_to_type" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be called instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + // TryCast can always return NULL (on cast failure), so always nullable + let [_, reference_field] = take_function_args(self.name(), args.arg_fields)?; + let target_type = reference_field.data_type().clone(); + Ok(Field::new(self.name(), target_type, true).into()) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!("try_cast_to_type should have been simplified to try_cast") + } + + fn simplify( + &self, + mut args: Vec, + info: &SimplifyContext, + ) -> Result { + let [_, type_arg] = take_function_args(self.name(), &args)?; + let target_type = info.get_data_type(type_arg)?; + + // remove second (reference) argument + args.pop().unwrap(); + let arg = args.pop().unwrap(); + + let source_type = info.get_data_type(&arg)?; + let new_expr = if source_type == target_type { + arg + } else { + Expr::TryCast(datafusion_expr::TryCast { + expr: Box::new(arg), + field: target_type.into_nullable_field_ref(), + }) + }; + Ok(ExprSimplifyResult::Simplified(new_expr)) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} diff --git a/datafusion/sqllogictest/test_files/cast_to_type.slt b/datafusion/sqllogictest/test_files/cast_to_type.slt new file mode 100644 index 000000000000..e48da7cb2d87 --- /dev/null +++ b/datafusion/sqllogictest/test_files/cast_to_type.slt @@ -0,0 +1,257 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +####### +## Tests for cast_to_type function +####### + +# Basic string to integer cast +query I +SELECT cast_to_type('42', NULL::INTEGER); +---- +42 + +# String to double cast +query R +SELECT cast_to_type('3.14', NULL::DOUBLE); +---- +3.14 + +# Integer to string cast +query T +SELECT cast_to_type(42, NULL::VARCHAR); +---- +42 + +# Integer to double cast +query R +SELECT cast_to_type(42, NULL::DOUBLE); +---- +42 + +# Same-type is a no-op +query I +SELECT cast_to_type(42, 0::INTEGER); +---- +42 + +# NULL first argument +query I +SELECT cast_to_type(NULL, 0::INTEGER); +---- +NULL + +# NULL reference (type still applies) +query I +SELECT cast_to_type('42', NULL::INTEGER); +---- +42 + +# CASE expression as first argument +query I +SELECT cast_to_type(CASE WHEN true THEN '1' ELSE '2' END, NULL::INTEGER); +---- +1 + +# Arithmetic expression as first argument +query R +SELECT cast_to_type(1 + 2, NULL::DOUBLE); +---- +3 + +# Nested cast_to_type +query T +SELECT cast_to_type(cast_to_type('3.14', NULL::DOUBLE), NULL::VARCHAR); +---- +3.14 + +# Subquery as second argument +query I +SELECT cast_to_type('42', (SELECT NULL::INTEGER)); +---- +42 + +# Column reference as second argument +statement ok +CREATE TABLE t1 (int_col INTEGER, text_col VARCHAR, double_col DOUBLE); + +statement ok +INSERT INTO t1 VALUES (1, 'hello', 3.14), (2, 'world', 2.72); + +query I +SELECT cast_to_type('99', int_col) FROM t1 LIMIT 1; +---- +99 + +query T +SELECT cast_to_type(123, text_col) FROM t1 LIMIT 1; +---- +123 + +query R +SELECT cast_to_type('1.5', double_col) FROM t1 LIMIT 1; +---- +1.5 + +# Use with column values as first argument +query R +SELECT cast_to_type(int_col, NULL::DOUBLE) FROM t1; +---- +1 +2 + +# Cast column to match another column's type +query T +SELECT cast_to_type(int_col, text_col) FROM t1; +---- +1 +2 + +# Boolean cast +query B +SELECT cast_to_type(1, NULL::BOOLEAN); +---- +true + +# String to date cast +query D +SELECT cast_to_type('2024-01-15', NULL::DATE); +---- +2024-01-15 + +# Error on invalid cast +statement error +SELECT cast_to_type('not_a_number', NULL::INTEGER); + +statement ok +DROP TABLE t1; + +####### +## Tests for try_cast_to_type function (fallible variant returning NULL) +####### + +# Basic string to integer cast +query I +SELECT try_cast_to_type('42', NULL::INTEGER); +---- +42 + +# Invalid cast returns NULL instead of error +query I +SELECT try_cast_to_type('not_a_number', NULL::INTEGER); +---- +NULL + +# String to double cast +query R +SELECT try_cast_to_type('3.14', NULL::DOUBLE); +---- +3.14 + +# Invalid double returns NULL +query R +SELECT try_cast_to_type('abc', NULL::DOUBLE); +---- +NULL + +# Integer to string cast (always succeeds) +query T +SELECT try_cast_to_type(42, NULL::VARCHAR); +---- +42 + +# Same-type is a no-op +query I +SELECT try_cast_to_type(42, 0::INTEGER); +---- +42 + +# NULL first argument +query I +SELECT try_cast_to_type(NULL, 0::INTEGER); +---- +NULL + +# CASE expression as first argument +query I +SELECT try_cast_to_type(CASE WHEN true THEN '1' ELSE '2' END, NULL::INTEGER); +---- +1 + +# Arithmetic expression as first argument +query R +SELECT try_cast_to_type(1 + 2, NULL::DOUBLE); +---- +3 + +# Nested: try_cast_to_type inside cast_to_type +query T +SELECT cast_to_type(try_cast_to_type('3.14', NULL::DOUBLE), NULL::VARCHAR); +---- +3.14 + +# Subquery as second argument +query I +SELECT try_cast_to_type('42', (SELECT NULL::INTEGER)); +---- +42 + +# Column reference as second argument +statement ok +CREATE TABLE t2 (int_col INTEGER, text_col VARCHAR); + +statement ok +INSERT INTO t2 VALUES (1, 'hello'), (2, 'world'); + +query I +SELECT try_cast_to_type('99', int_col) FROM t2 LIMIT 1; +---- +99 + +query I +SELECT try_cast_to_type(text_col, int_col) FROM t2; +---- +NULL +NULL + +# Cast column to match another column's type +query T +SELECT try_cast_to_type(int_col, text_col) FROM t2; +---- +1 +2 + +# Boolean cast +query B +SELECT try_cast_to_type(1, NULL::BOOLEAN); +---- +true + +# String to date - valid +query D +SELECT try_cast_to_type('2024-01-15', NULL::DATE); +---- +2024-01-15 + +# String to date - invalid returns NULL +query D +SELECT try_cast_to_type('not_a_date', NULL::DATE); +---- +NULL + +statement ok +DROP TABLE t2; diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 022b0f9daec8..8d8fe0e2c816 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -5191,7 +5191,9 @@ union_tag(union_expression) - [arrow_metadata](#arrow_metadata) - [arrow_try_cast](#arrow_try_cast) - [arrow_typeof](#arrow_typeof) +- [cast_to_type](#cast_to_type) - [get_field](#get_field) +- [try_cast_to_type](#try_cast_to_type) - [version](#version) ### `arrow_cast` @@ -5311,6 +5313,37 @@ arrow_typeof(expression) +---------------------------+------------------------+ ``` +### `cast_to_type` + +Casts the first argument to the data type of the second argument. Only the type of the second argument is used; its value is ignored. + +```sql +cast_to_type(expression, reference) +``` + +#### Arguments + +- **expression**: Expression to cast. The expression can be a constant, column, or function, and any combination of operators. +- **reference**: Reference expression whose data type determines the target cast type. The value is ignored. + +#### Example + +```sql +> select cast_to_type('42', NULL::INTEGER) as a; ++----+ +| a | ++----+ +| 42 | ++----+ + +> select cast_to_type(1 + 2, NULL::DOUBLE) as b; ++-----+ +| b | ++-----+ +| 3.0 | ++-----+ +``` + ### `get_field` Returns a field within a map or a struct with the given key. @@ -5363,6 +5396,32 @@ get_field(expression, field_name[, field_name2, ...]) +--------+ ``` +### `try_cast_to_type` + +Casts the first argument to the data type of the second argument, returning NULL if the cast fails. Only the type of the second argument is used; its value is ignored. + +```sql +try_cast_to_type(expression, reference) +``` + +#### Arguments + +- **expression**: Expression to cast. The expression can be a constant, column, or function, and any combination of operators. +- **reference**: Reference expression whose data type determines the target cast type. The value is ignored. + +#### Example + +```sql +> select try_cast_to_type('123', NULL::INTEGER) as a, + try_cast_to_type('not_a_number', NULL::INTEGER) as b; + ++-----+------+ +| a | b | ++-----+------+ +| 123 | NULL | ++-----+------+ +``` + ### `version` Returns the version of DataFusion.