Skip to content

Commit

Permalink
feat: add full typeof support (#105)
Browse files Browse the repository at this point in the history
  • Loading branch information
eakmanrq authored Jun 29, 2024
1 parent 56fef5e commit b8dea22
Show file tree
Hide file tree
Showing 16 changed files with 203 additions and 18 deletions.
1 change: 1 addition & 0 deletions docs/bigquery.md
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,7 @@ See something that you would like to see supported? [Open an issue](https://gith
* [trim](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.trim.html)
* [trunc](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.trunc.html)
* Shorthand expressions not supported. Ex: Use `month` instead of `mon`
* [try_to_timestamp](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.try_to_timestamp.html)
* [typeof](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.typeof.html)
* [ucase](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.ucase.html)
* [unbase64](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.unbase64.html)
Expand Down
2 changes: 2 additions & 0 deletions docs/duckdb.md
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ See something that you would like to see supported? [Open an issue](https://gith
* [date_format](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.date_format.html)
* [date_sub](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.date_sub.html)
* [date_trunc](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.date_trunc.html)
* [day](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.day.html)
* [dayofmonth](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.dayofmonth.html)
* [dayofweek](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.dayofweek.html)
* [dayofyear](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.dayofyear.html)
Expand Down Expand Up @@ -411,6 +412,7 @@ See something that you would like to see supported? [Open an issue](https://gith
* [trim](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.trim.html)
* [trunc](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.trunc.html)
* [try_element_at](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.try_element_at.html)
* [try_to_timestamp](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.try_to_timestamp.html)
* [typeof](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.typeof.html)
* [ucase](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.ucase.html)
* [unbase64](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.unbase64.html)
Expand Down
2 changes: 2 additions & 0 deletions docs/postgres.md
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,8 @@ See something that you would like to see supported? [Open an issue](https://gith
* [trunc](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.trunc.html)
* [try_element_at](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.try_element_at.html)
* Negative index returns null and cannot lookup elements in maps
* [try_to_timestamp](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.try_to_timestamp.html)
* [typeof](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.typeof.html)
* [ucase](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.ucase.html)
* [unbase64](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.unbase64.html)
* [unix_timestamp](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.unix_timestamp.html)
Expand Down
2 changes: 2 additions & 0 deletions docs/snowflake.md
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,8 @@ See something that you would like to see supported? [Open an issue](https://gith
* [translate](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.translate.html)
* [trim](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.trim.html)
* [trunc](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.trunc.html)
* [try_to_timestamp](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.try_to_timestamp.html)
* [typeof](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.typeof.html)
* [ucase](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.ucase.html)
* [unbase64](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.unbase64.html)
* [unhex](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.unhex.html)
Expand Down
91 changes: 91 additions & 0 deletions sqlframe/base/function_alternatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,14 @@ def to_date_from_timestamp(col: ColumnOrName, format: t.Optional[str] = None) ->
return to_date(to_timestamp(col, format))


def to_date_time_format(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
from sqlframe.base.functions import to_date

lit = get_func_from_session("lit")
format = lit(format or spark_default_time_format())
return to_date(col, format=format)


def last_day_with_cast(col: ColumnOrName) -> Column:
from sqlframe.base.functions import last_day

Expand Down Expand Up @@ -1519,3 +1527,86 @@ def to_unix_timestamp_include_default_format(
else:
format = format_time_from_spark(format)
return to_unix_timestamp(timestamp, format)


def day_with_try_to_timestamp(col: ColumnOrName) -> Column:
from sqlframe.base.functions import day

try_to_timestamp = get_func_from_session("try_to_timestamp")
to_date = get_func_from_session("to_date")
when = get_func_from_session("when")
_is_string = get_func_from_session("_is_string")
coalesce = get_func_from_session("coalesce")
return day(
when(
_is_string(col),
coalesce(try_to_timestamp(col), to_date(col)),
).otherwise(col)
)


def try_to_timestamp_strptime(col: ColumnOrName, format: t.Optional[ColumnOrName] = None) -> Column:
lit = get_func_from_session("lit")

format = lit(format or spark_default_time_format())
return Column.invoke_anonymous_function(col, "TRY_STRPTIME", format_time_from_spark(format)) # type: ignore


def try_to_timestamp_safe(col: ColumnOrName, format: t.Optional[ColumnOrName] = None) -> Column:
lit = get_func_from_session("lit")

format = lit(format or spark_default_time_format())
return Column.invoke_anonymous_function(
format_time_from_spark(format), # type: ignore
"SAFE.PARSE_TIMESTAMP",
col, # type: ignore
)


def try_to_timestamp_pgtemp(col: ColumnOrName, format: t.Optional[ColumnOrName] = None) -> Column:
lit = get_func_from_session("lit")

format = lit(format or spark_default_time_format())
return Column.invoke_anonymous_function(
col,
"pg_temp.TRY_TO_TIMESTAMP",
format_time_from_spark(format), # type: ignore
)


def typeof_pg_typeof(col: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col, "pg_typeof").cast("regtype").cast("text")


def typeof_from_variant(col: ColumnOrName) -> Column:
col = Column.invoke_anonymous_function(col, "TO_VARIANT")
return Column.invoke_anonymous_function(col, "TYPEOF")


def _is_string_using_typeof_varchar(col: ColumnOrName) -> Column:
typeof = get_func_from_session("typeof")
lit = get_func_from_session("lit")
return lit(typeof(col) == lit("VARCHAR"))


def _is_string_using_typeof_char_varying(col: ColumnOrName) -> Column:
typeof = get_func_from_session("typeof")
lit = get_func_from_session("lit")
return lit(
(typeof(col) == lit("text"))
| (typeof(col) == lit("character varying"))
| (typeof(col) == lit("unknown"))
| (typeof(col) == lit("text"))
)


def _is_string_using_typeof_string(col: ColumnOrName) -> Column:
typeof = get_func_from_session("typeof")
lit = get_func_from_session("lit")
return lit(typeof(col) == lit("STRING"))


def _is_string_using_typeof_string_lcase(col: ColumnOrName) -> Column:
typeof = get_func_from_session("typeof")
lit = get_func_from_session("lit")
return lit(typeof(col) == lit("string"))
35 changes: 26 additions & 9 deletions sqlframe/base/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@

from sqlframe.base.column import Column
from sqlframe.base.decorators import func_metadata as meta
from sqlframe.base.util import format_time_from_spark, spark_default_time_format
from sqlframe.base.util import (
format_time_from_spark,
get_func_from_session,
spark_default_date_format,
spark_default_time_format,
)

if t.TYPE_CHECKING:
from pyspark.sql.session import SparkContext
Expand Down Expand Up @@ -877,7 +882,7 @@ def months_between(

@meta()
def to_date(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
format = lit(format or spark_default_time_format())
format = lit(format or spark_default_date_format())
if format is not None:
return Column.invoke_expression_over_column(
col, expression.TsOrDsToDate, format=format_time_from_spark(format)
Expand Down Expand Up @@ -1743,7 +1748,7 @@ def map_zip_with(
return Column.invoke_anonymous_function(col1, "MAP_ZIP_WITH", col2, Column(f_expression))


@meta(unsupported_engines=["postgres", "snowflake"])
@meta()
def typeof(col: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col, "TYPEOF")

Expand Down Expand Up @@ -2162,7 +2167,7 @@ def datepart(field: ColumnOrName, source: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(field, "datepart", source)


@meta(unsupported_engines="*")
@meta(unsupported_engines=["bigquery", "postgres", "snowflake"])
def day(col: ColumnOrName) -> Column:
return Column.invoke_expression_over_column(col, expression.Day)

Expand Down Expand Up @@ -5277,7 +5282,7 @@ def try_element_at(col: ColumnOrName, extraction: ColumnOrName) -> Column:
)


@meta(unsupported_engines="*")
@meta()
def try_to_timestamp(col: ColumnOrName, format: t.Optional[ColumnOrName] = None) -> Column:
"""
Parses the `col` with the `format` to a timestamp. The function always
Expand All @@ -5302,10 +5307,8 @@ def try_to_timestamp(col: ColumnOrName, format: t.Optional[ColumnOrName] = None)
>>> df.select(try_to_timestamp(df.t, lit('yyyy-MM-dd HH:mm:ss')).alias('dt')).collect()
[Row(dt=datetime.datetime(1997, 2, 28, 10, 30))]
"""
if format is not None:
return Column.invoke_anonymous_function(col, "try_to_timestamp", format)
else:
return Column.invoke_anonymous_function(col, "try_to_timestamp")
format = lit(format or spark_default_time_format())
return Column.invoke_anonymous_function(col, "try_to_timestamp", format_time_from_spark(format)) # type: ignore


@meta()
Expand Down Expand Up @@ -5797,6 +5800,20 @@ def years(col: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col, "years")


# SQLFrame specific
@meta()
def _is_string(col: ColumnOrName) -> Column:
col = Column.invoke_anonymous_function(col, "TO_VARIANT")
return Column.invoke_anonymous_function(col, "IS_VARCHAR")


@meta()
def _is_date(col: ColumnOrName) -> Column:
typeof = get_func_from_session("typeof")
upper = get_func_from_session("upper")
return lit(upper(typeof(col)) == lit("DATE"))


@meta()
def _lambda_quoted(value: str) -> t.Optional[bool]:
return False if value == "_" else None
Expand Down
4 changes: 4 additions & 0 deletions sqlframe/base/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,3 +365,7 @@ def format_time_from_spark(value: ColumnOrLiteral) -> Column:

def spark_default_time_format() -> str:
return Dialect["spark"].TIME_FORMAT.strip("'")


def spark_default_date_format() -> str:
return Dialect["spark"].DATE_FORMAT.strip("'")
2 changes: 2 additions & 0 deletions sqlframe/bigquery/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@
array_union_using_array_concat as array_union,
sequence_from_generate_array as sequence,
position_as_strpos as position,
try_to_timestamp_safe as try_to_timestamp,
_is_string_using_typeof_string as _is_string,
)


Expand Down
3 changes: 3 additions & 0 deletions sqlframe/duckdb/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,7 @@
array_max_from_sort as array_max,
sequence_from_generate_series as sequence,
try_element_at_zero_based as try_element_at,
day_with_try_to_timestamp as day,
try_to_timestamp_strptime as try_to_timestamp,
_is_string_using_typeof_varchar as _is_string,
)
3 changes: 3 additions & 0 deletions sqlframe/postgres/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,7 @@
right_cast_len as right,
position_cast_start as position,
try_element_at_zero_based as try_element_at,
try_to_timestamp_pgtemp as try_to_timestamp,
typeof_pg_typeof as typeof,
_is_string_using_typeof_char_varying as _is_string,
)
8 changes: 8 additions & 0 deletions sqlframe/postgres/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ def __init__(self, conn: t.Optional[psycopg2_connection] = None):
if not hasattr(self, "_conn"):
super().__init__(conn)
self._execute("CREATE EXTENSION IF NOT EXISTS fuzzystrmatch")
self._execute("""CREATE OR REPLACE FUNCTION pg_temp.try_to_timestamp(input_text TEXT, format TEXT)
RETURNS TIMESTAMP AS $$
BEGIN
RETURN TO_TIMESTAMP(input_text, format);
EXCEPTION WHEN OTHERS THEN
RETURN NULL;
END;
$$ LANGUAGE plpgsql;""")

def _fetch_rows(
self, sql: t.Union[str, exp.Expression], *, quote_identifiers: bool = True
Expand Down
2 changes: 2 additions & 0 deletions sqlframe/snowflake/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,6 @@
map_concat_using_map_cat as map_concat,
sequence_from_array_generate_range as sequence,
to_number_using_to_double as to_number,
typeof_from_variant as typeof,
to_date_time_format as to_date,
)
1 change: 1 addition & 0 deletions sqlframe/spark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@
percentile_without_disc as percentile,
add_months_by_multiplication as add_months,
arrays_overlap_renamed as arrays_overlap,
_is_string_using_typeof_string_lcase as _is_string,
)
2 changes: 2 additions & 0 deletions tests/integration/engines/postgres/test_postgres_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def test_list_databases(postgres_session: PostgresSession):
Database(name="db1", catalog="tests", description=None, locationUri=""),
Database(name="information_schema", catalog="tests", description=None, locationUri=""),
Database(name="pg_catalog", catalog="tests", description=None, locationUri=""),
Database(name="pg_temp_3", catalog="tests", description=None, locationUri=""),
Database(name="pg_toast", catalog="tests", description=None, locationUri=""),
Database(name="pg_toast_temp_3", catalog="tests", description=None, locationUri=""),
Database(name="public", catalog="tests", description=None, locationUri=""),
]

Expand Down
57 changes: 51 additions & 6 deletions tests/integration/engines/test_int_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,24 @@ def test_typeof(get_session_and_func, get_types, arg, expected):
pytest.skip("BigQuery doesn't support binary")
if expected == "timestamp":
expected = "datetime"
if isinstance(session, PostgresSession):
if expected.startswith("map"):
pytest.skip("Postgres doesn't support map types")
elif expected.startswith("struct"):
pytest.skip("Postgres doesn't support struct types")
elif expected == "binary":
pytest.skip("Postgres doesn't support binary")
if isinstance(session, SnowflakeSession):
if expected == "bigint":
expected = "int"
elif expected == "string":
expected = "varchar"
elif expected.startswith("map") or expected.startswith("struct"):
expected = "object"
elif expected.startswith("array"):
pytest.skip("Snowflake doesn't handle arrays properly in values clause")
elif expected == "timestamp":
expected = "timestampntz"
result = df.select(typeof("col").alias("test")).first()[0]
assert exp.DataType.build(result, dialect=dialect) == exp.DataType.build(
expected, dialect=dialect
Expand Down Expand Up @@ -4849,12 +4867,16 @@ def test_try_to_timestamp(get_session_and_func, get_func):
session, try_to_timestamp = get_session_and_func("try_to_timestamp")
lit = get_func("lit", session)
df = session.createDataFrame([("1997-02-28 10:30:00",)], ["t"])
assert df.select(try_to_timestamp(df.t).alias("dt")).first()[0] == datetime.datetime(
1997, 2, 28, 10, 30
)
assert df.select(try_to_timestamp(df.t, lit("yyyy-MM-dd HH:mm:ss")).alias("dt")).first()[
0
] == datetime.datetime(1997, 2, 28, 10, 30)
result = df.select(try_to_timestamp(df.t).alias("dt")).first()[0]
if isinstance(session, BigQuerySession):
assert result == datetime.datetime(1997, 2, 28, 10, 30, tzinfo=datetime.timezone.utc)
else:
assert result == datetime.datetime(1997, 2, 28, 10, 30)
result = df.select(try_to_timestamp(df.t, lit("yyyy-MM-dd HH:mm:ss")).alias("dt")).first()[0]
if isinstance(session, BigQuerySession):
assert result == datetime.datetime(1997, 2, 28, 10, 30, tzinfo=datetime.timezone.utc)
else:
assert result == datetime.datetime(1997, 2, 28, 10, 30)


def test_ucase(get_session_and_func, get_func):
Expand Down Expand Up @@ -5010,3 +5032,26 @@ def test_xpath_string(get_session_and_func, get_func):
lit = get_func("lit", session)
df = session.createDataFrame([("<a><b>b</b><c>cc</c></a>",)], ["x"])
assert df.select(xpath_string(df.x, lit("a/c")).alias("r")).first()[0] == "cc"


def test_is_string(get_session_and_func, get_func):
session, _is_string = get_session_and_func("_is_string")
lit = get_func("lit", session)
assert session.range(1).select(_is_string(lit("value")), _is_string(lit(1))).collect() == [
Row(v1=True, v2=False)
]


def test_is_date(get_session_and_func, get_func):
session, _is_date = get_session_and_func("_is_date")
to_date = get_func("to_date", session)
lit = get_func("lit", session)
assert session.range(1).select(
_is_date(to_date(lit("2021-01-01"), "yyyy-MM-dd")), _is_date(lit("2021-01-01"))
).collect() == [Row(v1=True, v2=False)]


# def test_

# typeof = get_func("typeof", session)
# assert session.range(1).select(typeof(to_date(lit("2021-01-01"), 'yyyy-MM-dd'))).collect() == [Row(value=True)]
6 changes: 3 additions & 3 deletions tests/unit/standalone/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4636,9 +4636,9 @@ def test_try_element_at(expression, expected):
@pytest.mark.parametrize(
"expression, expected",
[
(SF.try_to_timestamp("cola"), "TRY_TO_TIMESTAMP(cola)"),
(SF.try_to_timestamp(SF.col("cola")), "TRY_TO_TIMESTAMP(cola)"),
(SF.try_to_timestamp("cola", "colb"), "TRY_TO_TIMESTAMP(cola, colb)"),
(SF.try_to_timestamp("cola"), "TRY_TO_TIMESTAMP(cola, 'yyyy-MM-dd HH:mm:ss')"),
(SF.try_to_timestamp(SF.col("cola")), "TRY_TO_TIMESTAMP(cola, 'yyyy-MM-dd HH:mm:ss')"),
(SF.try_to_timestamp("cola", "blah"), "TRY_TO_TIMESTAMP(cola, 'blah')"),
],
)
def test_try_to_timestamp(expression, expected):
Expand Down

0 comments on commit b8dea22

Please sign in to comment.