Skip to content

Commit

Permalink
Cleanup list casting and support nested lists (#5113)
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Nov 26, 2023
1 parent e1bafdf commit 3ba24d1
Showing 1 changed file with 65 additions and 103 deletions.
168 changes: 65 additions & 103 deletions arrow-cast/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -783,28 +783,10 @@ pub fn cast_with_options(
"Casting from type {from_type:?} to dictionary type {to_type:?} not supported",
))),
},
(List(_), List(ref to)) => cast_list_inner::<i32>(array, to, to_type, cast_options),
(LargeList(_), LargeList(ref to)) => {
cast_list_inner::<i64>(array, to, to_type, cast_options)
}
(List(list_from), LargeList(list_to)) => {
if list_to.data_type() != list_from.data_type() {
Err(ArrowError::CastError(
"cannot cast list to large-list with different child data".into(),
))
} else {
cast_list_container::<i32, i64>(array, cast_options)
}
}
(LargeList(list_from), List(list_to)) => {
if list_to.data_type() != list_from.data_type() {
Err(ArrowError::CastError(
"cannot cast large-list to list with different child data".into(),
))
} else {
cast_list_container::<i64, i32>(array, cast_options)
}
}
(List(_), List(to)) => cast_list_values::<i32>(array, to, cast_options),
(LargeList(_), LargeList(to)) => cast_list_values::<i64>(array, to, cast_options),
(List(_), LargeList(list_to)) => cast_list::<i32, i64>(array, list_to, cast_options),
(LargeList(_), List(list_to)) => cast_list::<i64, i32>(array, list_to, cast_options),
(List(_), FixedSizeList(field, size)) => {
let array = array.as_list::<i32>();
cast_list_to_fixed_size_list::<i32>(array, field, *size, cast_options)
Expand Down Expand Up @@ -3046,28 +3028,6 @@ fn cast_values_to_list<O: OffsetSizeTrait>(
Ok(Arc::new(list))
}

/// Helper function that takes an Generic list container and casts the inner datatype.
fn cast_list_inner<OffsetSize: OffsetSizeTrait>(
array: &dyn Array,
to: &Field,
to_type: &DataType,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError> {
let data = array.to_data();
let underlying_array = make_array(data.child_data()[0].clone());
let cast_array = cast_with_options(underlying_array.as_ref(), to.data_type(), cast_options)?;
let builder = data
.into_builder()
.data_type(to_type.clone())
.child_data(vec![cast_array.into_data()]);

// Safety
// Data was valid before
let array_data = unsafe { builder.build_unchecked() };
let list = GenericListArray::<OffsetSize>::from(array_data);
Ok(Arc::new(list) as ArrayRef)
}

/// A specified helper to cast from `GenericBinaryArray` to `GenericStringArray` when they have same
/// offset size so re-encoding offset is unnecessary.
fn cast_binary_to_string<O: OffsetSizeTrait>(
Expand Down Expand Up @@ -3221,7 +3181,7 @@ where

fn cast_list_to_fixed_size_list<OffsetSize>(
array: &GenericListArray<OffsetSize>,
field: &Arc<Field>,
field: &FieldRef,
size: i32,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError>
Expand Down Expand Up @@ -3289,70 +3249,52 @@ where
Ok(Arc::new(array))
}

/// Cast the container type of List/Largelist array but not the inner types.
/// This function can leave the value data intact and only has to cast the offset dtypes.
fn cast_list_container<OffsetSizeFrom, OffsetSizeTo>(
/// Helper function that takes an Generic list container and casts the inner datatype.
fn cast_list_values<O: OffsetSizeTrait>(
array: &dyn Array,
_cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError>
where
OffsetSizeFrom: OffsetSizeTrait + ToPrimitive,
OffsetSizeTo: OffsetSizeTrait + NumCast,
{
let list = array.as_list::<OffsetSizeFrom>();
// the value data stored by the list
let values = list.values();
to: &FieldRef,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError> {
let list = array.as_list::<O>();
let values = cast_with_options(list.values(), to.data_type(), cast_options)?;
Ok(Arc::new(GenericListArray::<O>::new(
to.clone(),
list.offsets().clone(),
values,
list.nulls().cloned(),
)))
}

let out_dtype = match array.data_type() {
DataType::List(value_type) => {
assert_eq!(
std::mem::size_of::<OffsetSizeFrom>(),
std::mem::size_of::<i32>()
);
assert_eq!(
std::mem::size_of::<OffsetSizeTo>(),
std::mem::size_of::<i64>()
);
DataType::LargeList(value_type.clone())
}
DataType::LargeList(value_type) => {
assert_eq!(
std::mem::size_of::<OffsetSizeFrom>(),
std::mem::size_of::<i64>()
);
assert_eq!(
std::mem::size_of::<OffsetSizeTo>(),
std::mem::size_of::<i32>()
);
if values.len() > i32::MAX as usize {
return Err(ArrowError::ComputeError(
"LargeList too large to cast to List".into(),
));
}
DataType::List(value_type.clone())
}
// implementation error
_ => unreachable!(),
};
/// Cast the container type of List/Largelist array along with the inner datatype
fn cast_list<I: OffsetSizeTrait, O: OffsetSizeTrait>(
array: &dyn Array,
field: &FieldRef,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError> {
let list = array.as_list::<I>();
let values = list.values();
let offsets = list.offsets();
let nulls = list.nulls().cloned();

let iter = list.value_offsets().iter().map(|idx| {
let idx: OffsetSizeTo = NumCast::from(*idx).unwrap();
idx
});
if !O::IS_LARGE && values.len() > i32::MAX as usize {
return Err(ArrowError::ComputeError(
"LargeList too large to cast to List".into(),
));
}

// SAFETY
// A slice produces a trusted length iterator
let offset_buffer = unsafe { Buffer::from_trusted_len_iter(iter) };
// Recursively cast values
let values = cast_with_options(values, field.data_type(), cast_options)?;
let offsets: Vec<_> = offsets.iter().map(|x| O::usize_as(x.as_usize())).collect();

// wrap up
let builder = ArrayData::builder(out_dtype)
.len(list.len())
.add_buffer(offset_buffer)
.add_child_data(values.to_data())
.nulls(list.nulls().cloned());
// Safety: valid offsets and checked for overflow
let offsets = unsafe { OffsetBuffer::new_unchecked(offsets.into()) };

let array_data = unsafe { builder.build_unchecked() };
Ok(Arc::new(GenericListArray::<OffsetSizeTo>::from(array_data)))
Ok(Arc::new(GenericListArray::<O>::new(
field.clone(),
offsets,
values,
nulls,
)))
}

#[cfg(test)]
Expand Down Expand Up @@ -9154,6 +9096,26 @@ mod tests {
assert_eq!(formatted.value(1).to_string(), "[[4], [null], [6]]");
}

#[test]
fn test_nested_list_cast() {
let mut builder = ListBuilder::new(ListBuilder::new(Int32Builder::new()));
builder.append_value([Some([Some(1), Some(2), None]), None]);
builder.append_value([None, Some([]), None]);
builder.append_null();
builder.append_value([Some([Some(2), Some(3)])]);
let start = builder.finish();

let mut builder = LargeListBuilder::new(LargeListBuilder::new(Int8Builder::new()));
builder.append_value([Some([Some(1), Some(2), None]), None]);
builder.append_value([None, Some([]), None]);
builder.append_null();
builder.append_value([Some([Some(2), Some(3)])]);
let expected = builder.finish();

let actual = cast(&start, expected.data_type()).unwrap();
assert_eq!(actual.as_ref(), &expected);
}

const CAST_OPTIONS: CastOptions<'static> = CastOptions {
safe: true,
format_options: FormatOptions::new(),
Expand Down

0 comments on commit 3ba24d1

Please sign in to comment.