From 338a9db7706b85b97d76a112413c070417ab2170 Mon Sep 17 00:00:00 2001 From: Ryan Eakman <6326532+eakmanrq@users.noreply.github.com> Date: Sun, 13 Oct 2024 09:31:43 -0700 Subject: [PATCH] fix: allow joining same table twice (#186) --- sqlframe/base/dataframe.py | 42 +++++++++++++++---- tests/integration/test_int_dataframe.py | 56 +++++++++++++++++++++++++ 2 files changed, 89 insertions(+), 9 deletions(-) diff --git a/sqlframe/base/dataframe.py b/sqlframe/base/dataframe.py index 36e64cb..6e9cf79 100644 --- a/sqlframe/base/dataframe.py +++ b/sqlframe/base/dataframe.py @@ -15,7 +15,7 @@ import sqlglot from prettytable import PrettyTable -from sqlglot import Dialect +from sqlglot import Dialect, maybe_parse from sqlglot import expressions as exp from sqlglot import lineage as sqlglot_lineage from sqlglot.helper import ensure_list, flatten, object_to_dict, seq_get @@ -460,16 +460,40 @@ def _cache(self, storage_level: str) -> Self: df.expression.ctes[-1].set("cache_storage_level", storage_level) return df - @classmethod - def _add_ctes_to_expression(cls, expression: exp.Select, ctes: t.List[exp.CTE]) -> exp.Select: + def _add_ctes_to_expression(self, expression: exp.Select, ctes: t.List[exp.CTE]) -> exp.Select: expression = expression.copy() with_expression = expression.args.get("with") if with_expression: existing_ctes = with_expression.expressions - existsing_cte_names = {x.alias_or_name for x in existing_ctes} + existing_cte_counts = {x.alias_or_name: 0 for x in existing_ctes} + replaced_cte_names = {} # type: ignore for cte in ctes: - if cte.alias_or_name not in existsing_cte_names: - existing_ctes.append(cte) + if replaced_cte_names: + cte = cte.transform(replace_id_value, replaced_cte_names) # type: ignore + if cte.alias_or_name in existing_cte_counts: + existing_cte_counts[cte.alias_or_name] += 10 + cte.set( + "this", + cte.this.where( + exp.EQ( + this=exp.Literal.number(existing_cte_counts[cte.alias_or_name]), + expression=exp.Literal.number( + existing_cte_counts[cte.alias_or_name] + ), + ) + ), + ) + new_cte_alias = self._create_hash_from_expression(cte.this) + replaced_cte_names[cte.args["alias"].this] = maybe_parse( + new_cte_alias, dialect=self.session.input_dialect, into=exp.Identifier + ) + cte.set( + "alias", + maybe_parse( + new_cte_alias, dialect=self.session.input_dialect, into=exp.TableAlias + ), + ) + existing_ctes.append(cte) else: existing_ctes = ctes expression.set("with", exp.With(expressions=existing_ctes)) @@ -843,11 +867,11 @@ def join( logger.warning("Got no value for on. This appears to change the join to a cross join.") how = "cross" other_df = other_df._convert_leaf_to_cte() + join_expression = self._add_ctes_to_expression(self.expression, other_df.expression.ctes) # We will determine actual "join on" expression later so we don't provide it at first - join_expression = self.expression.join( - other_df.latest_cte_name, join_type=how.replace("_", " ") + join_expression = join_expression.join( + join_expression.ctes[-1].alias, join_type=how.replace("_", " ") ) - join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes) self_columns = self._get_outer_select_columns(join_expression) other_columns = self._get_outer_select_columns(other_df.expression) join_columns = self._ensure_and_normalize_cols(on) diff --git a/tests/integration/test_int_dataframe.py b/tests/integration/test_int_dataframe.py index c1783d5..54d66bd 100644 --- a/tests/integration/test_int_dataframe.py +++ b/tests/integration/test_int_dataframe.py @@ -2180,3 +2180,59 @@ def test_chained_join_common_key( dfs = dfs.join(dfs_height, how="left", on="name").join(dfs_location, how="left", on="name") compare_frames(df, dfs, compare_schema=False) + + +# https://github.com/eakmanrq/sqlframe/issues/185 +def test_chaining_joins_with_selects( + pyspark_employee: PySparkDataFrame, + pyspark_store: PySparkDataFrame, + pyspark_district: PySparkDataFrame, + get_df: t.Callable[[str], _BaseDataFrame], + compare_frames: t.Callable, + is_spark: t.Callable, +): + if is_spark(): + pytest.skip( + "This test is not supported in Spark. This is related to how duplicate columns are handled in Spark" + ) + df = ( + pyspark_employee.alias("employee") + .join( + pyspark_store.filter(F.col("store_name") != "test").alias("store"), + on=F.col("employee.employee_id") == F.col("store.store_id"), + ) + .join( + pyspark_district.alias("district"), + on=F.col("store.store_id") == F.col("district.district_id"), + ) + .join( + pyspark_district.alias("district2"), + on=(F.col("store.store_id") + 1) == F.col("district2.district_id"), + how="left", + ) + .select("*") + ) + + employee = get_df("employee") + store = get_df("store") + district = get_df("district") + + dfs = ( + employee.alias("employee") + .join( + store.filter(SF.col("store_name") != "test").alias("store"), + on=SF.col("employee.employee_id") == SF.col("store.store_id"), + ) + .join( + district.alias("district"), + on=SF.col("store.store_id") == SF.col("district.district_id"), + ) + .join( + district.alias("district2"), + on=(SF.col("store.store_id") + 1) == SF.col("district2.district_id"), + how="left", + ) + .select("*") + ) + + compare_frames(df, dfs, compare_schema=False)