Skip to content

Commit

Permalink
Merge pull request #415 from machow/fix-np-ufuncs
Browse files Browse the repository at this point in the history
feat: symbolic support numpy ufuncs
  • Loading branch information
machow authored Apr 30, 2022
2 parents 3a52f5f + 4187de3 commit 9f83136
Show file tree
Hide file tree
Showing 10 changed files with 218 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ jobs:
run: |
make test-travis
env:
SB_TEST_PGPORT: 5432
SB_TEST_PGPORT: 5433
PYTEST_FLAGS: ${{ matrix.pytest_flags }}

# optional step for running bigquery tests ----
Expand Down
5 changes: 3 additions & 2 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ version: '3.1'
services:

db_mysql:
image: mysql
image: mysql/mysql-server
restart: always
environment:
MYSQL_ROOT_HOST: "%"
MYSQL_ROOT_PASSWORD: ""
MYSQL_ALLOW_EMPTY_PASSWORD: 1
MYSQL_DATABASE: "public"
Expand All @@ -21,4 +22,4 @@ services:
POSTGRES_PASSWORD: ""
POSTGRES_HOST_AUTH_METHOD: "trust"
ports:
- 5432:5432
- 5433:5432
2 changes: 1 addition & 1 deletion siuba/ops/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _register_series_default(generic):

generic.register(pd.Series, partial(_default_pd_series, generic.operation))

def _default_pd_series(__op, self, args = tuple(), kwargs = {}):
def _default_pd_series(__op, self, *args, **kwargs):
# Once we drop python 3.7 dependency, could make __op position only
if __op.accessor is not None:
method = getattr(getattr(self, __op.accessor), __op.name)
Expand Down
66 changes: 65 additions & 1 deletion siuba/siu/symbolic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import singledispatch

from .calls import BINARY_OPS, UNARY_OPS, Call, BinaryOp, BinaryRightOp, MetaArg, UnaryOp, SliceOp, FuncArg
from .format import Formatter

Expand All @@ -9,6 +11,12 @@ def __init__(self, source = None, ready_to_call = False):
self.__source = MetaArg("_") if source is None else source
self.__ready_to_call = ready_to_call

def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
"""Handle numpy universal functions. E.g. np.sqrt(_)."""
return array_ufunc(self, ufunc, method, *inputs, **kwargs)

def __array_function__(self, func, types, *args, **kwargs):
return array_function(self, func, types, *args, **kwargs)

# allowed methods ----

Expand Down Expand Up @@ -108,7 +116,63 @@ def explain(symbol):
return str(symbol)


# Do some gnarly method setting -----------------------------------------------
# Special numpy ufunc dispatcher
# =============================================================================
# note that this is essentially what dispatchers.symbolic_dispatch does...
# details on numpy array dispatch: https://github.com/numpy/numpy/issues/21387

@singledispatch
def array_function(self, func, types, *args, **kwargs):
return func(*args, **kwargs)


@array_function.register(Call)
def _array_function_call(self, func, types, *args, **kwargs):
return Call("__call__", FuncArg(array_function), self, func, *args, **kwargs)


@array_function.register(Symbolic)
def _array_function_sym(self, func, types, *args, **kwargs):
f_concrete = array_function.dispatch(Call)

call = f_concrete(
strip_symbolic(self),
func,
types,
*map(strip_symbolic, args),
**{k: strip_symbolic(v) for k, v in kwargs.items()}
)

return Symbolic(call)


@singledispatch
def array_ufunc(self, ufunc, method, *inputs, **kwargs):
return getattr(ufunc, method)(*inputs, **kwargs)

@array_ufunc.register(Call)
def _array_ufunc_call(self, ufunc, method, *inputs, **kwargs):

return Call("__call__", FuncArg(array_ufunc), self, ufunc, method, *inputs, **kwargs)


@array_ufunc.register(Symbolic)
def _array_ufunc_sym(self, ufunc, method, *inputs, **kwargs):
f_concrete = array_ufunc.dispatch(Call)

call = f_concrete(
strip_symbolic(self),
ufunc,
method,
*map(strip_symbolic, inputs),
**{k: strip_symbolic(v) for k, v in kwargs.items()}
)

return Symbolic(call)


# Do some gnarly method setting on Symbolic -----------------------------------
# =============================================================================

def create_binary_op(op_name, left_op = True):
def _binary_op(self, x):
Expand Down
2 changes: 1 addition & 1 deletion siuba/siu/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class FunctionLookupBound:
def __init__(self, msg):
self.msg = msg

def __call__(self):
def __call__(self, *args, **kwargs):
raise NotImplementedError(self.msg)


Expand Down
16 changes: 15 additions & 1 deletion siuba/sql/dialects/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@
annotate,
RankOver,
CumlOver,
SqlTranslator
SqlTranslator,
FunctionLookupBound
)


Expand Down Expand Up @@ -122,6 +123,19 @@ def sql_func_capitalize(_, col):
return sql.functions.concat(first_char, rest)


# Numpy ufuncs ----------------------------------------------------------------
# symbolic objects have a generic dispatch for when _.__array_ufunc__ is called,
# in order to support things like np.sqrt(_.x). In theory this wouldn't be crazy
# to support, but most ufuncs have existing pandas methods already.

from siuba.siu.symbolic import array_ufunc, array_function

_f_err = FunctionLookupBound("Numpy function sql translation (e.g. np.sqrt) not supported.")

