Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Incorrect .join(..., how="left").head(N) if N <= left_df.height() and there are duplicate matches #19422

Merged
merged 4 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions crates/polars-ops/src/frame/join/dispatch_left_right.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,14 @@ fn materialize_left_join(
if let Some((offset, len)) = args.slice {
left_idx = slice_slice(left_idx, offset, len);
}
left._create_left_df_from_slice(left_idx, true, true)
left._create_left_df_from_slice(left_idx, true, args.slice.is_some(), true)
},
ChunkJoinIds::Right(left_idx) => unsafe {
let mut left_idx = &*left_idx;
if let Some((offset, len)) = args.slice {
left_idx = slice_slice(left_idx, offset, len);
}
left.create_left_df_chunked(left_idx, true)
left.create_left_df_chunked(left_idx, true, args.slice.is_some())
},
};

Expand Down Expand Up @@ -133,7 +133,8 @@ fn materialize_left_join(
if let Some((offset, len)) = args.slice {
left_idx = slice_slice(left_idx, offset, len);
}
let materialize_left = || unsafe { left._create_left_df_from_slice(&left_idx, true, true) };
let materialize_left =
|| unsafe { left._create_left_df_from_slice(&left_idx, true, args.slice.is_some(), true) };

let mut right_idx = &*right_idx;
if let Some((offset, len)) = args.slice {
Expand Down
20 changes: 17 additions & 3 deletions crates/polars-ops/src/frame/join/hash_join/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,18 @@ pub trait JoinDispatch: IntoDf {
/// # Safety
/// Join tuples must be in bounds
#[cfg(feature = "chunked_ids")]
unsafe fn create_left_df_chunked(&self, chunk_ids: &[ChunkId], left_join: bool) -> DataFrame {
unsafe fn create_left_df_chunked(
&self,
chunk_ids: &[ChunkId],
left_join: bool,
was_sliced: bool,
) -> DataFrame {
let df_self = self.to_df();
if left_join && chunk_ids.len() == df_self.height() {

let left_join_no_duplicate_matches =
left_join && !was_sliced && chunk_ids.len() == df_self.height();

if left_join_no_duplicate_matches {
df_self.clone()
} else {
// left join keys are in ascending order
Expand All @@ -76,10 +85,15 @@ pub trait JoinDispatch: IntoDf {
&self,
join_tuples: &[IdxSize],
left_join: bool,
was_sliced: bool,
sorted_tuple_idx: bool,
) -> DataFrame {
let df_self = self.to_df();
if left_join && join_tuples.len() == df_self.height() {

let left_join_no_duplicate_matches =
left_join && !was_sliced && join_tuples.len() == df_self.height();

if left_join_no_duplicate_matches {
df_self.clone()
} else {
// left join tuples are always in ascending order
Expand Down
9 changes: 8 additions & 1 deletion crates/polars-ops/src/frame/join/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,14 @@ trait DataFrameJoinOpsPrivate: IntoDf {

let (df_left, df_right) = POOL.join(
// SAFETY: join indices are known to be in bounds
|| unsafe { left_df._create_left_df_from_slice(join_tuples_left, false, sorted) },
|| unsafe {
left_df._create_left_df_from_slice(
join_tuples_left,
false,
args.slice.is_some(),
sorted,
)
},
|| unsafe {
if let Some(drop_names) = drop_names {
other.drop_many(drop_names)
Expand Down
1 change: 1 addition & 0 deletions py-polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ all = [
"parquet",
"ipc",
"polars-python/all",
"performant",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code-path for left join on sorted keys is also affected, but it is gated behind the performant feature flag so it doesn't get compiled in debug builds.

I think maybe instead of adding it here, we can add a test step to the Python release workflow that runs through the test suite with the optimized release build?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assumed it was always included in pypolars. It should always be activated.

]

default = ["all", "nightly"]
17 changes: 17 additions & 0 deletions py-polars/tests/unit/operations/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,3 +1082,20 @@ def test_cross_join_no_on_keys(on_args: dict[str, str]) -> None:
msg = "cross join should not pass join keys"
with pytest.raises(ValueError, match=msg):
df1.join(df2, how="cross", **on_args) # type: ignore[arg-type]


@pytest.mark.parametrize("set_sorted", [True, False])
def test_left_join_slice_pushdown_19405(set_sorted: bool) -> None:
left = pl.LazyFrame({"k": [1, 2, 3, 4, 0]})
right = pl.LazyFrame({"k": [1, 1, 1, 1, 0]})

if set_sorted:
# The data isn't actually sorted on purpose to ensure we default to a
# hash join unless we set the sorted flag here, in case there is new
# code in the future that automatically identifies sortedness during
# Series construction from Python.
left = left.set_sorted("k")
right = right.set_sorted("k")

q = left.join(right, on="k", how="left").head(5)
assert_frame_equal(q.collect(), pl.DataFrame({"k": [1, 1, 1, 1, 2]}))