Skip to content

Commit

Permalink
feat: use spark time format instead of engine (#97)
Browse files Browse the repository at this point in the history
  • Loading branch information
eakmanrq authored Jun 28, 2024
1 parent 86188be commit 6e4159b
Show file tree
Hide file tree
Showing 18 changed files with 98 additions and 135 deletions.
4 changes: 0 additions & 4 deletions docs/bigquery.md
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,6 @@ See something that you would like to see supported? [Open an issue](https://gith
* [date_diff](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.date_diff.html)
* [datediff](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.datediff.html)
* [date_format](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.date_format.html)
* [The format string should be in BigQuery syntax](https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements)
* [date_sub](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.date_sub.html)
* [date_trunc](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.date_trunc.html)
* [dayofmonth](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.dayofmonth.html)
Expand Down Expand Up @@ -442,9 +441,7 @@ See something that you would like to see supported? [Open an issue](https://gith
* [toDegrees](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.toDegrees.html)
* [toRadians](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.toRadians.html)
* [to_date](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.to_date.html)
* [The format string should be in BigQuery syntax](https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements)
* [to_timestamp](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.to_timestamp.html)
* [The format string should be in BigQuery syntax](https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements)
* [translate](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.translate.html)
* [trim](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.trim.html)
* [trunc](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.trunc.html)
Expand All @@ -454,7 +451,6 @@ See something that you would like to see supported? [Open an issue](https://gith
* [unbase64](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.unbase64.html)
* [unhex](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.unhex.html)
* [unix_timestamp](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.unix_timestamp.html)
* [The format string should be in BigQuery syntax](https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements)
* [upper](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.upper.html)
* [var_pop](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.var_pop.html)
* [var_samp](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.var_samp.html)
Expand Down
7 changes: 1 addition & 6 deletions docs/duckdb.md
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,7 @@ See something that you would like to see supported? [Open an issue](https://gith
* [dateadd](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.dateadd.html)
* [date_diff](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.date_diff.html)
* [datediff](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.datediff.html)
* [date_format](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.date_format.html)
* [The format string should be in DuckDB syntax](https://duckdb.org/docs/sql/functions/dateformat.html#format-specifiers)
* [date_format](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.date_format.html)
* [date_sub](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.date_sub.html)
* [date_trunc](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.date_trunc.html)
* [dayofmonth](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.dayofmonth.html)
Expand Down Expand Up @@ -405,11 +404,8 @@ See something that you would like to see supported? [Open an issue](https://gith
* [toDegrees](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.toDegrees.html)
* [toRadians](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.toRadians.html)
* [to_date](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.to_date.html)
* [The format string should be in DuckDB syntax](https://duckdb.org/docs/sql/functions/dateformat.html#format-specifiers)
* [to_timestamp](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.to_timestamp.html)
* [The format string should be in DuckDB syntax](https://duckdb.org/docs/sql/functions/dateformat.html#format-specifiers)
* [to_unix_timestamp](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.to_unix_timestamp.html)
* [The format string should be in DuckDB syntax](https://duckdb.org/docs/sql/functions/dateformat.html#format-specifiers
* The values must match the format string (null will not be returned if they do not)
* [translate](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.translate.html)
* [trim](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.trim.html)
Expand All @@ -420,7 +416,6 @@ See something that you would like to see supported? [Open an issue](https://gith
* [unbase64](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.unbase64.html)
* [unhex](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.unhex.html)
* [unix_timestamp](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.unix_timestamp.html)
* [The format string should be in DuckDB syntax](https://duckdb.org/docs/sql/functions/dateformat.html#format-specifiers)
* [upper](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.upper.html)
* [var_pop](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.var_pop.html)
* [var_samp](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.var_samp.html)
Expand Down
3 changes: 0 additions & 3 deletions docs/postgres.md
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,6 @@ See something that you would like to see supported? [Open an issue](https://gith
* [date_diff](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.date_diff.html)
* [datediff](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.datediff.html)
* [date_format](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.date_format.html)
* [The format string should be in Postgres syntax](https://www.postgresql.org/docs/current/functions-formatting.html#FUNCTIONS-FORMATTING-DATETIME-TABLE)
* [date_sub](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.date_sub.html)
* [date_trunc](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.date_trunc.html)
* Rounded whole number is returned
Expand Down Expand Up @@ -397,10 +396,8 @@ See something that you would like to see supported? [Open an issue](https://gith
* [toDegrees](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.toDegrees.html)
* [toRadians](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.toRadians.html)
* [to_date](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.to_date.html)
* [The format string should be in Postgres syntax](https://www.postgresql.org/docs/current/functions-formatting.html#FUNCTIONS-FORMATTING-DATETIME-TABLE)
* [to_number](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.to_number.html)
* [to_timestamp](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.to_timestamp.html)
* [The format string should be in Postgres syntax](https://www.postgresql.org/docs/current/functions-formatting.html#FUNCTIONS-FORMATTING-DATETIME-TABLE)
* [translate](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.translate.html)
* [trim](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.trim.html)
* [trunc](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.trunc.html)
Expand Down
4 changes: 0 additions & 4 deletions docs/snowflake.md
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,6 @@ See something that you would like to see supported? [Open an issue](https://gith
* [date_diff](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.date_diff.html)
* [datediff](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.datediff.html)
* [date_format](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.date_format.html)
* [The format string should be in Snowflake syntax](https://docs.snowflake.com/en/sql-reference/functions-conversion#label-date-time-format-conversion)
* [date_sub](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.date_sub.html)
* [date_trunc](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.date_trunc.html)
* [dayofmonth](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.dayofmonth.html)
Expand Down Expand Up @@ -440,18 +439,15 @@ See something that you would like to see supported? [Open an issue](https://gith
* [toDegrees](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.toDegrees.html)
* [toRadians](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.toRadians.html)
* [to_date](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.to_date.html)
* [The format string should be in Snowflake syntax](https://docs.snowflake.com/en/sql-reference/functions-conversion#label-date-time-format-conversion)
* [to_number](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.to_number.html)
* [to_timestamp](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.to_timestamp.html)
* [The format string should be in Snowflake syntax](https://docs.snowflake.com/en/sql-reference/functions-conversion#label-date-time-format-conversion)
* [translate](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.translate.html)
* [trim](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.trim.html)
* [trunc](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.trunc.html)
* [ucase](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.ucase.html)
* [unbase64](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.unbase64.html)
* [unhex](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.unhex.html)
* [unix_timestamp](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.unix_timestamp.html)
* [The format string should be in Snowflake syntax](https://docs.snowflake.com/en/sql-reference/functions-conversion#label-date-time-format-conversion)
* [upper](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.upper.html)
* [var_pop](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.var_pop.html)
* [var_samp](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.var_samp.html)
Expand Down
23 changes: 12 additions & 11 deletions sqlframe/base/function_alternatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,16 @@
import typing as t

from sqlglot import exp as expression
from sqlglot.dialects.dialect import build_formatted_time
from sqlglot.helper import ensure_list
from sqlglot.helper import flatten as _flatten

from sqlframe.base.column import Column
from sqlframe.base.util import get_func_from_session
from sqlframe.base.util import (
format_time_from_spark,
get_func_from_session,
spark_default_time_format,
)

if t.TYPE_CHECKING:
from sqlframe.base._typing import ColumnOrLiteral, ColumnOrName
Expand Down Expand Up @@ -715,14 +720,10 @@ def months_between_cast_as_date_cast_roundoff(


def from_unixtime_from_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
from sqlframe.base.session import _BaseSession

session: _BaseSession = _BaseSession()
lit = get_func_from_session("lit")
col_func = get_func_from_session("col")

if format is None:
format = session.DEFAULT_TIME_FORMAT
format = lit(format or spark_default_time_format())
return Column.invoke_expression_over_column(
Column(
expression.Anonymous(
Expand All @@ -731,7 +732,7 @@ def from_unixtime_from_timestamp(col: ColumnOrName, format: t.Optional[str] = No
)
),
expression.TimeToStr,
format=lit(format),
format=format_time_from_spark(format), # type: ignore
)


Expand Down Expand Up @@ -1511,10 +1512,10 @@ def to_unix_timestamp_include_default_format(
format: t.Optional[ColumnOrName] = None,
) -> Column:
from sqlframe.base.functions import to_unix_timestamp

lit = get_func_from_session("lit")
from sqlframe.base.session import _BaseSession

if not format:
format = lit("%Y-%m-%d %H:%M:%S")

format = _BaseSession().output_dialect.TIME_FORMAT
else:
format = format_time_from_spark(format)
return to_unix_timestamp(timestamp, format)
39 changes: 24 additions & 15 deletions sqlframe/base/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
import logging
import typing as t

from sqlglot import Dialect
from sqlglot import exp as expression
from sqlglot.helper import ensure_list
from sqlglot.helper import flatten as _flatten

from sqlframe.base.column import Column
from sqlframe.base.decorators import func_metadata as meta
from sqlframe.base.util import format_time_from_spark, spark_default_time_format

if t.TYPE_CHECKING:
from pyspark.sql.session import SparkContext
Expand Down Expand Up @@ -695,7 +697,7 @@ def date_format(col: ColumnOrName, format: str) -> Column:
return Column.invoke_expression_over_column(
Column(expression.TimeStrToTime(this=Column.ensure_col(col).expression)),
expression.TimeToStr,
format=lit(format),
format=format_time_from_spark(format),
)


Expand Down Expand Up @@ -875,17 +877,21 @@ def months_between(

@meta()
def to_date(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
format = lit(format or spark_default_time_format())
if format is not None:
return Column.invoke_expression_over_column(
col, expression.TsOrDsToDate, format=lit(format)
col, expression.TsOrDsToDate, format=format_time_from_spark(format)
)
return Column.invoke_expression_over_column(col, expression.TsOrDsToDate)


@meta()
def to_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
format = lit(format or spark_default_time_format())
if format is not None:
return Column.invoke_expression_over_column(col, expression.StrToTime, format=lit(format))
return Column.invoke_expression_over_column(
col, expression.StrToTime, format=format_time_from_spark(format)
)

return Column.ensure_col(col).cast("timestamp")

Expand Down Expand Up @@ -916,23 +922,23 @@ def last_day(col: ColumnOrName) -> Column:

@meta()
def from_unixtime(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
from sqlframe.base.session import _BaseSession

if format is None:
format = _BaseSession().DEFAULT_TIME_FORMAT
return Column.invoke_expression_over_column(col, expression.UnixToStr, format=lit(format))
format = lit(format or spark_default_time_format())
return Column.invoke_expression_over_column(
col,
expression.UnixToStr,
format=format_time_from_spark(format), # type: ignore
)


@meta()
def unix_timestamp(
timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None
) -> Column:
from sqlframe.base.session import _BaseSession

if format is None:
format = _BaseSession().DEFAULT_TIME_FORMAT
format = lit(format or spark_default_time_format())
return Column.invoke_expression_over_column(
timestamp, expression.StrToUnix, format=lit(format)
timestamp,
expression.StrToUnix,
format=format_time_from_spark(format), # type: ignore
).cast("bigint")


Expand Down Expand Up @@ -5106,8 +5112,11 @@ def to_unix_timestamp(
[Row(r=None)]
>>> spark.conf.unset("spark.sql.session.timeZone")
"""
format = lit(spark_default_time_format()) if format is None else format
if format is not None:
return Column.invoke_expression_over_column(timestamp, expression.StrToUnix, format=format)
return Column.invoke_expression_over_column(
timestamp, expression.StrToUnix, format=format_time_from_spark(format)
)
else:
return Column.invoke_expression_over_column(timestamp, expression.StrToUnix)

Expand Down Expand Up @@ -5324,7 +5333,7 @@ def ucase(str: ColumnOrName) -> Column:
return Column.invoke_expression_over_column(str, expression.Upper)


@meta()
@meta(unsupported_engines=["bigquery", "snowflake"])
def unix_date(col: ColumnOrName) -> Column:
"""Returns the number of days since 1970-01-01.
Expand Down
5 changes: 4 additions & 1 deletion sqlframe/base/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ class _BaseSession(t.Generic[CATALOG, READER, WRITER, DF, CONN]):
_df: t.Type[DF]

SANITIZE_COLUMN_NAMES = False
DEFAULT_TIME_FORMAT = "yyyy-MM-dd HH:mm:ss"

def __init__(
self,
Expand Down Expand Up @@ -114,6 +113,10 @@ def _conn(self) -> CONN:
def _cur(self) -> DBAPICursorWithPandas:
return self._conn.cursor()

@property
def default_time_format(self) -> str:
return self.output_dialect.TIME_FORMAT.strip("'")

def _sanitize_column_name(self, name: str) -> str:
if self.SANITIZE_COLUMN_NAMES:
return name.replace("(", "_").replace(")", "_")
Expand Down
25 changes: 24 additions & 1 deletion sqlframe/base/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
from pyspark.sql.dataframe import SparkSession as PySparkSession

from sqlframe.base import types
from sqlframe.base._typing import OptionalPrimitiveType, SchemaInput
from sqlframe.base._typing import (
ColumnOrLiteral,
OptionalPrimitiveType,
SchemaInput,
)
from sqlframe.base.column import Column
from sqlframe.base.session import _BaseSession
from sqlframe.base.types import StructType

Expand Down Expand Up @@ -342,3 +347,21 @@ def sqlglot_to_spark(sqlglot_dtype: exp.DataType) -> types.DataType:
]
)
raise NotImplementedError(f"Unsupported data type: {sqlglot_dtype}")


def format_time_from_spark(value: ColumnOrLiteral) -> Column:
from sqlframe.base.column import Column
from sqlframe.base.session import _BaseSession

lit = get_func_from_session("lit")
value = lit(value) if not isinstance(value, Column) else value
formatted_time = Dialect["spark"].format_time(value.expression)
return Column(
_BaseSession()
.output_dialect.generator()
.format_time(exp.StrToTime(this=exp.Null(), format=formatted_time))
)


def spark_default_time_format() -> str:
return Dialect["spark"].TIME_FORMAT.strip("'")
Loading

0 comments on commit 6e4159b

Please sign in to comment.