From 250d86ea56c0e7fa375f2a84daef0a646411292a Mon Sep 17 00:00:00 2001 From: Juan D Barreto Date: Fri, 10 Jan 2025 18:21:29 -0800 Subject: [PATCH] fix: remove alias from where (#239) * Remove Alias from expression inside where method. * Add test --- sqlframe/base/dataframe.py | 2 ++ tests/integration/test_int_dataframe.py | 14 ++++++++++++++ 2 files changed, 16 insertions(+) diff --git a/sqlframe/base/dataframe.py b/sqlframe/base/dataframe.py index 9181ad3..2640ac4 100644 --- a/sqlframe/base/dataframe.py +++ b/sqlframe/base/dataframe.py @@ -806,6 +806,8 @@ def where(self, column: t.Union[Column, str, bool], **kwargs) -> Self: ) else: col = self._ensure_and_normalize_col(column) + if isinstance(col.expression, exp.Alias): + col.expression = col.expression.this return self.copy(expression=self.expression.where(col.expression)) filter = where diff --git a/tests/integration/test_int_dataframe.py b/tests/integration/test_int_dataframe.py index bb89303..89d0144 100644 --- a/tests/integration/test_int_dataframe.py +++ b/tests/integration/test_int_dataframe.py @@ -2278,3 +2278,17 @@ def test_self_join( ) compare_frames(df_joined, dfs_joined, compare_schema=False) + + +# https://github.com/eakmanrq/sqlframe/issues/232 +def test_filter_alias( + pyspark_employee: PySparkDataFrame, + get_df: t.Callable[[str], BaseDataFrame], + compare_frames: t.Callable, +): + df_filtered = pyspark_employee.where((F.col("age") > 40).alias("age_gt_40")) + + employee = get_df("employee") + dfs_filtered = employee.where((SF.col("age") > 40).alias("age_gt_40")) + + compare_frames(df_filtered, dfs_filtered, compare_schema=False)