Skip to content

Commit

Permalink
Support casting of Float16 with other numeric types
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Nov 28, 2023
1 parent 8a0b5cb commit 6e487de
Showing 1 changed file with 38 additions and 4 deletions.
42 changes: 38 additions & 4 deletions arrow-cast/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {

// start numeric casts
(
UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float32 | Float64,
UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float32 | Float64,
UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float16 | Float32 | Float64,
UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float16 | Float32 | Float64,
) => true,
// end numeric casts

Expand All @@ -220,8 +220,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
(Time64(_), Time32(to_unit)) => {
matches!(to_unit, Second | Millisecond)
}
(Timestamp(_, _), _) if to_type.is_numeric() && to_type != &Float16 => true,
(_, Timestamp(_, _)) if from_type.is_numeric() && from_type != &Float16 => true,
(Timestamp(_, _), _) if to_type.is_numeric() => true,
(_, Timestamp(_, _)) if from_type.is_numeric() => true,
(Date64, Timestamp(_, None)) => true,
(Date32, Timestamp(_, None)) => true,
(
Expand Down Expand Up @@ -1367,6 +1367,7 @@ pub fn cast_with_options(
(UInt8, Int16) => cast_numeric_arrays::<UInt8Type, Int16Type>(array, cast_options),
(UInt8, Int32) => cast_numeric_arrays::<UInt8Type, Int32Type>(array, cast_options),
(UInt8, Int64) => cast_numeric_arrays::<UInt8Type, Int64Type>(array, cast_options),
(UInt8, Float16) => cast_numeric_arrays::<UInt8Type, Float16Type>(array, cast_options),
(UInt8, Float32) => cast_numeric_arrays::<UInt8Type, Float32Type>(array, cast_options),
(UInt8, Float64) => cast_numeric_arrays::<UInt8Type, Float64Type>(array, cast_options),

Expand All @@ -1377,6 +1378,7 @@ pub fn cast_with_options(
(UInt16, Int16) => cast_numeric_arrays::<UInt16Type, Int16Type>(array, cast_options),
(UInt16, Int32) => cast_numeric_arrays::<UInt16Type, Int32Type>(array, cast_options),
(UInt16, Int64) => cast_numeric_arrays::<UInt16Type, Int64Type>(array, cast_options),
(UInt16, Float16) => cast_numeric_arrays::<UInt16Type, Float16Type>(array, cast_options),
(UInt16, Float32) => cast_numeric_arrays::<UInt16Type, Float32Type>(array, cast_options),
(UInt16, Float64) => cast_numeric_arrays::<UInt16Type, Float64Type>(array, cast_options),

Expand All @@ -1387,6 +1389,7 @@ pub fn cast_with_options(
(UInt32, Int16) => cast_numeric_arrays::<UInt32Type, Int16Type>(array, cast_options),
(UInt32, Int32) => cast_numeric_arrays::<UInt32Type, Int32Type>(array, cast_options),
(UInt32, Int64) => cast_numeric_arrays::<UInt32Type, Int64Type>(array, cast_options),
(UInt32, Float16) => cast_numeric_arrays::<UInt32Type, Float16Type>(array, cast_options),
(UInt32, Float32) => cast_numeric_arrays::<UInt32Type, Float32Type>(array, cast_options),
(UInt32, Float64) => cast_numeric_arrays::<UInt32Type, Float64Type>(array, cast_options),

Expand All @@ -1397,6 +1400,7 @@ pub fn cast_with_options(
(UInt64, Int16) => cast_numeric_arrays::<UInt64Type, Int16Type>(array, cast_options),
(UInt64, Int32) => cast_numeric_arrays::<UInt64Type, Int32Type>(array, cast_options),
(UInt64, Int64) => cast_numeric_arrays::<UInt64Type, Int64Type>(array, cast_options),
(UInt64, Float16) => cast_numeric_arrays::<UInt64Type, Float16Type>(array, cast_options),
(UInt64, Float32) => cast_numeric_arrays::<UInt64Type, Float32Type>(array, cast_options),
(UInt64, Float64) => cast_numeric_arrays::<UInt64Type, Float64Type>(array, cast_options),

Expand All @@ -1407,6 +1411,7 @@ pub fn cast_with_options(
(Int8, Int16) => cast_numeric_arrays::<Int8Type, Int16Type>(array, cast_options),
(Int8, Int32) => cast_numeric_arrays::<Int8Type, Int32Type>(array, cast_options),
(Int8, Int64) => cast_numeric_arrays::<Int8Type, Int64Type>(array, cast_options),
(Int8, Float16) => cast_numeric_arrays::<Int8Type, Float16Type>(array, cast_options),
(Int8, Float32) => cast_numeric_arrays::<Int8Type, Float32Type>(array, cast_options),
(Int8, Float64) => cast_numeric_arrays::<Int8Type, Float64Type>(array, cast_options),

Expand All @@ -1417,6 +1422,7 @@ pub fn cast_with_options(
(Int16, Int8) => cast_numeric_arrays::<Int16Type, Int8Type>(array, cast_options),
(Int16, Int32) => cast_numeric_arrays::<Int16Type, Int32Type>(array, cast_options),
(Int16, Int64) => cast_numeric_arrays::<Int16Type, Int64Type>(array, cast_options),
(Int16, Float16) => cast_numeric_arrays::<Int16Type, Float16Type>(array, cast_options),
(Int16, Float32) => cast_numeric_arrays::<Int16Type, Float32Type>(array, cast_options),
(Int16, Float64) => cast_numeric_arrays::<Int16Type, Float64Type>(array, cast_options),

Expand All @@ -1427,6 +1433,7 @@ pub fn cast_with_options(
(Int32, Int8) => cast_numeric_arrays::<Int32Type, Int8Type>(array, cast_options),
(Int32, Int16) => cast_numeric_arrays::<Int32Type, Int16Type>(array, cast_options),
(Int32, Int64) => cast_numeric_arrays::<Int32Type, Int64Type>(array, cast_options),
(Int32, Float16) => cast_numeric_arrays::<Int32Type, Float16Type>(array, cast_options),
(Int32, Float32) => cast_numeric_arrays::<Int32Type, Float32Type>(array, cast_options),
(Int32, Float64) => cast_numeric_arrays::<Int32Type, Float64Type>(array, cast_options),

Expand All @@ -1437,9 +1444,21 @@ pub fn cast_with_options(
(Int64, Int8) => cast_numeric_arrays::<Int64Type, Int8Type>(array, cast_options),
(Int64, Int16) => cast_numeric_arrays::<Int64Type, Int16Type>(array, cast_options),
(Int64, Int32) => cast_numeric_arrays::<Int64Type, Int32Type>(array, cast_options),
(Int64, Float16) => cast_numeric_arrays::<Int64Type, Float16Type>(array, cast_options),
(Int64, Float32) => cast_numeric_arrays::<Int64Type, Float32Type>(array, cast_options),
(Int64, Float64) => cast_numeric_arrays::<Int64Type, Float64Type>(array, cast_options),

(Float16, UInt8) => cast_numeric_arrays::<Float16Type, UInt8Type>(array, cast_options),
(Float16, UInt16) => cast_numeric_arrays::<Float16Type, UInt16Type>(array, cast_options),
(Float16, UInt32) => cast_numeric_arrays::<Float16Type, UInt32Type>(array, cast_options),
(Float16, UInt64) => cast_numeric_arrays::<Float16Type, UInt64Type>(array, cast_options),
(Float16, Int8) => cast_numeric_arrays::<Float16Type, Int8Type>(array, cast_options),
(Float16, Int16) => cast_numeric_arrays::<Float16Type, Int16Type>(array, cast_options),
(Float16, Int32) => cast_numeric_arrays::<Float16Type, Int32Type>(array, cast_options),
(Float16, Int64) => cast_numeric_arrays::<Float16Type, Int64Type>(array, cast_options),
(Float16, Float32) => cast_numeric_arrays::<Float16Type, Float32Type>(array, cast_options),
(Float16, Float64) => cast_numeric_arrays::<Float16Type, Float64Type>(array, cast_options),

(Float32, UInt8) => cast_numeric_arrays::<Float32Type, UInt8Type>(array, cast_options),
(Float32, UInt16) => cast_numeric_arrays::<Float32Type, UInt16Type>(array, cast_options),
(Float32, UInt32) => cast_numeric_arrays::<Float32Type, UInt32Type>(array, cast_options),
Expand All @@ -1448,6 +1467,7 @@ pub fn cast_with_options(
(Float32, Int16) => cast_numeric_arrays::<Float32Type, Int16Type>(array, cast_options),
(Float32, Int32) => cast_numeric_arrays::<Float32Type, Int32Type>(array, cast_options),
(Float32, Int64) => cast_numeric_arrays::<Float32Type, Int64Type>(array, cast_options),
(Float32, Float16) => cast_numeric_arrays::<Float32Type, Float16Type>(array, cast_options),
(Float32, Float64) => cast_numeric_arrays::<Float32Type, Float64Type>(array, cast_options),

(Float64, UInt8) => cast_numeric_arrays::<Float64Type, UInt8Type>(array, cast_options),
Expand All @@ -1458,6 +1478,7 @@ pub fn cast_with_options(
(Float64, Int16) => cast_numeric_arrays::<Float64Type, Int16Type>(array, cast_options),
(Float64, Int32) => cast_numeric_arrays::<Float64Type, Int32Type>(array, cast_options),
(Float64, Int64) => cast_numeric_arrays::<Float64Type, Int64Type>(array, cast_options),
(Float64, Float16) => cast_numeric_arrays::<Float64Type, Float16Type>(array, cast_options),
(Float64, Float32) => cast_numeric_arrays::<Float64Type, Float32Type>(array, cast_options),
// end numeric casts

Expand Down Expand Up @@ -3299,6 +3320,7 @@ fn cast_list<I: OffsetSizeTrait, O: OffsetSizeTrait>(
#[cfg(test)]
mod tests {
use arrow_buffer::{Buffer, NullBuffer};
use half::f16;

use super::*;

Expand Down Expand Up @@ -4665,6 +4687,15 @@ mod tests {
let array = Int64Array::from(vec![Some(2), Some(10), None]);
let expected = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap();

let array = Float16Array::from(vec![
Some(f16::from_f32(2.0)),
Some(f16::from_f32(10.6)),
None,
]);
let actual = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap();

assert_eq!(&actual, &expected);

let array = Float32Array::from(vec![Some(2.0), Some(10.6), None]);
let actual = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap();

Expand All @@ -4682,6 +4713,9 @@ mod tests {
.with_timezone("UTC".to_string());
let expected = cast(&array, &DataType::Int64).unwrap();

let actual = cast(&cast(&array, &DataType::Float16).unwrap(), &DataType::Int64).unwrap();
assert_eq!(&actual, &expected);

let actual = cast(&cast(&array, &DataType::Float32).unwrap(), &DataType::Int64).unwrap();
assert_eq!(&actual, &expected);

Expand Down

0 comments on commit 6e487de

Please sign in to comment.