array_ufunc.register(SqlColumn, _f_err)
array_function.register(SqlColumn, _f_err)


# Misc implementations --------------------------------------------------------

def sql_func_astype(_, col, _type):
Expand Down
2 changes: 1 addition & 1 deletion siuba/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def data_frame(*args, _index = None, **kwargs):
"dialect": "postgresql",
"driver": "",
"dbname": ["SB_TEST_PGDATABASE", "postgres"],
"port": ["SB_TEST_PGPORT", "5432"],
"port": ["SB_TEST_PGPORT", "5433"],
"user": ["SB_TEST_PGUSER", "postgres"],
"password": ["SB_TEST_PGPASSWORD", ""],
"host": ["SB_TEST_PGHOST", "localhost"],
Expand Down
1 change: 1 addition & 0 deletions siuba/tests/test_dply_series_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ def test_pandas_grouped_frame_fast_summarize(agg_entry):

# Edge Cases ==================================================================

@pytest.mark.postgresql
def test_frame_set_aggregates_postgresql():
# TODO: probably shouldn't be creating backend here
backend = SqlBackend("postgresql")
Expand Down
36 changes: 33 additions & 3 deletions siuba/tests/test_siu.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,40 @@
def _():
return Symbolic()

def test_source_attr(_):
sym = _.source

# Symbolic class ==============================================================

def test_symbolic_source_attr(_):
sym = _.__source
assert isinstance(sym, Symbolic)
assert explain(sym) == "_.__source"


def test_symbolic_numpy_ufunc(_):
from siuba.siu.symbolic import array_ufunc
import numpy as np

# should have form...
# █─'__call__'
# ├─█─'__custom_func__'
# │ └─<function array_ufunc at ...>
# ├─_
# ├─<ufunc 'sqrt'>
# ├─'__call__'
# └─_

sym = np.sqrt(_)
expr = strip_symbolic(sym)

assert isinstance(sym, Symbolic)
assert explain(sym) == "_.source"

# check we are doing a call over a custom dispatch function ----
assert expr.func == "__call__"

dispatcher = expr.args[0]
assert isinstance(dispatcher, FuncArg)
assert dispatcher.args[0] is array_ufunc # could check .dispatch() method


def test_op_vars_slice(_):
assert strip_symbolic(_.a[_.b:_.c]).op_vars() == {'a', 'b', 'c'}
Expand Down
97 changes: 97 additions & 0 deletions siuba/tests/test_siu_symbolic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import numpy as np
import pytest

from siuba.siu import strip_symbolic, FunctionLookupError, Symbolic, MetaArg, Call


# Note that currently tests are split across the test_siu.py, and this module.

@pytest.fixture
def _():
return Symbolic()

def test_siu_symbolic_np_array_ufunc_call(_):
sym = np.add(_, 1)
expr = strip_symbolic(sym)

# structure:
# █─'__call__'
# ├─█─'__custom_func__'
# │ └─<function array_ufunc at 0x103aa3820>
# ├─_
# ├─<ufunc 'add'>
# ├─'__call__'
# ├─_
# └─1

assert len(expr.args) == 6
assert expr.args[1] is strip_symbolic(_) # original dispatch obj
assert expr.args[2] is np.add # ufunc object
assert expr.args[3] == "__call__" # its method to use
assert expr.args[4] is strip_symbolic(_) # lhs input
assert expr.args[5] == 1 # rhs input


def test_siu_symbolic_np_array_ufunc_inputs_lhs(_):
lhs = np.array([1,2])
rhs = np.array([3,4])
res = lhs + rhs

# symbol on lhs ----

sym = np.add(_, rhs)
expr = strip_symbolic(sym)

assert np.array_equal(expr(lhs), res)


def test_siu_symbolic_np_array_ufunc_inputs_rhs(_):
lhs = np.array([1,2])
rhs = np.array([3,4])
res = lhs + rhs

# symbol on rhs ----

sym2 = np.add(lhs, _)
expr2 = strip_symbolic(sym2)

assert np.array_equal(expr2(rhs), res)


@pytest.mark.xfail
def test_siu_symbolic_np_array_function(_):
# Note that np.sum is not a ufunc, but sort of reduces on a ufunc under the
# hood, so fails when called on a symbol
sym = np.sum(_)
expr = strip_symbolic(sym)

assert expr(np.array([1,2])) == 3


@pytest.mark.parametrize("func", [
np.absolute, # a ufunc
np.sum # dispatched by __array_function__
])
def test_siu_symbolic_array_ufunc_sql_raises(_, func):
from siuba.sql.utils import mock_sqlalchemy_engine
from siuba.sql import LazyTbl
from siuba.sql import SqlFunctionLookupError

lazy_tbl = LazyTbl(mock_sqlalchemy_engine("postgresql"), "somedata", ["x", "y"])
with pytest.raises(SqlFunctionLookupError) as exc_info:
lazy_tbl.shape_call(strip_symbolic(func(_.x)))

assert "Numpy function sql translation" in exc_info.value.args[0]
assert "not supported" in exc_info.value.args[0]

def test_siu_symbolic_array_ufunc_pandas(_):
import pandas as pd
lhs = pd.Series([1,2])

sym = np.add(_, 1)
expr = strip_symbolic(sym)

src = expr(lhs)
assert isinstance(src, pd.Series)
assert src.equals(lhs + 1)

0 comments on commit 9f83136

Please sign in to comment.