Skip to content

Commit

Permalink
fix: allow joining same table twice (#186)
Browse files Browse the repository at this point in the history
  • Loading branch information
eakmanrq authored Oct 13, 2024
1 parent db2f1bb commit 338a9db
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 9 deletions.
42 changes: 33 additions & 9 deletions sqlframe/base/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import sqlglot
from prettytable import PrettyTable
from sqlglot import Dialect
from sqlglot import Dialect, maybe_parse
from sqlglot import expressions as exp
from sqlglot import lineage as sqlglot_lineage
from sqlglot.helper import ensure_list, flatten, object_to_dict, seq_get
Expand Down Expand Up @@ -460,16 +460,40 @@ def _cache(self, storage_level: str) -> Self:
df.expression.ctes[-1].set("cache_storage_level", storage_level)
return df

@classmethod
def _add_ctes_to_expression(cls, expression: exp.Select, ctes: t.List[exp.CTE]) -> exp.Select:
def _add_ctes_to_expression(self, expression: exp.Select, ctes: t.List[exp.CTE]) -> exp.Select:
expression = expression.copy()
with_expression = expression.args.get("with")
if with_expression:
existing_ctes = with_expression.expressions
existsing_cte_names = {x.alias_or_name for x in existing_ctes}
existing_cte_counts = {x.alias_or_name: 0 for x in existing_ctes}
replaced_cte_names = {} # type: ignore
for cte in ctes:
if cte.alias_or_name not in existsing_cte_names:
existing_ctes.append(cte)
if replaced_cte_names:
cte = cte.transform(replace_id_value, replaced_cte_names) # type: ignore
if cte.alias_or_name in existing_cte_counts:
existing_cte_counts[cte.alias_or_name] += 10
cte.set(
"this",
cte.this.where(
exp.EQ(
this=exp.Literal.number(existing_cte_counts[cte.alias_or_name]),
expression=exp.Literal.number(
existing_cte_counts[cte.alias_or_name]
),
)
),
)
new_cte_alias = self._create_hash_from_expression(cte.this)
replaced_cte_names[cte.args["alias"].this] = maybe_parse(
new_cte_alias, dialect=self.session.input_dialect, into=exp.Identifier
)
cte.set(
"alias",
maybe_parse(
new_cte_alias, dialect=self.session.input_dialect, into=exp.TableAlias
),
)
existing_ctes.append(cte)
else:
existing_ctes = ctes
expression.set("with", exp.With(expressions=existing_ctes))
Expand Down Expand Up @@ -843,11 +867,11 @@ def join(
logger.warning("Got no value for on. This appears to change the join to a cross join.")
how = "cross"
other_df = other_df._convert_leaf_to_cte()
join_expression = self._add_ctes_to_expression(self.expression, other_df.expression.ctes)
# We will determine actual "join on" expression later so we don't provide it at first
join_expression = self.expression.join(
other_df.latest_cte_name, join_type=how.replace("_", " ")
join_expression = join_expression.join(
join_expression.ctes[-1].alias, join_type=how.replace("_", " ")
)
join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes)
self_columns = self._get_outer_select_columns(join_expression)
other_columns = self._get_outer_select_columns(other_df.expression)
join_columns = self._ensure_and_normalize_cols(on)
Expand Down
56 changes: 56 additions & 0 deletions tests/integration/test_int_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2180,3 +2180,59 @@ def test_chained_join_common_key(
dfs = dfs.join(dfs_height, how="left", on="name").join(dfs_location, how="left", on="name")

compare_frames(df, dfs, compare_schema=False)


# https://github.com/eakmanrq/sqlframe/issues/185
def test_chaining_joins_with_selects(
pyspark_employee: PySparkDataFrame,
pyspark_store: PySparkDataFrame,
pyspark_district: PySparkDataFrame,
get_df: t.Callable[[str], _BaseDataFrame],
compare_frames: t.Callable,
is_spark: t.Callable,
):
if is_spark():
pytest.skip(
"This test is not supported in Spark. This is related to how duplicate columns are handled in Spark"
)
df = (
pyspark_employee.alias("employee")
.join(
pyspark_store.filter(F.col("store_name") != "test").alias("store"),
on=F.col("employee.employee_id") == F.col("store.store_id"),
)
.join(
pyspark_district.alias("district"),
on=F.col("store.store_id") == F.col("district.district_id"),
)
.join(
pyspark_district.alias("district2"),
on=(F.col("store.store_id") + 1) == F.col("district2.district_id"),
how="left",
)
.select("*")
)

employee = get_df("employee")
store = get_df("store")
district = get_df("district")

dfs = (
employee.alias("employee")
.join(
store.filter(SF.col("store_name") != "test").alias("store"),
on=SF.col("employee.employee_id") == SF.col("store.store_id"),
)
.join(
district.alias("district"),
on=SF.col("store.store_id") == SF.col("district.district_id"),
)
.join(
district.alias("district2"),
on=(SF.col("store.store_id") + 1) == SF.col("district2.district_id"),
how="left",
)
.select("*")
)

compare_frames(df, dfs, compare_schema=False)

0 comments on commit 338a9db

Please sign in to comment.