Skip to content

Commit

Permalink
fix: Incorrect .join(..., how="left").head(N) if `N <= left_df.heig…
Browse files Browse the repository at this point in the history
…ht()` and there are duplicate matches (#19422)
  • Loading branch information
nameexhaustion authored Oct 24, 2024
1 parent b9084b7 commit 3fe25c5
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 7 deletions.
7 changes: 4 additions & 3 deletions crates/polars-ops/src/frame/join/dispatch_left_right.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,14 @@ fn materialize_left_join(
if let Some((offset, len)) = args.slice {
left_idx = slice_slice(left_idx, offset, len);
}
left._create_left_df_from_slice(left_idx, true, true)
left._create_left_df_from_slice(left_idx, true, args.slice.is_some(), true)
},
ChunkJoinIds::Right(left_idx) => unsafe {
let mut left_idx = &*left_idx;
if let Some((offset, len)) = args.slice {
left_idx = slice_slice(left_idx, offset, len);
}
left.create_left_df_chunked(left_idx, true)
left.create_left_df_chunked(left_idx, true, args.slice.is_some())
},
};

Expand Down Expand Up @@ -133,7 +133,8 @@ fn materialize_left_join(
if let Some((offset, len)) = args.slice {
left_idx = slice_slice(left_idx, offset, len);
}
let materialize_left = || unsafe { left._create_left_df_from_slice(&left_idx, true, true) };
let materialize_left =
|| unsafe { left._create_left_df_from_slice(&left_idx, true, args.slice.is_some(), true) };

let mut right_idx = &*right_idx;
if let Some((offset, len)) = args.slice {
Expand Down
20 changes: 17 additions & 3 deletions crates/polars-ops/src/frame/join/hash_join/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,18 @@ pub trait JoinDispatch: IntoDf {
/// # Safety
/// Join tuples must be in bounds
#[cfg(feature = "chunked_ids")]
unsafe fn create_left_df_chunked(&self, chunk_ids: &[ChunkId], left_join: bool) -> DataFrame {
unsafe fn create_left_df_chunked(
&self,
chunk_ids: &[ChunkId],
left_join: bool,
was_sliced: bool,
) -> DataFrame {
let df_self = self.to_df();
if left_join && chunk_ids.len() == df_self.height() {

let left_join_no_duplicate_matches =
left_join && !was_sliced && chunk_ids.len() == df_self.height();

if left_join_no_duplicate_matches {
df_self.clone()
} else {
// left join keys are in ascending order
Expand All @@ -76,10 +85,15 @@ pub trait JoinDispatch: IntoDf {
&self,
join_tuples: &[IdxSize],
left_join: bool,
was_sliced: bool,
sorted_tuple_idx: bool,
) -> DataFrame {
let df_self = self.to_df();
if left_join && join_tuples.len() == df_self.height() {

let left_join_no_duplicate_matches =
left_join && !was_sliced && join_tuples.len() == df_self.height();

if left_join_no_duplicate_matches {
df_self.clone()
} else {
// left join tuples are always in ascending order
Expand Down
9 changes: 8 additions & 1 deletion crates/polars-ops/src/frame/join/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,14 @@ trait DataFrameJoinOpsPrivate: IntoDf {

let (df_left, df_right) = POOL.join(
// SAFETY: join indices are known to be in bounds
|| unsafe { left_df._create_left_df_from_slice(join_tuples_left, false, sorted) },
|| unsafe {
left_df._create_left_df_from_slice(
join_tuples_left,
false,
args.slice.is_some(),
sorted,
)
},
|| unsafe {
if let Some(drop_names) = drop_names {
other.drop_many(drop_names)
Expand Down
1 change: 1 addition & 0 deletions py-polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ all = [
"parquet",
"ipc",
"polars-python/all",
"performant",
]

default = ["all", "nightly"]
17 changes: 17 additions & 0 deletions py-polars/tests/unit/operations/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,3 +1082,20 @@ def test_cross_join_no_on_keys(on_args: dict[str, str]) -> None:
msg = "cross join should not pass join keys"
with pytest.raises(ValueError, match=msg):
df1.join(df2, how="cross", **on_args) # type: ignore[arg-type]


@pytest.mark.parametrize("set_sorted", [True, False])
def test_left_join_slice_pushdown_19405(set_sorted: bool) -> None:
left = pl.LazyFrame({"k": [1, 2, 3, 4, 0]})
right = pl.LazyFrame({"k": [1, 1, 1, 1, 0]})

if set_sorted:
# The data isn't actually sorted on purpose to ensure we default to a
# hash join unless we set the sorted flag here, in case there is new
# code in the future that automatically identifies sortedness during
# Series construction from Python.
left = left.set_sorted("k")
right = right.set_sorted("k")

q = left.join(right, on="k", how="left").head(5)
assert_frame_equal(q.collect(), pl.DataFrame({"k": [1, 1, 1, 1, 2]}))

0 comments on commit 3fe25c5

Please sign in to comment.