Skip to content

Commit

Permalink
feat: add 3.5 try_* support (#68)
Browse files Browse the repository at this point in the history
  • Loading branch information
eakmanrq authored Jun 12, 2024
1 parent ea6733e commit bf38844
Show file tree
Hide file tree
Showing 5 changed files with 374 additions and 0 deletions.
9 changes: 9 additions & 0 deletions docs/spark.md
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ df.show(5)
* [lpad](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.lpad.html)
* [ltrim](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.ltrim.html)
* [make_date](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.make_date.html)
* [make_interval](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.make_interval.html)
* [map_concat](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.map_concat.html)
* [map_entries](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.map_entries.html)
* [map_filter](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.map_filter.html)
Expand Down Expand Up @@ -408,6 +409,14 @@ df.show(5)
* [translate](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.translate.html)
* [trim](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.trim.html)
* [trunc](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.trunc.html)
* [try_add](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.try_add.html)
* [try_avg](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.try_avg.html)
* [try_divide](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.try_divide.html)
* [try_multiply](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.try_multiply.html)
* [try_subtract](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.try_subtract.html)
* [try_sum](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.try_sum.html)
* [try_to_binary](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.try_to_binary.html)
* [try_to_number](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.try_to_number.html)
* [typeof](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.typeof.html)
* [unbase64](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.unbase64.html)
* [unhex](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.unhex.html)
Expand Down
65 changes: 65 additions & 0 deletions sqlframe/base/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1682,6 +1682,71 @@ def stack(*cols: ColumnOrName) -> Column:
)


@meta(unsupported_engines="*")
def make_interval(
years: t.Optional[ColumnOrName] = None,
months: t.Optional[ColumnOrName] = None,
weeks: t.Optional[ColumnOrName] = None,
days: t.Optional[ColumnOrName] = None,
hours: t.Optional[ColumnOrName] = None,
mins: t.Optional[ColumnOrName] = None,
secs: t.Optional[ColumnOrName] = None,
) -> Column:
values = [years, months, weeks, days, hours, mins, secs]
for value in reversed(values.copy()):
if value is not None:
break
values = values[:-1]
else:
raise ValueError("At least one value must be provided")
columns = [Column.ensure_col(x) if x is not None else lit(None) for x in values]
return Column.invoke_anonymous_function(columns[0], "MAKE_INTERVAL", *columns[1:])


