Skip to content
141 changes: 69 additions & 72 deletions datafusion-examples/examples/extension_types/temperature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use arrow::array::{
};
use arrow::datatypes::{Float32Type, Float64Type};
use arrow::util::display::{ArrayFormatter, DisplayIndex, FormatOptions, FormatResult};
use arrow_schema::extension::ExtensionType;
use arrow_schema::extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY};
use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaRef};
use datafusion::dataframe::DataFrame;
use datafusion::error::Result;
Expand All @@ -30,8 +30,9 @@ use datafusion::prelude::SessionContext;
use datafusion_common::internal_err;
use datafusion_common::types::DFExtensionType;
use datafusion_expr::registry::{
DefaultExtensionTypeRegistration, ExtensionTypeRegistry, MemoryExtensionTypeRegistry,
ExtensionTypeRegistration, ExtensionTypeRegistry, MemoryExtensionTypeRegistry,
};
use std::collections::HashMap;
use std::fmt::{Display, Write};
use std::sync::Arc;

Expand All @@ -50,13 +51,15 @@ fn create_session_context() -> Result<SessionContext> {
let registry = MemoryExtensionTypeRegistry::new_empty();

// The registration creates a new instance of the extension type with the deserialized metadata.
let temp_registration =
DefaultExtensionTypeRegistration::new_arc(|storage_type, metadata| {
Ok(TemperatureExtensionType::new(
storage_type.clone(),
metadata,
))
});
let temp_registration = ExtensionTypeRegistration::new_arc(
TemperatureExtensionType::NAME,
|storage_type, metadata| {
Ok(Arc::new(TemperatureExtensionType::try_new(
storage_type,
TemperatureUnit::try_from(metadata)?,
)?))
},
);
registry.add_extension_type_registration(temp_registration)?;

let state = SessionStateBuilder::default()
Expand Down Expand Up @@ -96,26 +99,15 @@ async fn register_temperature_table(ctx: &SessionContext) -> Result<DataFrame> {
fn example_schema() -> SchemaRef {
Arc::new(Schema::new(vec![
Field::new("city", DataType::Utf8, false),
Field::new("celsius", DataType::Float64, false).with_extension_type(
TemperatureExtensionType::new(DataType::Float64, TemperatureUnit::Celsius),
),
Field::new("fahrenheit", DataType::Float64, false).with_extension_type(
TemperatureExtensionType::new(DataType::Float64, TemperatureUnit::Fahrenheit),
),
Field::new("kelvin", DataType::Float32, false).with_extension_type(
TemperatureExtensionType::new(DataType::Float32, TemperatureUnit::Kelvin),
),
Field::new("celsius", DataType::Float64, false)
.with_metadata(create_metadata(TemperatureUnit::Celsius)),
Field::new("fahrenheit", DataType::Float64, false)
.with_metadata(create_metadata(TemperatureUnit::Fahrenheit)),
Field::new("kelvin", DataType::Float32, false)
.with_metadata(create_metadata(TemperatureUnit::Kelvin)),
]))
}

/// Represents the unit of a temperature reading.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TemperatureUnit {
Celsius,
Fahrenheit,
Kelvin,
}

/// Represents a float that semantically represents a temperature. The temperature can be one of
/// the supported [`TemperatureUnit`]s.
///
Expand Down Expand Up @@ -143,46 +135,61 @@ pub struct TemperatureExtensionType {
}

impl TemperatureExtensionType {
/// The name of the extension type.
pub const NAME: &'static str = "custom.temperature";

/// Creates a new [`TemperatureExtensionType`].
pub fn new(storage_type: DataType, temperature_unit: TemperatureUnit) -> Self {
Self {
storage_type,
temperature_unit,
pub fn try_new(
storage_type: &DataType,
temperature_unit: TemperatureUnit,
) -> Result<Self, ArrowError> {
match storage_type {
DataType::Float32 | DataType::Float64 => {}
_ => {
return Err(ArrowError::InvalidArgumentError(format!(
"Invalid data type: {storage_type} for temperature type, expected Float32 or Float64",
)));
}
}

let result = Self {
storage_type: storage_type.clone(),
temperature_unit,
};
Ok(result)
}
}

/// Implementation of [`ExtensionType`] for [`TemperatureExtensionType`].
///
/// This implements the arrow-rs trait for reading, writing, and validating extension types.
impl ExtensionType for TemperatureExtensionType {
/// Arrow extension type name that is stored in the `ARROW:extension:name` field.
const NAME: &'static str = "custom.temperature";
type Metadata = TemperatureUnit;

fn metadata(&self) -> &Self::Metadata {
&self.temperature_unit
}
/// Represents the unit of a temperature reading.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TemperatureUnit {
Celsius,
Fahrenheit,
Kelvin,
}

impl TemperatureUnit {
/// Arrow extension type metadata is encoded as a string and stored using the
/// `ARROW:extension:metadata` key. As we only store the name of the unit, a simple string
/// suffices. Extension types can store more complex metadata using serialization formats like
/// JSON.
fn serialize_metadata(&self) -> Option<String> {
let s = match self.temperature_unit {
pub fn serialize(self) -> String {
let result = match self {
TemperatureUnit::Celsius => "celsius",
TemperatureUnit::Fahrenheit => "fahrenheit",
TemperatureUnit::Kelvin => "kelvin",
};
Some(s.to_string())
result.to_owned()
}
}

/// Inverse operation of [`Self::serialize_metadata`]. This creates the [`TemperatureUnit`]
/// value from the serialized string.
fn deserialize_metadata(
metadata: Option<&str>,
) -> std::result::Result<Self::Metadata, ArrowError> {
match metadata {
/// Inverse operation of [`TemperatureUnit::serialize`]. This creates the [`TemperatureUnit`]
/// value from the serialized string.
impl TryFrom<Option<&str>> for TemperatureUnit {
type Error = ArrowError;

fn try_from(value: Option<&str>) -> std::result::Result<Self, Self::Error> {
match value {
Some("celsius") => Ok(TemperatureUnit::Celsius),
Some("fahrenheit") => Ok(TemperatureUnit::Fahrenheit),
Some("kelvin") => Ok(TemperatureUnit::Kelvin),
Expand All @@ -194,28 +201,18 @@ impl ExtensionType for TemperatureExtensionType {
)),
}
}
}

/// Checks that the extension type supports a given [`DataType`].
fn supports_data_type(
&self,
data_type: &DataType,
) -> std::result::Result<(), ArrowError> {
match data_type {
DataType::Float32 | DataType::Float64 => Ok(()),
_ => Err(ArrowError::InvalidArgumentError(format!(
"Invalid data type: {data_type} for temperature type, expected Float32 or Float64",
))),
}
}

fn try_new(
data_type: &DataType,
metadata: Self::Metadata,
) -> std::result::Result<Self, ArrowError> {
let instance = Self::new(data_type.clone(), metadata);
instance.supports_data_type(data_type)?;
Ok(instance)
}
/// This creates a metadata map for the temperature type. Another way of writing the metadata can be
/// implemented using arrow-rs' [`ExtensionType`](arrow_schema::extension::ExtensionType) trait.
fn create_metadata(unit: TemperatureUnit) -> HashMap<String, String> {
HashMap::from([
(
EXTENSION_TYPE_NAME_KEY.to_owned(),
TemperatureExtensionType::NAME.to_owned(),
),
(EXTENSION_TYPE_METADATA_KEY.to_owned(), unit.serialize()),
])
}

/// Implementation of [`DFExtensionType`] for [`TemperatureExtensionType`].
Expand All @@ -227,7 +224,7 @@ impl DFExtensionType for TemperatureExtensionType {
}

fn serialize_metadata(&self) -> Option<String> {
ExtensionType::serialize_metadata(self)
Some(self.temperature_unit.serialize())
}

fn create_array_formatter<'fmt>(
Expand Down
114 changes: 114 additions & 0 deletions datafusion/common/src/types/canonical_extensions/bool8.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use crate::Result;
use crate::error::_internal_err;
use crate::types::extension::DFExtensionType;
use arrow::array::{Array, Int8Array};
use arrow::datatypes::DataType;
use arrow::util::display::{ArrayFormatter, DisplayIndex, FormatOptions, FormatResult};
use arrow_schema::extension::{Bool8, ExtensionType};
use std::fmt::Write;

/// Defines the extension type logic for the canonical `arrow.bool8` extension type.
///
/// See [`DFExtensionType`] for information on DataFusion's extension type mechanism.
#[derive(Debug, Clone)]
pub struct DFBool8(Bool8);

impl DFBool8 {
/// Creates a new [`DFBool8`], validating that the storage type is compatible with the
/// extension type.
pub fn try_new(
data_type: &DataType,
metadata: <Bool8 as ExtensionType>::Metadata,
) -> Result<Self> {
Ok(Self(<Bool8 as ExtensionType>::try_new(
data_type, metadata,
)?))
}
}

impl DFExtensionType for DFBool8 {
fn storage_type(&self) -> DataType {
DataType::Int8
}

fn serialize_metadata(&self) -> Option<String> {
self.0.serialize_metadata()
}

fn create_array_formatter<'fmt>(
&self,
array: &'fmt dyn Array,
options: &FormatOptions<'fmt>,
) -> Result<Option<ArrayFormatter<'fmt>>> {
if array.data_type() != &DataType::Int8 {
return _internal_err!("Wrong array type for Bool8");
}

let display_index = Bool8ValueDisplayIndex {
array: array.as_any().downcast_ref().unwrap(),
null_str: options.null(),
};
Ok(Some(ArrayFormatter::new(
Box::new(display_index),
options.safe(),
)))
}
}

/// Pretty printer for binary UUID values.
#[derive(Debug, Clone, Copy)]
struct Bool8ValueDisplayIndex<'a> {
array: &'a Int8Array,
null_str: &'a str,
}

impl DisplayIndex for Bool8ValueDisplayIndex<'_> {
fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult {
if self.array.is_null(idx) {
write!(f, "{}", self.null_str)?;
return Ok(());
}

let bytes = self.array.value(idx);
write!(f, "{}", bytes != 0)?;
Ok(())
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
pub fn test_pretty_bool8() {
let values = Int8Array::from_iter([Some(0), Some(1), Some(-20), None]);

let extension_type = DFBool8(Bool8 {});
let formatter = extension_type
.create_array_formatter(&values, &FormatOptions::default().with_null("NULL"))
.unwrap()
.unwrap();

assert_eq!(formatter.value(0).to_string(), "false");
assert_eq!(formatter.value(1).to_string(), "true");
assert_eq!(formatter.value(2).to_string(), "true");
assert_eq!(formatter.value(3).to_string(), "NULL");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use crate::Result;
use crate::types::extension::DFExtensionType;
use arrow::datatypes::DataType;
use arrow_schema::extension::{ExtensionType, FixedShapeTensor};

/// Defines the extension type logic for the canonical `arrow.fixed_shape_tensor` extension type.
///
/// See [`DFExtensionType`] for information on DataFusion's extension type mechanism.
#[derive(Debug, Clone)]
pub struct DFFixedShapeTensor {
inner: FixedShapeTensor,
/// The storage type of the tensor.
///
/// While we could reconstruct the storage type from the inner [`FixedShapeTensor`], we may
/// choose a different name for the field within the [`DataType::FixedSizeList`] which can
/// cause problems down the line (e.g., checking for equality).
storage_type: DataType,
}

impl DFFixedShapeTensor {
/// Creates a new [`DFFixedShapeTensor`], validating that the storage type is compatible with
/// the extension type.
pub fn try_new(
data_type: &DataType,
metadata: <FixedShapeTensor as ExtensionType>::Metadata,
) -> Result<Self> {
Ok(Self {
inner: <FixedShapeTensor as ExtensionType>::try_new(data_type, metadata)?,
storage_type: data_type.clone(),
})
}
}

impl DFExtensionType for DFFixedShapeTensor {
fn storage_type(&self) -> DataType {
self.storage_type.clone()
}

fn serialize_metadata(&self) -> Option<String> {
self.inner.serialize_metadata()
}
}
Loading
Loading