Skip to content

Commit

Permalink
fix: Address inadvertent quadratic behaviour in expand_columns (#19469
Browse files Browse the repository at this point in the history
)
  • Loading branch information
alexander-beedie authored Oct 27, 2024
1 parent b222796 commit 60d0721
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 69 deletions.
4 changes: 2 additions & 2 deletions crates/polars-plan/src/dsl/selector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub enum Selector {
Add(Box<Selector>, Box<Selector>),
Sub(Box<Selector>, Box<Selector>),
ExclusiveOr(Box<Selector>, Box<Selector>),
InterSect(Box<Selector>, Box<Selector>),
Intersect(Box<Selector>, Box<Selector>),
Root(Box<Expr>),
}

Expand All @@ -34,7 +34,7 @@ impl BitAnd for Selector {

#[allow(clippy::suspicious_arithmetic_impl)]
fn bitand(self, rhs: Self) -> Self::Output {
Selector::InterSect(Box::new(self), Box::new(rhs))
Selector::Intersect(Box::new(self), Box::new(rhs))
}
}

Expand Down
100 changes: 33 additions & 67 deletions crates/polars-plan/src/plans/conversion/expr_expansion.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
//! this contains code used for rewriting projections, expanding wildcards, regex selection etc.
use std::ops::BitXor;

use super::*;

Expand Down Expand Up @@ -176,26 +175,28 @@ fn expand_columns(
schema: &Schema,
exclude: &PlHashSet<PlSmallStr>,
) -> PolarsResult<()> {
let mut is_valid = true;
if !expr.into_iter().all(|e| match e {
// check for invalid expansions such as `col([a, b]) + col([c, d])`
Expr::Columns(ref members) => members.as_ref() == names,
_ => true,
}) {
polars_bail!(ComputeError: "expanding more than one `col` is not allowed");
}
for name in names {
if !exclude.contains(name) {
let new_expr = expr.clone();
let (new_expr, new_expr_valid) = replace_columns_with_column(new_expr, names, name);
is_valid &= new_expr_valid;
// we may have regex col in columns.
#[allow(clippy::collapsible_else_if)]
let new_expr = expr.clone().map_expr(|e| match e {
Expr::Columns(_) => Expr::Column((*name).clone()),
Expr::Exclude(input, _) => Arc::unwrap_or_clone(input),
e => e,
});

#[cfg(feature = "regex")]
{
replace_regex(&new_expr, result, schema, exclude)?;
}
replace_regex(&new_expr, result, schema, exclude)?;

#[cfg(not(feature = "regex"))]
{
let new_expr = rewrite_special_aliases(new_expr)?;
result.push(new_expr)
}
result.push(rewrite_special_aliases(new_expr)?);
}
}
polars_ensure!(is_valid, ComputeError: "expanding more than one `col` is not allowed");
Ok(())
}

Expand Down Expand Up @@ -246,30 +247,6 @@ fn replace_dtype_or_index_with_column(
})
}

/// This replaces the columns Expr with a Column Expr. It also removes the Exclude Expr from the
/// expression chain.
pub(super) fn replace_columns_with_column(
mut expr: Expr,
names: &[PlSmallStr],
column_name: &PlSmallStr,
) -> (Expr, bool) {
let mut is_valid = true;
expr = expr.map_expr(|e| match e {
Expr::Columns(members) => {
// `col([a, b]) + col([c, d])`
if members.as_ref() == names {
Expr::Column(column_name.clone())
} else {
is_valid = false;
Expr::Columns(members)
}
},
Expr::Exclude(input, _) => Arc::unwrap_or_clone(input),
e => e,
});
(expr, is_valid)
}

fn dtypes_match(d1: &DataType, d2: &DataType) -> bool {
match (d1, d2) {
// note: allow Datetime "*" wildcard for timezones...
Expand Down Expand Up @@ -562,7 +539,7 @@ fn expand_function_inputs(
})
}

#[derive(Copy, Clone)]
#[derive(Copy, Clone, Debug)]
struct ExpansionFlags {
multiple_columns: bool,
has_nth: bool,
Expand Down Expand Up @@ -819,42 +796,31 @@ fn replace_selector_inner(
members.extend(scratch.drain(..))
},
Selector::Add(lhs, rhs) => {
let mut tmp_members: PlIndexSet<Expr> = Default::default();
replace_selector_inner(*lhs, members, scratch, schema, keys)?;
let mut rhs_members: PlIndexSet<Expr> = Default::default();
replace_selector_inner(*rhs, &mut rhs_members, scratch, schema, keys)?;
members.extend(rhs_members)
replace_selector_inner(*rhs, &mut tmp_members, scratch, schema, keys)?;
members.extend(tmp_members)
},
Selector::ExclusiveOr(lhs, rhs) => {
let mut lhs_members = Default::default();
replace_selector_inner(*lhs, &mut lhs_members, scratch, schema, keys)?;
let mut tmp_members = Default::default();
replace_selector_inner(*lhs, &mut tmp_members, scratch, schema, keys)?;
replace_selector_inner(*rhs, members, scratch, schema, keys)?;

let mut rhs_members = Default::default();
replace_selector_inner(*rhs, &mut rhs_members, scratch, schema, keys)?;

let xor_members = lhs_members.bitxor(&rhs_members);
*members = xor_members;
*members = tmp_members.symmetric_difference(members).cloned().collect();
},
Selector::InterSect(lhs, rhs) => {
replace_selector_inner(*lhs, members, scratch, schema, keys)?;
Selector::Intersect(lhs, rhs) => {
let mut tmp_members = Default::default();
replace_selector_inner(*lhs, &mut tmp_members, scratch, schema, keys)?;
replace_selector_inner(*rhs, members, scratch, schema, keys)?;

let mut rhs_members = Default::default();
replace_selector_inner(*rhs, &mut rhs_members, scratch, schema, keys)?;

*members = members.intersection(&rhs_members).cloned().collect()
*members = tmp_members.intersection(members).cloned().collect();
},
Selector::Sub(lhs, rhs) => {
replace_selector_inner(*lhs, members, scratch, schema, keys)?;
let mut tmp_members = Default::default();
replace_selector_inner(*lhs, &mut tmp_members, scratch, schema, keys)?;
replace_selector_inner(*rhs, members, scratch, schema, keys)?;

let mut rhs_members = Default::default();
replace_selector_inner(*rhs, &mut rhs_members, scratch, schema, keys)?;

let mut new_members = PlIndexSet::with_capacity(members.len());
for e in members.drain(..) {
if !rhs_members.contains(&e) {
new_members.insert(e);
}
}
*members = new_members;
*members = tmp_members.difference(members).cloned().collect();
},
}
Ok(())
Expand Down

0 comments on commit 60d0721

Please sign in to comment.