@meta(unsupported_engines="*")
def try_add(left: ColumnOrName, right: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(left, "TRY_ADD", right)


@meta(unsupported_engines="*")
def try_avg(col: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col, "TRY_AVG")


@meta(unsupported_engines="*")
def try_divide(left: ColumnOrName, right: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(left, "TRY_DIVIDE", right)


@meta(unsupported_engines="*")
def try_multiply(left: ColumnOrName, right: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(left, "TRY_MULTIPLY", right)


@meta(unsupported_engines="*")
def try_subtract(left: ColumnOrName, right: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(left, "TRY_SUBTRACT", right)


@meta(unsupported_engines="*")
def try_sum(col: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col, "TRY_SUM")


@meta(unsupported_engines="*")
def try_to_binary(col: ColumnOrName, format: t.Optional[ColumnOrName] = None) -> Column:
if format is not None:
return Column.invoke_anonymous_function(col, "TRY_TO_BINARY", format)
return Column.invoke_anonymous_function(col, "TRY_TO_BINARY")


@meta(unsupported_engines="*")
def try_to_number(col: ColumnOrName, format: t.Optional[ColumnOrName] = None) -> Column:
if format is not None:
return Column.invoke_anonymous_function(col, "TRY_TO_NUMBER", format)
return Column.invoke_anonymous_function(col, "TRY_TO_NUMBER")


@meta()
def _lambda_quoted(value: str) -> t.Optional[bool]:
return False if value == "_" else None
Expand Down
9 changes: 9 additions & 0 deletions sqlframe/spark/functions.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ from sqlframe.base.functions import (
lpad as lpad,
ltrim as ltrim,
make_date as make_date,
make_interval as make_interval,
map_concat as map_concat,
map_entries as map_entries,
map_filter as map_filter,
Expand Down Expand Up @@ -228,6 +229,14 @@ from sqlframe.base.functions import (
translate as translate,
trim as trim,
trunc as trunc,
try_add as try_add,
try_avg as try_avg,
try_divide as try_divide,
try_multiply as try_multiply,
try_subtract as try_subtract,
try_sum as try_sum,
try_to_binary as try_to_binary,
try_to_number as try_to_number,
typeof as typeof,
unbase64 as unbase64,
unhex as unhex,
Expand Down
188 changes: 188 additions & 0 deletions tests/integration/engines/test_int_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import math
import typing as t
from collections import Counter
from decimal import Decimal

import pytest
from pyspark.sql import SparkSession as PySparkSession
Expand Down Expand Up @@ -2883,3 +2884,190 @@ def test_stack(get_session_and_func, get_func):
Row(key=1, value=2),
Row(key=3, value=None),
]


def test_make_interval(get_session_and_func, get_func):
session, make_interval = get_session_and_func("make_interval")
df = session.createDataFrame(
[[100, 11, 1, 1, 12, 30, 01.001001]], ["year", "month", "week", "day", "hour", "min", "sec"]
)
assert (
df.select(
make_interval(df.year, df.month, df.week, df.day, df.hour, df.min, df.sec)
.cast("string")
.alias("r")
).first()[0]
== "100 years 11 months 8 days 12 hours 30 minutes 1.001001 seconds"
)
assert (
df.select(
make_interval(df.year, df.month, df.week, df.day, df.hour, df.min)
.cast("string")
.alias("r")
).first()[0]
== "100 years 11 months 8 days 12 hours 30 minutes"
)
assert (
df.select(
make_interval(df.year, df.month, df.week, df.day, df.hour).cast("string").alias("r")
).first()[0]
== "100 years 11 months 8 days 12 hours"
)
assert (
df.select(
make_interval(df.year, df.month, df.week, df.day).cast("string").alias("r")
).first()[0]
== "100 years 11 months 8 days"
)
assert (
df.select(make_interval(df.year, df.month, df.week).cast("string").alias("r")).first()[0]
== "100 years 11 months 7 days"
)
assert (
df.select(make_interval(df.year, df.month).cast("string").alias("r")).first()[0]
== "100 years 11 months"
)
assert df.select(make_interval(df.year).cast("string").alias("r")).first()[0] == "100 years"


def test_try_add(get_session_and_func, get_func, get_types):
session, try_add = get_session_and_func("try_add")
to_date = get_func("to_date", session)
make_interval = get_func("make_interval", session)
lit = get_func("lit", session)
types = get_types(session)
df = session.createDataFrame([(1982, 15), (1990, 2)], ["birth", "age"])
assert df.select(try_add(df.birth, df.age).alias("r")).collect() == [
Row(r=1997),
Row(r=1992),
]
schema = types.StructType(
[
types.StructField("i", types.IntegerType(), True),
types.StructField("d", types.StringType(), True),
]
)
df = session.createDataFrame([(1, "2015-09-30")], schema)
df = df.select(df.i, to_date(df.d).alias("d"))
assert df.select(try_add(df.d, df.i).alias("r")).collect() == [
Row(r=datetime.date(2015, 10, 1))
]
assert df.select(try_add(df.d, make_interval(df.i)).alias("r")).collect() == [
Row(r=datetime.date(2016, 9, 30))
]
assert df.select(
try_add(df.d, make_interval(lit(0), lit(0), lit(0), df.i)).alias("r")
).collect() == [Row(r=datetime.date(2015, 10, 1))]
assert df.select(
try_add(make_interval(df.i), make_interval(df.i)).cast("string").alias("r")
).collect() == [Row(r="2 years")]


def test_try_avg(get_session_and_func, get_func):
session, try_avg = get_session_and_func("try_avg")
df = session.createDataFrame([(1982, 15), (1990, 2)], ["birth", "age"])
assert df.select(try_avg("age")).first()[0] == 8.5


def test_try_divide(get_session_and_func, get_func):
session, try_divide = get_session_and_func("try_divide")
make_interval = get_func("make_interval", session)
lit = get_func("lit", session)
df = session.createDataFrame([(6000, 15), (1990, 2)], ["a", "b"])
assert df.select(try_divide(df.a, df.b).alias("r")).collect() == [
Row(r=400.0),
Row(r=995.0),
]
df = session.createDataFrame([(1, 2)], ["year", "month"])
assert (
df.select(try_divide(make_interval(df.year), df.month).cast("string").alias("r")).first()[0]
== "6 months"
)
assert (
df.select(
try_divide(make_interval(df.year, df.month), lit(2)).cast("string").alias("r")
).first()[0]
== "7 months"
)
assert (
df.select(
try_divide(make_interval(df.year, df.month), lit(0)).cast("string").alias("r")
).first()[0]
is None
)


def test_try_multiply(get_session_and_func, get_func):
session, try_multiply = get_session_and_func("try_multiply")
make_interval = get_func("make_interval", session)
lit = get_func("lit", session)
df = session.createDataFrame([(6000, 15), (1990, 2)], ["a", "b"])
assert df.select(try_multiply(df.a, df.b).alias("r")).collect() == [
Row(r=90000),
Row(r=3980),
]
df = session.createDataFrame([(2, 3)], ["a", "b"])
assert (
df.select(try_multiply(make_interval(df.a), df.b).cast("string").alias("r")).first()[0]
== "6 years"
)


def test_try_subtract(get_session_and_func, get_func, get_types):
session, try_subtract = get_session_and_func("try_subtract")
make_interval = get_func("make_interval", session)
types = get_types(session)
lit = get_func("lit", session)
to_date = get_func("to_date", session)
df = session.createDataFrame([(6000, 15), (1990, 2)], ["a", "b"])
assert df.select(try_subtract(df.a, df.b).alias("r")).collect() == [
Row(r=5985),
Row(r=1988),
]
schema = types.StructType(
[
types.StructField("i", types.IntegerType(), True),
types.StructField("d", types.StringType(), True),
]
)
df = session.createDataFrame([(1, "2015-09-30")], schema)
df = df.select(df.i, to_date(df.d).alias("d"))
assert df.select(try_subtract(df.d, df.i).alias("r")).first()[0] == datetime.date(2015, 9, 29)
assert df.select(try_subtract(df.d, make_interval(df.i)).alias("r")).first()[
0
] == datetime.date(2014, 9, 30)
assert df.select(
try_subtract(df.d, make_interval(lit(0), lit(0), lit(0), df.i)).alias("r")
).first()[0] == datetime.date(2015, 9, 29)
assert (
df.select(
try_subtract(make_interval(df.i), make_interval(df.i)).cast("string").alias("r")
).first()[0]
== "0 seconds"
)


def test_try_sum(get_session_and_func, get_func):
session, try_sum = get_session_and_func("try_sum")
assert session.range(10).select(try_sum("id")).first()[0] == 45


def test_try_to_binary(get_session_and_func, get_func):
session, try_to_binary = get_session_and_func("try_to_binary")
lit = get_func("lit", session)
df = session.createDataFrame([("abc",)], ["e"])
assert df.select(try_to_binary(df.e, lit("utf-8")).alias("r")).first()[0] == bytearray(b"abc")
df = session.createDataFrame([("414243",)], ["e"])
assert df.select(try_to_binary(df.e).alias("r")).first()[0] == bytearray(b"ABC")


def test_try_to_number(get_session_and_func, get_func):
session, try_to_number = get_session_and_func("try_to_number")
lit = get_func("lit", session)
df = session.createDataFrame([("$78.12",)], ["e"])
actual = df.select(try_to_number(df.e, lit("$99.99")).alias("r")).first()[0]
if isinstance(session, SparkSession):
expected = 78.12
else:
expected = Decimal("78.12")
assert actual == expected
Loading

0 comments on commit bf38844

Please sign in to comment.