Skip to content

Commit

Permalink
fix: remove alias from where (#239)
Browse files Browse the repository at this point in the history
* Remove Alias from expression inside where method.

* Add test
  • Loading branch information
zerodarkzone authored Jan 11, 2025
1 parent 9dc3baa commit 250d86e
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
2 changes: 2 additions & 0 deletions sqlframe/base/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,8 @@ def where(self, column: t.Union[Column, str, bool], **kwargs) -> Self:
)
else:
col = self._ensure_and_normalize_col(column)
if isinstance(col.expression, exp.Alias):
col.expression = col.expression.this
return self.copy(expression=self.expression.where(col.expression))

filter = where
Expand Down
14 changes: 14 additions & 0 deletions tests/integration/test_int_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2278,3 +2278,17 @@ def test_self_join(
)

compare_frames(df_joined, dfs_joined, compare_schema=False)


# https://github.com/eakmanrq/sqlframe/issues/232
def test_filter_alias(
pyspark_employee: PySparkDataFrame,
get_df: t.Callable[[str], BaseDataFrame],
compare_frames: t.Callable,
):
df_filtered = pyspark_employee.where((F.col("age") > 40).alias("age_gt_40"))

employee = get_df("employee")
dfs_filtered = employee.where((SF.col("age") > 40).alias("age_gt_40"))

compare_frames(df_filtered, dfs_filtered, compare_schema=False)

0 comments on commit 250d86e

Please sign in to comment.