From 867a83f1b9997846afc6ef313baeb3c6a179f576 Mon Sep 17 00:00:00 2001 From: Alexander Droste Date: Fri, 29 May 2026 13:35:58 +0000 Subject: [PATCH] feat(cuda): widen decimals for Arrow device export Arrow Decimal128/Decimal256 schemas require fixed 16/32-byte value buffers, while Vortex decimals may use narrower storage. Add a CUDA widening kernel for Arrow Device export and cover compact, wide, nullable, and empty decimal cases. Signed-off-by: Alexander Droste --- vortex-cuda/kernels/src/decimal_cast.cu | 76 +++++++ vortex-cuda/src/arrow/canonical.rs | 283 +++++++++++++++++++++--- vortex-test/e2e-cuda/src/lib.rs | 4 +- 3 files changed, 330 insertions(+), 33 deletions(-) create mode 100644 vortex-cuda/kernels/src/decimal_cast.cu diff --git a/vortex-cuda/kernels/src/decimal_cast.cu b/vortex-cuda/kernels/src/decimal_cast.cu new file mode 100644 index 00000000000..4331f15dd4f --- /dev/null +++ b/vortex-cuda/kernels/src/decimal_cast.cu @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +#include "config.cuh" +#include "types.cuh" +#include +#include + +// Arrow decimal schemas fix the physical values buffer width: +// - Decimal128: 16 bytes per value. +// - Decimal256: 32 bytes per value. +// +// Vortex may use narrower decimal storage, so Arrow Device export widens values +// to match the schema-implied physical layout consumed by cuDF and other Arrow +// readers. +// Converts a decimal storage value to Arrow's 128-bit decimal physical representation. +template +__device__ __forceinline__ int128_t decimal_to_i128(Input value) { + if constexpr (std::is_same_v) { + return value; + } else if constexpr (std::is_same_v) { + return int128_t {value.parts[0], value.parts[1]}; + } else { + const int64_t lo = static_cast(value); + const int64_t hi = value < 0 ? -1 : 0; + return int128_t {lo, hi}; + } +} + +// Converts one decimal value to the requested Arrow decimal physical representation. +template +__device__ __forceinline__ Output decimal_cast_value(Input value) { + if constexpr (std::is_same_v && std::is_same_v) { + return value; + } else if constexpr (std::is_same_v) { + return decimal_to_i128(value); + } else { + static_assert(std::is_same_v); + const int128_t value128 = decimal_to_i128(value); + const int64_t sign = value128.hi < 0 ? -1 : 0; + return int256_t {{value128.lo, value128.hi, sign, sign}}; + } +} + +// Widens a contiguous decimal values buffer on the device. +template +__device__ void +decimal_cast_device(const Input *__restrict input, Output *__restrict output, uint64_t array_len) { + const uint64_t worker = blockIdx.x * blockDim.x + threadIdx.x; + const uint64_t startElem = start_elem(worker, array_len); + const uint64_t stopElem = stop_elem(worker, array_len); + + if (startElem >= array_len) { + return; + } + + for (uint64_t idx = startElem; idx < stopElem; idx++) { + output[idx] = decimal_cast_value(input[idx]); + } +} + +// Generates Arrow Decimal128 and Decimal256 widening kernels for one input storage type. +#define GENERATE_DECIMAL_CAST_KERNELS(input_suffix, InputType) \ + extern "C" __global__ void decimal_cast_##input_suffix##_i128(const InputType *__restrict input, \ + int128_t *__restrict output, \ + uint64_t array_len) { \ + decimal_cast_device(input, output, array_len); \ + } \ + extern "C" __global__ void decimal_cast_##input_suffix##_i256(const InputType *__restrict input, \ + int256_t *__restrict output, \ + uint64_t array_len) { \ + decimal_cast_device(input, output, array_len); \ + } + +FOR_EACH_SIGNED_INT(GENERATE_DECIMAL_CAST_KERNELS) +FOR_EACH_LARGE_DECIMAL(GENERATE_DECIMAL_CAST_KERNELS) diff --git a/vortex-cuda/src/arrow/canonical.rs b/vortex-cuda/src/arrow/canonical.rs index ebda0e04d6c..c6f3feb99e6 100644 --- a/vortex-cuda/src/arrow/canonical.rs +++ b/vortex-cuda/src/arrow/canonical.rs @@ -3,11 +3,15 @@ use std::mem; use std::ptr; +use std::sync::Arc; use async_trait::async_trait; +use cudarc::driver::DeviceRepr; +use cudarc::driver::PushKernelArg; use futures::future::BoxFuture; use vortex::array::ArrayRef; use vortex::array::Canonical; +use vortex::array::arrays::DecimalArray; use vortex::array::arrays::PrimitiveArray; use vortex::array::arrays::StructArray; use vortex::array::arrays::bool::BoolDataParts; @@ -17,16 +21,22 @@ use vortex::array::arrays::primitive::PrimitiveDataParts; use vortex::array::arrays::struct_::StructDataParts; use vortex::array::arrays::varbinview::VarBinViewDataParts; use vortex::array::buffer::BufferHandle; +use vortex::array::match_each_decimal_value_type; use vortex::array::validity::Validity; use vortex::buffer::Buffer; use vortex::buffer::ByteBuffer; +use vortex::dtype::DecimalDType; use vortex::dtype::DecimalType; +use vortex::dtype::NativeDecimalType; +use vortex::dtype::i256; use vortex::error::VortexResult; use vortex::error::vortex_bail; use vortex::error::vortex_ensure; use vortex::extension::datetime::AnyTemporal; use vortex::mask::Mask; +use crate::CudaBufferExt; +use crate::CudaDeviceBuffer; use crate::CudaExecutionCtx; use crate::arrow::ARROW_DEVICE_CUDA; use crate::arrow::ArrowArray; @@ -94,27 +104,7 @@ fn export_canonical( // we don't need a sync event for Null since no data is copied. Ok((array, ptr::null_mut())) } - Canonical::Decimal(decimal) => { - let len = decimal.len(); - let DecimalDataParts { - values, - values_type, - validity, - .. - } = decimal.into_data_parts(); - - // TODO(aduffy): GPU kernel for upcasting. - vortex_ensure!( - values_type >= DecimalType::I32, - "cannot export DecimalArray with values type {values_type}. must be i32 or wider." - ); - - let (validity_buffer, null_count) = - export_arrow_validity_buffer(validity, len, 0, ctx).await?; - let buffer = ctx.ensure_on_device(values).await?; - - export_fixed_size(buffer, len, 0, validity_buffer, null_count, ctx) - } + Canonical::Decimal(decimal) => export_decimal(decimal, ctx).await, Canonical::Extension(extension) => { if !extension.ext_dtype().is::() { vortex_bail!("only support temporal extension types currently"); @@ -206,6 +196,101 @@ fn export_canonical( }) } +/// Exports decimals with value buffers widened to Arrow's Decimal128/Decimal256 layout. +/// +/// Vortex may use narrower storage for decimal values, but Arrow consumers interpret the buffer +/// width from the schema, so the exported device buffer must match Arrow's physical layout. +async fn export_decimal( + decimal: DecimalArray, + ctx: &mut CudaExecutionCtx, +) -> VortexResult<(ArrowArray, SyncEvent)> { + let len = decimal.len(); + let DecimalDataParts { + decimal_dtype, + values, + values_type, + validity, + } = decimal.into_data_parts(); + + let (validity_buffer, null_count) = export_arrow_validity_buffer(validity, len, 0, ctx).await?; + let target_type = arrow_decimal_value_type(decimal_dtype); + let values = export_decimal_values(values, values_type, target_type, len, ctx).await?; + + export_fixed_size(values, len, 0, validity_buffer, null_count, ctx) +} + +/// Returns the Arrow physical value type for a decimal dtype. +/// +/// Arrow represents decimal precision 1-38 as Decimal128 and precision 39-76 as Decimal256. +fn arrow_decimal_value_type(decimal_dtype: DecimalDType) -> DecimalType { + if decimal_dtype.precision() <= 38 { + DecimalType::I128 + } else { + DecimalType::I256 + } +} + +async fn export_decimal_values( + values: BufferHandle, + values_type: DecimalType, + target_type: DecimalType, + len: usize, + ctx: &mut CudaExecutionCtx, +) -> VortexResult { + let values = ctx.ensure_on_device(values).await?; + if values_type == target_type { + return Ok(values); + } + + match_each_decimal_value_type!(values_type, |S| { + match target_type { + DecimalType::I128 => decimal_cast::(values, len, ctx).await, + DecimalType::I256 => decimal_cast::(values, len, ctx).await, + target_type => { + vortex_bail!("cannot export DecimalArray as Arrow decimal value type {target_type}") + } + } + }) +} + +/// Launches the CUDA kernel that widens decimal values from `S` to `D`. +/// +/// This preserves the scaled integer value while changing only the physical buffer width. +async fn decimal_cast( + values: BufferHandle, + len: usize, + ctx: &mut CudaExecutionCtx, +) -> VortexResult +where + S: NativeDecimalType + DeviceRepr, + D: NativeDecimalType + DeviceRepr, +{ + if len == 0 { + return ctx + .ensure_on_device(BufferHandle::new_host( + Buffer::::empty().into_byte_buffer(), + )) + .await; + } + + let output_buffer = ctx.device_alloc::(len)?; + let output_device = CudaDeviceBuffer::new(output_buffer); + + let values_view = values.cuda_view::()?; + let output_view = output_device.as_view::(); + let len_u64 = len as u64; + let cuda_function = ctx.load_function_with_suffixes( + "decimal_cast", + &[&S::DECIMAL_TYPE.to_string(), &D::DECIMAL_TYPE.to_string()], + )?; + + ctx.launch_kernel(&cuda_function, len, |args| { + args.arg(&values_view).arg(&output_view).arg(&len_u64); + })?; + + Ok(BufferHandle::new_device(Arc::new(output_device))) +} + /// Export Vortex validity as an Arrow validity byte buffer. /// /// Returns `None` for the buffer when Arrow can omit validity because all rows are valid. @@ -359,10 +444,14 @@ mod tests { use vortex::array::arrays::StructArray; use vortex::array::arrays::TemporalArray; use vortex::array::arrays::VarBinViewArray; + use vortex::array::buffer::BufferHandle; use vortex::array::validity::Validity; + use vortex::buffer::Buffer; use vortex::dtype::DecimalDType; use vortex::dtype::FieldNames; + use vortex::dtype::NativeDecimalType; use vortex::dtype::half::f16; + use vortex::dtype::i256; use vortex::error::VortexExpect; use vortex::error::VortexResult; use vortex::extension::datetime::TimeUnit; @@ -418,6 +507,14 @@ mod tests { Ok(device_array) } + fn assert_exported_decimal_values( + value_buffer: &BufferHandle, + expected: &[T], + ) { + let values = Buffer::::from_byte_buffer(value_buffer.to_host_sync()); + assert_eq!(values.as_slice(), expected); + } + // Build a nested struct fixture with an out-of-line string-view value. fn nested_struct_array() -> ArrayRef { let nested = StructArray::new( @@ -512,22 +609,139 @@ mod tests { Ok(()) } + #[rstest] + #[case::i8( + DecimalArray::from_iter([1i8, -2, 3], DecimalDType::new(2, 1)).into_array(), + DataType::Decimal128(2, 1), + vec![1, -2, 3] + )] + #[case::i16( + DecimalArray::from_iter([100i16, -200, 300], DecimalDType::new(4, 2)).into_array(), + DataType::Decimal128(4, 2), + vec![100, -200, 300] + )] + #[case::i32( + DecimalArray::from_iter([10_000i32, -20_000, 30_000], DecimalDType::new(9, 2)).into_array(), + DataType::Decimal128(9, 2), + vec![10_000, -20_000, 30_000] + )] + #[case::i64( + DecimalArray::from_iter([1_000_000i64, -2_000_000, 3_000_000], DecimalDType::new(18, 2)).into_array(), + DataType::Decimal128(18, 2), + vec![1_000_000, -2_000_000, 3_000_000] + )] + #[case::i128( + DecimalArray::from_iter([1i128, -2, 3], DecimalDType::new(38, 2)).into_array(), + DataType::Decimal128(38, 2), + vec![1, -2, 3] + )] #[crate::test] - async fn test_export_decimal() -> VortexResult<()> { + async fn test_export_decimal128( + #[case] array: ArrayRef, + #[case] expected_data_type: DataType, + #[case] expected_values: Vec, + ) -> VortexResult<()> { let mut ctx = CudaSession::create_execution_ctx(&VortexSession::empty()) .vortex_expect("failed to create execution context"); - let array = DecimalArray::from_iter(0i128..5, DecimalDType::new(38, 2)).into_array(); - let mut device_array = array.export_device_array(&mut ctx).await?; + let mut exported = array.export_device_array_with_schema(&mut ctx).await?; - assert_eq!(device_array.array.length, 5); - assert_eq!(device_array.array.null_count, 0); - assert_eq!(device_array.array.n_buffers, 2); - assert_eq!(device_array.array.n_children, 0); - assert!(device_array.array.release.is_some()); - assert_eq!(device_array.device_type, ARROW_DEVICE_CUDA); + let field = Field::try_from(&exported.schema)?; + assert_eq!(field, Field::new("", expected_data_type, false)); + assert_eq!(exported.array.array.length, 3); + assert_eq!(exported.array.array.null_count, 0); + assert_eq!(exported.array.array.n_buffers, 2); + assert_eq!(exported.array.array.n_children, 0); + assert!(exported.array.array.release.is_some()); + assert_eq!(exported.array.device_type, ARROW_DEVICE_CUDA); - unsafe { release_exported_array(&raw mut device_array.array) }; + let private_data = unsafe { &*exported.array.array.private_data.cast::() }; + let value_buffer = private_data.buffers[1] + .as_ref() + .vortex_expect("value buffer should be present"); + assert_eq!(value_buffer.len(), 3 * size_of::()); + assert_exported_decimal_values(value_buffer, &expected_values); + + unsafe { release_exported_array(&raw mut exported.array.array) }; + Ok(()) + } + + #[crate::test] + async fn test_export_empty_decimal128() -> VortexResult<()> { + let mut ctx = CudaSession::create_execution_ctx(&VortexSession::empty()) + .vortex_expect("failed to create execution context"); + + let array = DecimalArray::new( + Buffer::::empty(), + DecimalDType::new(9, 2), + Validity::NonNullable, + ) + .into_array(); + let mut exported = array.export_device_array_with_schema(&mut ctx).await?; + + let field = Field::try_from(&exported.schema)?; + assert_eq!(field, Field::new("", DataType::Decimal128(9, 2), false)); + assert_eq!(exported.array.array.length, 0); + assert_eq!(exported.array.array.null_count, 0); + assert_eq!(exported.array.array.n_buffers, 2); + assert_eq!(exported.array.array.n_children, 0); + assert!(exported.array.array.release.is_some()); + assert_eq!(exported.array.device_type, ARROW_DEVICE_CUDA); + + let private_data = unsafe { &*exported.array.array.private_data.cast::() }; + let value_buffer = private_data.buffers[1] + .as_ref() + .vortex_expect("value buffer should be present"); + assert_eq!(value_buffer.len(), 0); + assert_exported_decimal_values::(value_buffer, &[]); + + unsafe { release_exported_array(&raw mut exported.array.array) }; + Ok(()) + } + + #[rstest] + #[case::i128( + DecimalArray::from_iter([1i128, -2, 3], DecimalDType::new(39, 2)).into_array(), + DataType::Decimal256(39, 2), + vec![i256::from_i128(1), i256::from_i128(-2), i256::from_i128(3)] + )] + #[case::i256( + DecimalArray::from_iter( + [i256::from_i128(10), i256::from_i128(-20), i256::from_i128(30)], + DecimalDType::new(76, 2), + ) + .into_array(), + DataType::Decimal256(76, 2), + vec![i256::from_i128(10), i256::from_i128(-20), i256::from_i128(30)] + )] + #[crate::test] + async fn test_export_decimal256( + #[case] array: ArrayRef, + #[case] expected_data_type: DataType, + #[case] expected_values: Vec, + ) -> VortexResult<()> { + let mut ctx = CudaSession::create_execution_ctx(&VortexSession::empty()) + .vortex_expect("failed to create execution context"); + + let mut exported = array.export_device_array_with_schema(&mut ctx).await?; + + let field = Field::try_from(&exported.schema)?; + assert_eq!(field, Field::new("", expected_data_type, false)); + assert_eq!(exported.array.array.length, 3); + assert_eq!(exported.array.array.null_count, 0); + assert_eq!(exported.array.array.n_buffers, 2); + assert_eq!(exported.array.array.n_children, 0); + assert!(exported.array.array.release.is_some()); + assert_eq!(exported.array.device_type, ARROW_DEVICE_CUDA); + + let private_data = unsafe { &*exported.array.array.private_data.cast::() }; + let value_buffer = private_data.buffers[1] + .as_ref() + .vortex_expect("value buffer should be present"); + assert_eq!(value_buffer.len(), 3 * size_of::()); + assert_exported_decimal_values(value_buffer, &expected_values); + + unsafe { release_exported_array(&raw mut exported.array.array) }; Ok(()) } @@ -745,6 +959,13 @@ mod tests { &mut ctx, ) .await?; + + let private_data = unsafe { &*decimal.array.private_data.cast::() }; + let value_buffer = private_data.buffers[1] + .as_ref() + .vortex_expect("value buffer should be present"); + assert_eq!(value_buffer.len(), 3 * size_of::()); + unsafe { release_exported_array(&raw mut decimal.array) }; Ok(()) diff --git a/vortex-test/e2e-cuda/src/lib.rs b/vortex-test/e2e-cuda/src/lib.rs index 0a8ca91f021..e96d7a51239 100644 --- a/vortex-test/e2e-cuda/src/lib.rs +++ b/vortex-test/e2e-cuda/src/lib.rs @@ -111,8 +111,8 @@ pub unsafe extern "C" fn export_array( } }; let decimal = DecimalArray::from_option_iter( - [Some(0i128), Some(1), None, Some(3), Some(4)], - DecimalDType::new(38, 2), + [Some(0i32), Some(1), None, Some(3), Some(4)], + DecimalDType::new(10, 2), ); let strings = VarBinViewArray::from_iter_nullable_str([ Some("one"),