From 9051a6adf36a063283e690bcbb0f6e89bbc4ceb1 Mon Sep 17 00:00:00 2001 From: Ryan Eakman <6326532+eakmanrq@users.noreply.github.com> Date: Wed, 22 May 2024 20:20:06 -0700 Subject: [PATCH] fix: properly support compound filter expressions (#22) --- sqlframe/base/dataframe.py | 4 +++- tests/unit/standalone/test_dataframe.py | 9 +++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/sqlframe/base/dataframe.py b/sqlframe/base/dataframe.py index bbb20b9..62792a7 100644 --- a/sqlframe/base/dataframe.py +++ b/sqlframe/base/dataframe.py @@ -608,7 +608,9 @@ def alias(self, name: str, **kwargs) -> Self: @operation(Operation.WHERE) def where(self, column: t.Union[Column, str, bool], **kwargs) -> Self: if isinstance(column, str): - col = sqlglot.parse_one(column, dialect=self.session.input_dialect) + col = self._ensure_and_normalize_col( + sqlglot.parse_one(column, dialect=self.session.input_dialect) + ) else: col = self._ensure_and_normalize_col(column) return self.copy(expression=self.expression.where(col.expression)) diff --git a/tests/unit/standalone/test_dataframe.py b/tests/unit/standalone/test_dataframe.py index a382caf..9d1ee27 100644 --- a/tests/unit/standalone/test_dataframe.py +++ b/tests/unit/standalone/test_dataframe.py @@ -55,3 +55,12 @@ def test_with_column_duplicate_alias(standalone_employee: StandaloneDataFrame): df.sql(pretty=False) == "SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`age` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" ) + + +def test_where_expr(standalone_employee: StandaloneDataFrame): + df = standalone_employee.where("fname = 'Jack' AND age = 37") + assert df.columns == ["employee_id", "fname", "lname", "age", "store_id"] + assert ( + df.sql(pretty=False) + == "SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`) WHERE `a1`.`age` = 37 AND CAST(`a1`.`fname` AS STRING) = 'Jack'" + )