From eeffd41a8626ce40970b9d6ad870bd6cac57c900 Mon Sep 17 00:00:00 2001 From: Ryan Eakman <6326532+eakmanrq@users.noreply.github.com> Date: Fri, 28 Jun 2024 19:56:17 -0700 Subject: [PATCH] feat: add array_append support all engines (#106) --- sqlframe/base/function_alternatives.py | 13 +++++++++++++ sqlframe/base/functions.py | 2 +- sqlframe/bigquery/functions.py | 1 + sqlframe/duckdb/functions.py | 1 + tests/integration/engines/test_int_functions.py | 5 +++-- 5 files changed, 19 insertions(+), 3 deletions(-) diff --git a/sqlframe/base/function_alternatives.py b/sqlframe/base/function_alternatives.py index 12faf71..1a861a8 100644 --- a/sqlframe/base/function_alternatives.py +++ b/sqlframe/base/function_alternatives.py @@ -1529,6 +1529,19 @@ def to_unix_timestamp_include_default_format( return to_unix_timestamp(timestamp, format) +def array_append_list_append(col: ColumnOrName, value: ColumnOrLiteral) -> Column: + lit = get_func_from_session("lit") + value = value if isinstance(value, Column) else lit(value) + return Column.invoke_anonymous_function(col, "LIST_APPEND", value) + + +def array_append_using_array_cat(col: ColumnOrName, value: ColumnOrLiteral) -> Column: + lit = get_func_from_session("lit") + array = get_func_from_session("array") + value = value if isinstance(value, Column) else lit(value) + return Column.invoke_anonymous_function(col, "ARRAY_CONCAT", array(value)) + + def day_with_try_to_timestamp(col: ColumnOrName) -> Column: from sqlframe.base.functions import day diff --git a/sqlframe/base/functions.py b/sqlframe/base/functions.py index 37bb94b..a3cb866 100644 --- a/sqlframe/base/functions.py +++ b/sqlframe/base/functions.py @@ -1300,7 +1300,7 @@ def array_agg(col: ColumnOrName) -> Column: return Column.invoke_expression_over_column(col, expression.ArrayAgg) -@meta(unsupported_engines="*") +@meta() def array_append(col: ColumnOrName, value: ColumnOrLiteral) -> Column: value = value if isinstance(value, Column) else lit(value) return Column.invoke_anonymous_function(col, "ARRAY_APPEND", value) diff --git a/sqlframe/bigquery/functions.py b/sqlframe/bigquery/functions.py index 05a4233..5d86d59 100644 --- a/sqlframe/bigquery/functions.py +++ b/sqlframe/bigquery/functions.py @@ -74,6 +74,7 @@ position_as_strpos as position, try_to_timestamp_safe as try_to_timestamp, _is_string_using_typeof_string as _is_string, + array_append_using_array_cat as array_append, ) diff --git a/sqlframe/duckdb/functions.py b/sqlframe/duckdb/functions.py index 5299a03..63aeae3 100644 --- a/sqlframe/duckdb/functions.py +++ b/sqlframe/duckdb/functions.py @@ -49,4 +49,5 @@ day_with_try_to_timestamp as day, try_to_timestamp_strptime as try_to_timestamp, _is_string_using_typeof_varchar as _is_string, + array_append_list_append as array_append, ) diff --git a/tests/integration/engines/test_int_functions.py b/tests/integration/engines/test_int_functions.py index 99807cd..98467a1 100644 --- a/tests/integration/engines/test_int_functions.py +++ b/tests/integration/engines/test_int_functions.py @@ -2067,9 +2067,10 @@ def test_array_agg(get_session_and_func): ] -def test_array_append(get_session_and_func): +def test_array_append(get_session_and_func, get_func): session, array_append = get_session_and_func("array_append") - df = session.createDataFrame([Row(c1=["b", "a", "c"], c2="c")]) + lit = get_func("lit", session) + df = session.range(1).select(lit(["b", "a", "c"]).alias("c1"), lit("c").alias("c2")) assert df.select(array_append(df.c1, df.c2)).collect() == [ Row(value=["b", "a", "c", "c"]), ]