Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Make output dtype known for list.to_struct when fields are passed #19439

Merged
merged 12 commits into from
Oct 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions crates/polars-io/src/cloud/credential_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,9 +343,6 @@ impl serde::Serialize for PlCredentialProvider {
{
use serde::ser::Error;

// TODO:
// * Add magic bytes here to indicate a python function
// * Check the Python version on deserialize
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

drive-by - outdated todo

#[cfg(feature = "python")]
if let PlCredentialProvider::Python(v) = self {
return v.serialize(serializer);
Expand Down
238 changes: 188 additions & 50 deletions crates/polars-ops/src/chunked_array/list/to_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,82 +5,220 @@ use polars_utils::pl_str::PlSmallStr;

use super::*;

#[derive(Copy, Clone, Debug)]
#[derive(Clone, Eq, PartialEq, Hash, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum ListToStructArgs {
FixedWidth(Arc<[PlSmallStr]>),
InferWidth {
infer_field_strategy: ListToStructWidthStrategy,
get_index_name: Option<NameGenerator>,
/// If this is 0, it means unbounded.
max_fields: usize,
},
}

#[derive(Clone, Eq, PartialEq, Hash, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum ListToStructWidthStrategy {
FirstNonNull,
MaxWidth,
}

fn det_n_fields(ca: &ListChunked, n_fields: ListToStructWidthStrategy) -> usize {
match n_fields {
ListToStructWidthStrategy::MaxWidth => {
let mut max = 0;

ca.downcast_iter().for_each(|arr| {
let offsets = arr.offsets().as_slice();
let mut last = offsets[0];
for o in &offsets[1..] {
let len = (*o - last) as usize;
max = std::cmp::max(max, len);
last = *o;
impl ListToStructArgs {
pub fn get_output_dtype(&self, input_dtype: &DataType) -> PolarsResult<DataType> {
let DataType::List(inner_dtype) = input_dtype else {
polars_bail!(
InvalidOperation:
"attempted list to_struct on non-list dtype: {}",
input_dtype
);
};
let inner_dtype = inner_dtype.as_ref();

match self {
Self::FixedWidth(names) => Ok(DataType::Struct(
names
.iter()
.map(|x| Field::new(x.clone(), inner_dtype.clone()))
.collect::<Vec<_>>(),
)),
Self::InferWidth {
get_index_name,
max_fields,
..
} if *max_fields > 0 => {
let get_index_name_func = get_index_name.as_ref().map_or(
&_default_struct_name_gen as &dyn Fn(usize) -> PlSmallStr,
|x| x.0.as_ref(),
);
Ok(DataType::Struct(
(0..*max_fields)
.map(|i| Field::new(get_index_name_func(i), inner_dtype.clone()))
.collect::<Vec<_>>(),
))
},
Self::InferWidth { .. } => Ok(DataType::Unknown(UnknownKind::Any)),
}
}

fn det_n_fields(&self, ca: &ListChunked) -> usize {
match self {
Self::FixedWidth(v) => v.len(),
Self::InferWidth {
infer_field_strategy,
max_fields,
..
} => {
let inferred = match infer_field_strategy {
ListToStructWidthStrategy::MaxWidth => {
let mut max = 0;

ca.downcast_iter().for_each(|arr| {
let offsets = arr.offsets().as_slice();
let mut last = offsets[0];
for o in &offsets[1..] {
let len = (*o - last) as usize;
max = std::cmp::max(max, len);
last = *o;
}
});
max
},
ListToStructWidthStrategy::FirstNonNull => {
let mut len = 0;
for arr in ca.downcast_iter() {
let offsets = arr.offsets().as_slice();
let mut last = offsets[0];
for o in &offsets[1..] {
len = (*o - last) as usize;
if len > 0 {
break;
}
last = *o;
}
if len > 0 {
break;
}
}
len
},
};

if *max_fields > 0 {
inferred.min(*max_fields)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

upper_bound was previously ignored when width was being inferred during execution - it was only used during IR resolving for getting the output names

} else {
inferred
}
});
max
},
ListToStructWidthStrategy::FirstNonNull => {
let mut len = 0;
for arr in ca.downcast_iter() {
let offsets = arr.offsets().as_slice();
let mut last = offsets[0];
for o in &offsets[1..] {
len = (*o - last) as usize;
if len > 0 {
break;
}
last = *o;
},
}
}

fn set_output_names(&self, columns: &mut [Series]) {
match self {
Self::FixedWidth(v) => {
assert_eq!(columns.len(), v.len());

for (c, name) in columns.iter_mut().zip(v.iter()) {
c.rename(name.clone());
}
if len > 0 {
break;
},
Self::InferWidth { get_index_name, .. } => {
let get_index_name_func = get_index_name.as_ref().map_or(
&_default_struct_name_gen as &dyn Fn(usize) -> PlSmallStr,
|x| x.0.as_ref(),
);

for (i, c) in columns.iter_mut().enumerate() {
c.rename(get_index_name_func(i));
}
}
len
},
},
}
}
}

#[derive(Clone)]
pub struct NameGenerator(pub Arc<dyn Fn(usize) -> PlSmallStr + Send + Sync>);

impl NameGenerator {
pub fn from_func(func: impl Fn(usize) -> PlSmallStr + Send + Sync + 'static) -> Self {
Self(Arc::new(func))
}
}

impl std::fmt::Debug for NameGenerator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"list::to_struct::NameGenerator function at 0x{:016x}",
self.0.as_ref() as *const _ as *const () as usize
)
}
}

impl Eq for NameGenerator {}

impl PartialEq for NameGenerator {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.0, &other.0)
}
}

pub type NameGenerator = Arc<dyn Fn(usize) -> PlSmallStr + Send + Sync>;
impl std::hash::Hash for NameGenerator {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
state.write_usize(Arc::as_ptr(&self.0) as *const () as usize)
}
}

