Skip to content

Commit

Permalink
fix: expand star in select expressions (#95)
Browse files Browse the repository at this point in the history
  • Loading branch information
eakmanrq authored Jun 27, 2024
1 parent a344eda commit 04165e6
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 21 deletions.
25 changes: 23 additions & 2 deletions sqlframe/base/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def _ensure_and_normalize_cols(

cols = self._ensure_list_of_columns(cols)
normalize(self.session, expression or self.expression, cols)
return cols
return list(flatten([self._expand_star(col) for col in cols]))

def _ensure_and_normalize_col(self, col):
from sqlframe.base.column import Column
Expand Down Expand Up @@ -514,6 +514,27 @@ def _get_select_expressions(
select_expressions.append(expression_select_pair) # type: ignore
return select_expressions

def _expand_star(self, col: Column) -> t.List[Column]:
from sqlframe.base.column import Column

if isinstance(col.column_expression, exp.Star):
return self._get_outer_select_columns(self.expression)
elif (
isinstance(col.column_expression, exp.Column)
and isinstance(col.column_expression.this, exp.Star)
and col.column_expression.args.get("table")
):
for cte in self.expression.ctes:
if cte.alias_or_name == col.column_expression.args["table"].this:
return [
Column.ensure_col(exp.column(x.column_alias_or_name, cte.alias_or_name))
for x in self._get_outer_select_columns(cte)
]
raise ValueError(
f"Could not find table to expand star: {col.column_expression.args['table']}"
)
return [col]

@t.overload
def sql(
self,
Expand Down Expand Up @@ -1555,7 +1576,7 @@ def show(
result = self.session._fetch_rows(sql)
table = PrettyTable()
if row := seq_get(result, 0):
table.field_names = list(row.asDict().keys())
table.field_names = row._unique_field_names
for row in result:
table.add_row(list(row))
print(table)
Expand Down
10 changes: 10 additions & 0 deletions sqlframe/base/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,3 +416,13 @@ def __repr__(self) -> str:
)
else:
return "<Row(%s)>" % ", ".join(repr(field) for field in self)

# SQLFrame Specific
@property
def _unique_field_names(self) -> t.List[str]:
fields = []
for i, field in enumerate(self.__fields__):
if field in fields:
field = field + "_" + str(i)
fields.append(field)
return fields
1 change: 1 addition & 0 deletions sqlframe/bigquery/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
make_date_from_date_func as make_date,
to_date_from_timestamp as to_date,
last_day_with_cast as last_day,
sha1_force_sha1_and_to_hex as sha,
sha1_force_sha1_and_to_hex as sha1,
hash_from_farm_fingerprint as hash,
base64_from_blob as base64,
Expand Down
44 changes: 25 additions & 19 deletions tests/integration/engines/test_engine_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,36 +24,42 @@ def test_collect(get_engine_df: t.Callable[[str], _BaseDataFrame], get_func):


def test_show(
get_engine_df: t.Callable[[str], _BaseDataFrame], capsys, caplog, is_snowflake: t.Callable
get_engine_df: t.Callable[[str], _BaseDataFrame],
get_func,
capsys,
caplog,
is_snowflake: t.Callable,
):
employee = get_engine_df("employee")
lit = get_func("lit", employee.session)
employee = employee.select("*", lit(1).alias("one"))
employee.show()
captured = capsys.readouterr()
if is_snowflake():
assert (
captured.out
== """+-------------+--------+-----------+-----+----------+
| EMPLOYEE_ID | FNAME | LNAME | AGE | STORE_ID |
+-------------+--------+-----------+-----+----------+
| 1 | Jack | Shephard | 37 | 1 |
| 2 | John | Locke | 65 | 1 |
| 3 | Kate | Austen | 37 | 2 |
| 4 | Claire | Littleton | 27 | 2 |
| 5 | Hugo | Reyes | 29 | 100 |
+-------------+--------+-----------+-----+----------+\n"""
== """+-------------+--------+-----------+-----+----------+-----+
| EMPLOYEE_ID | FNAME | LNAME | AGE | STORE_ID | ONE |
+-------------+--------+-----------+-----+----------+-----+
| 1 | Jack | Shephard | 37 | 1 | 1 |
| 2 | John | Locke | 65 | 1 | 1 |
| 3 | Kate | Austen | 37 | 2 | 1 |
| 4 | Claire | Littleton | 27 | 2 | 1 |
| 5 | Hugo | Reyes | 29 | 100 | 1 |
+-------------+--------+-----------+-----+----------+-----+\n"""
)
else:
assert (
captured.out
== """+-------------+--------+-----------+-----+----------+
| employee_id | fname | lname | age | store_id |
+-------------+--------+-----------+-----+----------+
| 1 | Jack | Shephard | 37 | 1 |
| 2 | John | Locke | 65 | 1 |
| 3 | Kate | Austen | 37 | 2 |
| 4 | Claire | Littleton | 27 | 2 |
| 5 | Hugo | Reyes | 29 | 100 |
+-------------+--------+-----------+-----+----------+\n"""
== """+-------------+--------+-----------+-----+----------+-----+
| employee_id | fname | lname | age | store_id | one |
+-------------+--------+-----------+-----+----------+-----+
| 1 | Jack | Shephard | 37 | 1 | 1 |
| 2 | John | Locke | 65 | 1 | 1 |
| 3 | Kate | Austen | 37 | 2 | 1 |
| 4 | Claire | Littleton | 27 | 2 | 1 |
| 5 | Hugo | Reyes | 29 | 100 | 1 |
+-------------+--------+-----------+-----+----------+-----+\n"""
)
assert "Truncate is ignored so full results will be displayed" not in caplog.text
employee.show(truncate=True)
Expand Down
10 changes: 10 additions & 0 deletions tests/integration/test_int_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,16 @@ def test_simple_select_from_table(
compare_frames(df, dfs)


def test_select_star_from_table(
pyspark_employee: PySparkDataFrame,
get_df: t.Callable[[str], _BaseDataFrame],
compare_frames: t.Callable,
):
df = pyspark_employee
dfs = get_df("employee").session.read.table("employee")
compare_frames(df, dfs)


def test_simple_select_df_attribute(
pyspark_employee: PySparkDataFrame,
get_df: t.Callable[[str], _BaseDataFrame],
Expand Down
14 changes: 14 additions & 0 deletions tests/unit/standalone/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,17 @@ def test_missing_method(standalone_employee: StandaloneDataFrame):
UnsupportedOperationError, match="Tried to call a column which is unexpected.*"
):
standalone_employee.missing_method("blah")


def test_expand_star(standalone_employee: StandaloneDataFrame):
assert (
standalone_employee.select("*").sql(pretty=False, optimize=False)
== "WITH t51718876 AS (SELECT CAST(employee_id AS INT) AS employee_id, CAST(fname AS STRING) AS fname, CAST(lname AS STRING) AS lname, CAST(age AS INT) AS age, CAST(store_id AS INT) AS store_id FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS a1(employee_id, fname, lname, age, store_id)) SELECT employee_id, fname, lname, age, store_id FROM t51718876"
)


def test_expand_star_table_alias(standalone_employee: StandaloneDataFrame):
assert (
standalone_employee.alias("blah").select("blah.*").sql(pretty=False, optimize=False)
== "WITH t51718876 AS (SELECT CAST(employee_id AS INT) AS employee_id, CAST(fname AS STRING) AS fname, CAST(lname AS STRING) AS lname, CAST(age AS INT) AS age, CAST(store_id AS INT) AS store_id FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS a1(employee_id, fname, lname, age, store_id)), t37842204 AS (SELECT employee_id, fname, lname, age, store_id FROM t51718876) SELECT t37842204.employee_id, t37842204.fname, t37842204.lname, t37842204.age, t37842204.store_id FROM t37842204"
)

0 comments on commit 04165e6

Please sign in to comment.