diff --git a/sqlframe/base/dataframe.py b/sqlframe/base/dataframe.py index faa57b6..7008fdf 100644 --- a/sqlframe/base/dataframe.py +++ b/sqlframe/base/dataframe.py @@ -79,6 +79,23 @@ "SHUFFLE_REPLICATE_NL", } +JOIN_TYPE_MAPPING = { + "inner": "inner", + "cross": "cross", + "outer": "full_outer", + "full": "full_outer", + "fullouter": "full_outer", + "left": "left_outer", + "leftouter": "left_outer", + "right": "right_outer", + "rightouter": "right_outer", + "semi": "left_semi", + "leftsemi": "left_semi", + "left_semi": "left_semi", + "anti": "left_anti", + "leftanti": "left_anti", + "left_anti": "left_anti", +} DF = t.TypeVar("DF", bound="BaseDataFrame") @@ -944,16 +961,20 @@ def join( ) -> Self: from sqlframe.base.functions import coalesce - if on is None: + if (on is None) and ("cross" not in how): logger.warning("Got no value for on. This appears to change the join to a cross join.") how = "cross" + if (on is not None) and ("cross" in how): + # Not a lot of doc, but Spark handles cross with predicate as an inner join + # https://learn.microsoft.com/en-us/dotnet/api/microsoft.spark.sql.dataframe.join + logger.warning("Got cross join with an 'on' value. This will result in an inner join.") + how = "inner" other_df = other_df._convert_leaf_to_cte() join_expression = self._add_ctes_to_expression(self.expression, other_df.expression.ctes) # We will determine actual "join on" expression later so we don't provide it at first - join_expression = join_expression.join( - join_expression.ctes[-1].alias, join_type=how.replace("_", " ") - ) + join_type = JOIN_TYPE_MAPPING.get(how, how).replace("_", " ") + join_expression = join_expression.join(join_expression.ctes[-1].alias, join_type=join_type) self_columns = self._get_outer_select_columns(join_expression) other_columns = self._get_outer_select_columns(other_df.expression) join_columns = self._ensure_and_normalize_cols(on) @@ -961,7 +982,12 @@ def join( # Determines the join clause and select columns to be used passed on what type of columns were provided for # the join. The columns returned changes based on how the on expression is provided. - if how != "cross": + select_columns = ( + self_columns + if join_type in ["left anti", "left semi"] + else self_columns + other_columns + ) + if join_type != "cross": if isinstance(join_columns[0].expression, exp.Column): """ Unique characteristics of join on column names only: @@ -992,7 +1018,7 @@ def join( if not isinstance(column.expression.this, exp.Star) else column.sql() ) - for column in self_columns + other_columns + for column in select_columns ] select_column_names = [ column_name @@ -1010,13 +1036,11 @@ def join( * The left join dataframe columns go first and right come after. No sort preference is given to join columns """ join_clause = self._normalize_join_clause(join_columns, join_expression) - select_column_names = [ - column.alias_or_name for column in self_columns + other_columns - ] + select_column_names = [column.alias_or_name for column in select_columns] # Update the on expression with the actual join clause to replace the dummy one from before else: - select_column_names = [column.alias_or_name for column in self_columns + other_columns] + select_column_names = [column.alias_or_name for column in select_columns] join_clause = None join_expression.args["joins"][-1].set("on", join_clause.expression if join_clause else None) new_df = self.copy(expression=join_expression) diff --git a/tests/integration/test_int_dataframe.py b/tests/integration/test_int_dataframe.py index 9aa9659..1659250 100644 --- a/tests/integration/test_int_dataframe.py +++ b/tests/integration/test_int_dataframe.py @@ -435,6 +435,43 @@ def test_join_inner( compare_frames(df_joined, dfs_joined, sort=True) +@pytest.mark.parametrize( + "how", + [ + "inner", + "cross", + "outer", + "full", + "fullouter", + "full_outer", + "left", + "leftouter", + "left_outer", + "right", + "rightouter", + "right_outer", + "semi", + "leftsemi", + "left_semi", + "anti", + "leftanti", + "left_anti", + ], +) +def test_join_various_how( + pyspark_employee: PySparkDataFrame, + pyspark_store: PySparkDataFrame, + get_df: t.Callable[[str], BaseDataFrame], + compare_frames: t.Callable, + how: str, +): + employee = get_df("employee") + store = get_df("store") + df_joined = pyspark_employee.join(pyspark_store, on=["store_id"], how=how) + dfs_joined = employee.join(store, on=["store_id"], how=how) + compare_frames(df_joined, dfs_joined, sort=True) + + def test_join_inner_no_select( pyspark_employee: PySparkDataFrame, pyspark_store: PySparkDataFrame,