pub fn _default_struct_name_gen(idx: usize) -> PlSmallStr {
format_pl_smallstr!("field_{idx}")
}

pub trait ToStruct: AsList {
fn to_struct(
&self,
n_fields: ListToStructWidthStrategy,
name_generator: Option<NameGenerator>,
) -> PolarsResult<StructChunked> {
fn to_struct(&self, args: &ListToStructArgs) -> PolarsResult<StructChunked> {
let ca = self.as_list();
let n_fields = det_n_fields(ca, n_fields);
let n_fields = args.det_n_fields(ca);

let name_generator = name_generator
.as_deref()
.unwrap_or(&_default_struct_name_gen);

let fields = POOL.install(|| {
let mut fields = POOL.install(|| {
(0..n_fields)
.into_par_iter()
.map(|i| {
ca.lst_get(i as i64, true).map(|mut s| {
s.rename(name_generator(i));
s
})
})
.map(|i| ca.lst_get(i as i64, true))
.collect::<PolarsResult<Vec<_>>>()
})?;

args.set_output_names(&mut fields);

StructChunked::from_series(ca.name().clone(), ca.len(), fields.iter())
}
}

impl ToStruct for ListChunked {}

#[cfg(feature = "serde")]
mod _serde_impl {
use super::*;

impl serde::Serialize for NameGenerator {
fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::Error;
Err(S::Error::custom(
"cannot serialize name generator function for to_struct, \
consider passing a list of field names instead.",
))
}
}

impl<'de> serde::Deserialize<'de> for NameGenerator {
fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error;
Err(D::Error::custom(
"invalid data: attempted to deserialize list::to_struct::NameGenerator",
))
}
}
}
15 changes: 14 additions & 1 deletion crates/polars-plan/src/dsl/function_expr/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use polars_ops::chunked_array::list::*;
use super::*;
use crate::{map, map_as_slice, wrap};

#[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)]
#[derive(Clone, Eq, PartialEq, Hash, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum ListFunction {
Concat,
Expand Down Expand Up @@ -56,6 +56,8 @@ pub enum ListFunction {
Join(bool),
#[cfg(feature = "dtype-array")]
ToArray(usize),
#[cfg(feature = "list_to_struct")]
ToStruct(ListToStructArgs),
}

impl ListFunction {
Expand Down Expand Up @@ -103,6 +105,8 @@ impl ListFunction {
#[cfg(feature = "dtype-array")]
ToArray(width) => mapper.try_map_dtype(|dt| map_list_dtype_to_array_dtype(dt, *width)),
NUnique => mapper.with_dtype(IDX_DTYPE),
#[cfg(feature = "list_to_struct")]
ToStruct(args) => mapper.try_map_dtype(|x| args.get_output_dtype(x)),
}
}
}
Expand Down Expand Up @@ -174,6 +178,8 @@ impl Display for ListFunction {
Join(_) => "join",
#[cfg(feature = "dtype-array")]
ToArray(_) => "to_array",
#[cfg(feature = "list_to_struct")]
ToStruct(_) => "to_struct",
};
write!(f, "list.{name}")
}
Expand Down Expand Up @@ -235,6 +241,8 @@ impl From<ListFunction> for SpecialEq<Arc<dyn ColumnsUdf>> {
#[cfg(feature = "dtype-array")]
ToArray(width) => map!(to_array, width),
NUnique => map!(n_unique),
#[cfg(feature = "list_to_struct")]
ToStruct(args) => map!(to_struct, &args),
}
}
}
Expand Down Expand Up @@ -650,6 +658,11 @@ pub(super) fn to_array(s: &Column, width: usize) -> PolarsResult<Column> {
s.cast(&array_dtype)
}

