diff --git a/docs/bigquery.md b/docs/bigquery.md index 905543d..7311109 100644 --- a/docs/bigquery.md +++ b/docs/bigquery.md @@ -446,6 +446,7 @@ See something that you would like to see supported? [Open an issue](https://gith * [trim](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.trim.html) * [trunc](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.trunc.html) * Shorthand expressions not supported. Ex: Use `month` instead of `mon` +* [try_to_timestamp](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.try_to_timestamp.html) * [typeof](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.typeof.html) * [ucase](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.ucase.html) * [unbase64](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.unbase64.html) diff --git a/docs/duckdb.md b/docs/duckdb.md index 9a1b257..070fef5 100644 --- a/docs/duckdb.md +++ b/docs/duckdb.md @@ -279,6 +279,7 @@ See something that you would like to see supported? [Open an issue](https://gith * [date_format](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.date_format.html) * [date_sub](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.date_sub.html) * [date_trunc](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.date_trunc.html) +* [day](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.day.html) * [dayofmonth](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.dayofmonth.html) * [dayofweek](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.dayofweek.html) * [dayofyear](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.dayofyear.html) @@ -411,6 +412,7 @@ See something that you would like to see supported? [Open an issue](https://gith * [trim](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.trim.html) * [trunc](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.trunc.html) * [try_element_at](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.try_element_at.html) +* [try_to_timestamp](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.try_to_timestamp.html) * [typeof](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.typeof.html) * [ucase](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.ucase.html) * [unbase64](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.unbase64.html) diff --git a/docs/postgres.md b/docs/postgres.md index 2483566..05ce5cf 100644 --- a/docs/postgres.md +++ b/docs/postgres.md @@ -403,6 +403,8 @@ See something that you would like to see supported? [Open an issue](https://gith * [trunc](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.trunc.html) * [try_element_at](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.try_element_at.html) * Negative index returns null and cannot lookup elements in maps +* [try_to_timestamp](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.try_to_timestamp.html) +* [typeof](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.typeof.html) * [ucase](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.ucase.html) * [unbase64](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.unbase64.html) * [unix_timestamp](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.unix_timestamp.html) diff --git a/docs/snowflake.md b/docs/snowflake.md index 161704b..393f501 100644 --- a/docs/snowflake.md +++ b/docs/snowflake.md @@ -444,6 +444,8 @@ See something that you would like to see supported? [Open an issue](https://gith * [translate](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.translate.html) * [trim](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.trim.html) * [trunc](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.trunc.html) +* [try_to_timestamp](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.try_to_timestamp.html) +* [typeof](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.typeof.html) * [ucase](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.ucase.html) * [unbase64](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.unbase64.html) * [unhex](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.unhex.html) diff --git a/sqlframe/base/function_alternatives.py b/sqlframe/base/function_alternatives.py index 95e0a77..12faf71 100644 --- a/sqlframe/base/function_alternatives.py +++ b/sqlframe/base/function_alternatives.py @@ -561,6 +561,14 @@ def to_date_from_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> return to_date(to_timestamp(col, format)) +def to_date_time_format(col: ColumnOrName, format: t.Optional[str] = None) -> Column: + from sqlframe.base.functions import to_date + + lit = get_func_from_session("lit") + format = lit(format or spark_default_time_format()) + return to_date(col, format=format) + + def last_day_with_cast(col: ColumnOrName) -> Column: from sqlframe.base.functions import last_day @@ -1519,3 +1527,86 @@ def to_unix_timestamp_include_default_format( else: format = format_time_from_spark(format) return to_unix_timestamp(timestamp, format) + + +def day_with_try_to_timestamp(col: ColumnOrName) -> Column: + from sqlframe.base.functions import day + + try_to_timestamp = get_func_from_session("try_to_timestamp") + to_date = get_func_from_session("to_date") + when = get_func_from_session("when") + _is_string = get_func_from_session("_is_string") + coalesce = get_func_from_session("coalesce") + return day( + when( + _is_string(col), + coalesce(try_to_timestamp(col), to_date(col)), + ).otherwise(col) + ) + + +def try_to_timestamp_strptime(col: ColumnOrName, format: t.Optional[ColumnOrName] = None) -> Column: + lit = get_func_from_session("lit") + + format = lit(format or spark_default_time_format()) + return Column.invoke_anonymous_function(col, "TRY_STRPTIME", format_time_from_spark(format)) # type: ignore + + +def try_to_timestamp_safe(col: ColumnOrName, format: t.Optional[ColumnOrName] = None) -> Column: + lit = get_func_from_session("lit") + + format = lit(format or spark_default_time_format()) + return Column.invoke_anonymous_function( + format_time_from_spark(format), # type: ignore + "SAFE.PARSE_TIMESTAMP", + col, # type: ignore + ) + + +def try_to_timestamp_pgtemp(col: ColumnOrName, format: t.Optional[ColumnOrName] = None) -> Column: + lit = get_func_from_session("lit") + + format = lit(format or spark_default_time_format()) + return Column.invoke_anonymous_function( + col, + "pg_temp.TRY_TO_TIMESTAMP", + format_time_from_spark(format), # type: ignore + ) + + +def typeof_pg_typeof(col: ColumnOrName) -> Column: + return Column.invoke_anonymous_function(col, "pg_typeof").cast("regtype").cast("text") + + +def typeof_from_variant(col: ColumnOrName) -> Column: + col = Column.invoke_anonymous_function(col, "TO_VARIANT") + return Column.invoke_anonymous_function(col, "TYPEOF") + + +def _is_string_using_typeof_varchar(col: ColumnOrName) -> Column: + typeof = get_func_from_session("typeof") + lit = get_func_from_session("lit") + return lit(typeof(col) == lit("VARCHAR")) + + +def _is_string_using_typeof_char_varying(col: ColumnOrName) -> Column: + typeof = get_func_from_session("typeof") + lit = get_func_from_session("lit") + return lit( + (typeof(col) == lit("text")) + | (typeof(col) == lit("character varying")) + | (typeof(col) == lit("unknown")) + | (typeof(col) == lit("text")) + ) + + +def _is_string_using_typeof_string(col: ColumnOrName) -> Column: + typeof = get_func_from_session("typeof") + lit = get_func_from_session("lit") + return lit(typeof(col) == lit("STRING")) + + +def _is_string_using_typeof_string_lcase(col: ColumnOrName) -> Column: + typeof = get_func_from_session("typeof") + lit = get_func_from_session("lit") + return lit(typeof(col) == lit("string")) diff --git a/sqlframe/base/functions.py b/sqlframe/base/functions.py index 1d4aa76..37bb94b 100644 --- a/sqlframe/base/functions.py +++ b/sqlframe/base/functions.py @@ -13,7 +13,12 @@ from sqlframe.base.column import Column from sqlframe.base.decorators import func_metadata as meta -from sqlframe.base.util import format_time_from_spark, spark_default_time_format +from sqlframe.base.util import ( + format_time_from_spark, + get_func_from_session, + spark_default_date_format, + spark_default_time_format, +) if t.TYPE_CHECKING: from pyspark.sql.session import SparkContext @@ -877,7 +882,7 @@ def months_between( @meta() def to_date(col: ColumnOrName, format: t.Optional[str] = None) -> Column: - format = lit(format or spark_default_time_format()) + format = lit(format or spark_default_date_format()) if format is not None: return Column.invoke_expression_over_column( col, expression.TsOrDsToDate, format=format_time_from_spark(format) @@ -1743,7 +1748,7 @@ def map_zip_with( return Column.invoke_anonymous_function(col1, "MAP_ZIP_WITH", col2, Column(f_expression)) -@meta(unsupported_engines=["postgres", "snowflake"]) +@meta() def typeof(col: ColumnOrName) -> Column: return Column.invoke_anonymous_function(col, "TYPEOF") @@ -2162,7 +2167,7 @@ def datepart(field: ColumnOrName, source: ColumnOrName) -> Column: return Column.invoke_anonymous_function(field, "datepart", source) -@meta(unsupported_engines="*") +@meta(unsupported_engines=["bigquery", "postgres", "snowflake"]) def day(col: ColumnOrName) -> Column: return Column.invoke_expression_over_column(col, expression.Day) @@ -5277,7 +5282,7 @@ def try_element_at(col: ColumnOrName, extraction: ColumnOrName) -> Column: ) -@meta(unsupported_engines="*") +@meta() def try_to_timestamp(col: ColumnOrName, format: t.Optional[ColumnOrName] = None) -> Column: """ Parses the `col` with the `format` to a timestamp. The function always @@ -5302,10 +5307,8 @@ def try_to_timestamp(col: ColumnOrName, format: t.Optional[ColumnOrName] = None) >>> df.select(try_to_timestamp(df.t, lit('yyyy-MM-dd HH:mm:ss')).alias('dt')).collect() [Row(dt=datetime.datetime(1997, 2, 28, 10, 30))] """ - if format is not None: - return Column.invoke_anonymous_function(col, "try_to_timestamp", format) - else: - return Column.invoke_anonymous_function(col, "try_to_timestamp") + format = lit(format or spark_default_time_format()) + return Column.invoke_anonymous_function(col, "try_to_timestamp", format_time_from_spark(format)) # type: ignore @meta() @@ -5797,6 +5800,20 @@ def years(col: ColumnOrName) -> Column: return Column.invoke_anonymous_function(col, "years") +# SQLFrame specific +@meta() +def _is_string(col: ColumnOrName) -> Column: + col = Column.invoke_anonymous_function(col, "TO_VARIANT") + return Column.invoke_anonymous_function(col, "IS_VARCHAR") + + +@meta() +def _is_date(col: ColumnOrName) -> Column: + typeof = get_func_from_session("typeof") + upper = get_func_from_session("upper") + return lit(upper(typeof(col)) == lit("DATE")) + + @meta() def _lambda_quoted(value: str) -> t.Optional[bool]: return False if value == "_" else None diff --git a/sqlframe/base/util.py b/sqlframe/base/util.py index 2787df2..787efd1 100644 --- a/sqlframe/base/util.py +++ b/sqlframe/base/util.py @@ -365,3 +365,7 @@ def format_time_from_spark(value: ColumnOrLiteral) -> Column: def spark_default_time_format() -> str: return Dialect["spark"].TIME_FORMAT.strip("'") + + +def spark_default_date_format() -> str: + return Dialect["spark"].DATE_FORMAT.strip("'") diff --git a/sqlframe/bigquery/functions.py b/sqlframe/bigquery/functions.py index c8fc74d..05a4233 100644 --- a/sqlframe/bigquery/functions.py +++ b/sqlframe/bigquery/functions.py @@ -72,6 +72,8 @@ array_union_using_array_concat as array_union, sequence_from_generate_array as sequence, position_as_strpos as position, + try_to_timestamp_safe as try_to_timestamp, + _is_string_using_typeof_string as _is_string, ) diff --git a/sqlframe/duckdb/functions.py b/sqlframe/duckdb/functions.py index f6dd9a2..5299a03 100644 --- a/sqlframe/duckdb/functions.py +++ b/sqlframe/duckdb/functions.py @@ -46,4 +46,7 @@ array_max_from_sort as array_max, sequence_from_generate_series as sequence, try_element_at_zero_based as try_element_at, + day_with_try_to_timestamp as day, + try_to_timestamp_strptime as try_to_timestamp, + _is_string_using_typeof_varchar as _is_string, ) diff --git a/sqlframe/postgres/functions.py b/sqlframe/postgres/functions.py index 0248086..99dff56 100644 --- a/sqlframe/postgres/functions.py +++ b/sqlframe/postgres/functions.py @@ -64,4 +64,7 @@ right_cast_len as right, position_cast_start as position, try_element_at_zero_based as try_element_at, + try_to_timestamp_pgtemp as try_to_timestamp, + typeof_pg_typeof as typeof, + _is_string_using_typeof_char_varying as _is_string, ) diff --git a/sqlframe/postgres/session.py b/sqlframe/postgres/session.py index 1ecb0cf..021c7da 100644 --- a/sqlframe/postgres/session.py +++ b/sqlframe/postgres/session.py @@ -38,6 +38,14 @@ def __init__(self, conn: t.Optional[psycopg2_connection] = None): if not hasattr(self, "_conn"): super().__init__(conn) self._execute("CREATE EXTENSION IF NOT EXISTS fuzzystrmatch") + self._execute("""CREATE OR REPLACE FUNCTION pg_temp.try_to_timestamp(input_text TEXT, format TEXT) +RETURNS TIMESTAMP AS $$ +BEGIN + RETURN TO_TIMESTAMP(input_text, format); +EXCEPTION WHEN OTHERS THEN + RETURN NULL; +END; +$$ LANGUAGE plpgsql;""") def _fetch_rows( self, sql: t.Union[str, exp.Expression], *, quote_identifiers: bool = True diff --git a/sqlframe/snowflake/functions.py b/sqlframe/snowflake/functions.py index 2907609..fd6d6a0 100644 --- a/sqlframe/snowflake/functions.py +++ b/sqlframe/snowflake/functions.py @@ -63,4 +63,6 @@ map_concat_using_map_cat as map_concat, sequence_from_array_generate_range as sequence, to_number_using_to_double as to_number, + typeof_from_variant as typeof, + to_date_time_format as to_date, ) diff --git a/sqlframe/spark/functions.py b/sqlframe/spark/functions.py index d9cc8bf..ad12fdf 100644 --- a/sqlframe/spark/functions.py +++ b/sqlframe/spark/functions.py @@ -17,4 +17,5 @@ percentile_without_disc as percentile, add_months_by_multiplication as add_months, arrays_overlap_renamed as arrays_overlap, + _is_string_using_typeof_string_lcase as _is_string, ) diff --git a/tests/integration/engines/postgres/test_postgres_catalog.py b/tests/integration/engines/postgres/test_postgres_catalog.py index 3bf6116..1225f77 100644 --- a/tests/integration/engines/postgres/test_postgres_catalog.py +++ b/tests/integration/engines/postgres/test_postgres_catalog.py @@ -36,7 +36,9 @@ def test_list_databases(postgres_session: PostgresSession): Database(name="db1", catalog="tests", description=None, locationUri=""), Database(name="information_schema", catalog="tests", description=None, locationUri=""), Database(name="pg_catalog", catalog="tests", description=None, locationUri=""), + Database(name="pg_temp_3", catalog="tests", description=None, locationUri=""), Database(name="pg_toast", catalog="tests", description=None, locationUri=""), + Database(name="pg_toast_temp_3", catalog="tests", description=None, locationUri=""), Database(name="public", catalog="tests", description=None, locationUri=""), ] diff --git a/tests/integration/engines/test_int_functions.py b/tests/integration/engines/test_int_functions.py index 1c791ec..99807cd 100644 --- a/tests/integration/engines/test_int_functions.py +++ b/tests/integration/engines/test_int_functions.py @@ -204,6 +204,24 @@ def test_typeof(get_session_and_func, get_types, arg, expected): pytest.skip("BigQuery doesn't support binary") if expected == "timestamp": expected = "datetime" + if isinstance(session, PostgresSession): + if expected.startswith("map"): + pytest.skip("Postgres doesn't support map types") + elif expected.startswith("struct"): + pytest.skip("Postgres doesn't support struct types") + elif expected == "binary": + pytest.skip("Postgres doesn't support binary") + if isinstance(session, SnowflakeSession): + if expected == "bigint": + expected = "int" + elif expected == "string": + expected = "varchar" + elif expected.startswith("map") or expected.startswith("struct"): + expected = "object" + elif expected.startswith("array"): + pytest.skip("Snowflake doesn't handle arrays properly in values clause") + elif expected == "timestamp": + expected = "timestampntz" result = df.select(typeof("col").alias("test")).first()[0] assert exp.DataType.build(result, dialect=dialect) == exp.DataType.build( expected, dialect=dialect @@ -4849,12 +4867,16 @@ def test_try_to_timestamp(get_session_and_func, get_func): session, try_to_timestamp = get_session_and_func("try_to_timestamp") lit = get_func("lit", session) df = session.createDataFrame([("1997-02-28 10:30:00",)], ["t"]) - assert df.select(try_to_timestamp(df.t).alias("dt")).first()[0] == datetime.datetime( - 1997, 2, 28, 10, 30 - ) - assert df.select(try_to_timestamp(df.t, lit("yyyy-MM-dd HH:mm:ss")).alias("dt")).first()[ - 0 - ] == datetime.datetime(1997, 2, 28, 10, 30) + result = df.select(try_to_timestamp(df.t).alias("dt")).first()[0] + if isinstance(session, BigQuerySession): + assert result == datetime.datetime(1997, 2, 28, 10, 30, tzinfo=datetime.timezone.utc) + else: + assert result == datetime.datetime(1997, 2, 28, 10, 30) + result = df.select(try_to_timestamp(df.t, lit("yyyy-MM-dd HH:mm:ss")).alias("dt")).first()[0] + if isinstance(session, BigQuerySession): + assert result == datetime.datetime(1997, 2, 28, 10, 30, tzinfo=datetime.timezone.utc) + else: + assert result == datetime.datetime(1997, 2, 28, 10, 30) def test_ucase(get_session_and_func, get_func): @@ -5010,3 +5032,26 @@ def test_xpath_string(get_session_and_func, get_func): lit = get_func("lit", session) df = session.createDataFrame([("bcc",)], ["x"]) assert df.select(xpath_string(df.x, lit("a/c")).alias("r")).first()[0] == "cc" + + +def test_is_string(get_session_and_func, get_func): + session, _is_string = get_session_and_func("_is_string") + lit = get_func("lit", session) + assert session.range(1).select(_is_string(lit("value")), _is_string(lit(1))).collect() == [ + Row(v1=True, v2=False) + ] + + +def test_is_date(get_session_and_func, get_func): + session, _is_date = get_session_and_func("_is_date") + to_date = get_func("to_date", session) + lit = get_func("lit", session) + assert session.range(1).select( + _is_date(to_date(lit("2021-01-01"), "yyyy-MM-dd")), _is_date(lit("2021-01-01")) + ).collect() == [Row(v1=True, v2=False)] + + +# def test_ + +# typeof = get_func("typeof", session) +# assert session.range(1).select(typeof(to_date(lit("2021-01-01"), 'yyyy-MM-dd'))).collect() == [Row(value=True)] diff --git a/tests/unit/standalone/test_functions.py b/tests/unit/standalone/test_functions.py index 3c1a69b..ef4cf53 100644 --- a/tests/unit/standalone/test_functions.py +++ b/tests/unit/standalone/test_functions.py @@ -4636,9 +4636,9 @@ def test_try_element_at(expression, expected): @pytest.mark.parametrize( "expression, expected", [ - (SF.try_to_timestamp("cola"), "TRY_TO_TIMESTAMP(cola)"), - (SF.try_to_timestamp(SF.col("cola")), "TRY_TO_TIMESTAMP(cola)"), - (SF.try_to_timestamp("cola", "colb"), "TRY_TO_TIMESTAMP(cola, colb)"), + (SF.try_to_timestamp("cola"), "TRY_TO_TIMESTAMP(cola, 'yyyy-MM-dd HH:mm:ss')"), + (SF.try_to_timestamp(SF.col("cola")), "TRY_TO_TIMESTAMP(cola, 'yyyy-MM-dd HH:mm:ss')"), + (SF.try_to_timestamp("cola", "blah"), "TRY_TO_TIMESTAMP(cola, 'blah')"), ], ) def test_try_to_timestamp(expression, expected):