#[cfg(feature = "list_to_struct")]
pub(super) fn to_struct(s: &Column, args: &ListToStructArgs) -> PolarsResult<Column> {
Ok(s.list()?.to_struct(args)?.into_series().into())
}

pub(super) fn n_unique(s: &Column) -> PolarsResult<Column> {
Ok(s.list()?.lst_n_unique()?.into_column())
}
48 changes: 2 additions & 46 deletions crates/polars-plan/src/dsl/list.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
#[cfg(feature = "list_to_struct")]
use std::sync::RwLock;

use polars_core::prelude::*;
#[cfg(feature = "diff")]
use polars_core::series::ops::NullBehavior;
Expand Down Expand Up @@ -281,50 +278,9 @@ impl ListNameSpace {
/// an `upper_bound` of struct fields that will be set.
/// If this is incorrectly downstream operation may fail. For instance an `all().sum()` expression
/// will look in the current schema to determine which columns to select.
pub fn to_struct(
self,
n_fields: ListToStructWidthStrategy,
name_generator: Option<NameGenerator>,
upper_bound: usize,
) -> Expr {
// heap allocate the output type and fill it later
let out_dtype = Arc::new(RwLock::new(None::<DataType>));

pub fn to_struct(self, args: ListToStructArgs) -> Expr {
self.0
.map(
move |s| {
s.list()?
.to_struct(n_fields, name_generator.clone())
.map(|s| Some(s.into_column()))
},
// we don't yet know the fields
GetOutput::map_dtype(move |dt: &DataType| {
polars_ensure!(matches!(dt, DataType::List(_)), SchemaMismatch: "expected 'List' as input to 'list.to_struct' got {}", dt);
let out = out_dtype.read().unwrap();
match out.as_ref() {
// dtype already set
Some(dt) => Ok(dt.clone()),
// dtype still unknown, set it
None => {
drop(out);
let mut lock = out_dtype.write().unwrap();

let inner = dt.inner_dtype().unwrap();
let fields = (0..upper_bound)
.map(|i| {
let name = _default_struct_name_gen(i);
Field::new(name, inner.clone())
})
.collect();
let dt = DataType::Struct(fields);

*lock = Some(dt.clone());
Ok(dt)
},
}
}),
)
.with_fmt("list.to_struct")
.map_private(FunctionExpr::ListExpr(ListFunction::ToStruct(args)))
}

#[cfg(feature = "is_in")]
Expand Down
Loading