Skip to content

Commit

Permalink
Refactoring following a review
Browse files Browse the repository at this point in the history
  • Loading branch information
ahsimb committed Oct 18, 2023
1 parent 7c97b03 commit aa3f2ed
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 65 deletions.
104 changes: 66 additions & 38 deletions exasol_udf_mock_python/mock_context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Tuple, Iterator, Iterable, Any, Optional, Union
from functools import wraps

import pandas as pd

Expand All @@ -14,8 +15,11 @@ class MockContext(UDFContext):
This class allows iterating over groups. The functionality of the UDF Context are applicable
for the current input group.
Call `_next_group` to iterate over groups. The `_output_groups` property provides the emit
Call `next_group` to iterate over groups. The `output_groups` property provides the emit
output for all groups iterated so far including the output for the current group.
Calling any function of the UDFContext interface when the group iterator has passed the end
or before the first call to the `next_group` is illegal and will cause a RuntimeException.
"""

def __init__(self, input_groups: Iterator[Group], metadata: MockMetaData):
Expand All @@ -28,20 +32,20 @@ def __init__(self, input_groups: Iterator[Group], metadata: MockMetaData):
self._input_groups = input_groups
self._metadata = metadata
""" Mock context for the current group """
self._current:Optional[StandaloneMockContext] = None
self._current_context: Optional[StandaloneMockContext] = None
""" Output for all groups """
self._previous_groups: List[Group] = []
self._previous_output: List[Group] = []

def _next_group(self) -> bool:
def next_group(self) -> bool:
"""
Moves group iterator to the next group.
Returns False if the iterator gets beyond the last group. Returns True otherwise.
"""

# Save output of the current group
if self._current is not None:
self._previous_groups.append(Group(self._current.output))
self._current = None
if self._current_context is not None:
self._previous_output.append(Group(self._current_context.output))
self._current_context = None

# Try get to the next input group
try:
Expand All @@ -52,48 +56,83 @@ def _next_group(self) -> bool:
raise RuntimeError("Empty input groups are not allowed")

# Create Mock Context for the new input group
self._current = StandaloneMockContext(input_group, self._metadata)
self._current_context = StandaloneMockContext(input_group, self._metadata)
return True

@property
def _output_groups(self):
def output_groups(self):
"""
Output of all groups including the current one.
"""
if self._current is None:
return self._previous_groups
if self._current_context is None:
return self._previous_output
else:
groups = list(self._previous_groups)
groups.append(Group(self._current.output))
groups = list(self._previous_output)
groups.append(Group(self._current_context.output))
return groups

@staticmethod
def _check_context(f):
@wraps(f)
def wrapper(self, *args, **kwargs):
if self._current_context is None:
raise RuntimeError('Calling UDFContext interface when the current group context '
'is invalid is disallowed')
return f(self, *args, **kwargs)

return wrapper

@_check_context
def __getattr__(self, name):
return None if self._current is None else getattr(self._current, name)
return getattr(self._current_context, name)

@_check_context
def get_dataframe(self, num_rows: Union[str, int], start_col: int = 0) -> Optional[pd.DataFrame]:
return None if self._current is None else self._current.get_dataframe(num_rows, start_col)
return self._current_context.get_dataframe(num_rows, start_col)

@_check_context
def next(self, reset: bool = False) -> bool:
return False if self._current is None else self._current.next(reset)
return self._current_context.next(reset)

@_check_context
def size(self) -> int:
return 0 if self._current is None else self._current.size()
return self._current_context.size()

@_check_context
def reset(self) -> None:
if self._current is not None:
self._current.reset()
self._current_context.reset()

def emit(self, *args):
if self._current is not None:
self._current.emit(*args)
@_check_context
def emit(self, *args) -> None:
self._current_context.emit(*args)


def get_scalar_input(inp: Any) -> Iterable[Tuple[Any, ...]]:
"""
Figures out if the SCALAR parameters are provided as a scalar value or a tuple
and also if there is a wrapping container around.
Unless the parameters are already in a wrapping container returns parameters as a tuple wrapped
into a one-item list, e.g [(param1[, param2, ...)]. Otherwise, returns the original input.
:param inp: Input parameters.
"""

if isinstance(inp, Iterable) and not isinstance(inp, str):
row1 = next(iter(inp))
if isinstance(row1, Iterable) and not isinstance(row1, str):
return inp
else:
return [inp]
else:
return [(inp,)]


class StandaloneMockContext(UDFContext):
"""
Implementation of generic UDF Mock Context interface a SCALAR UDF or a SET UDF with no groups.
For Emit UDFs the output in the form of the list of tuples can be
access by reading the `output` property.
accessed by reading the `output` property.
"""

def __init__(self, inp: Any, metadata: MockMetaData):
Expand All @@ -107,19 +146,8 @@ def __init__(self, inp: Any, metadata: MockMetaData):
:param metadata: The mock metadata object.
"""

if metadata.input_type.upper() == 'SCALAR':
# Figure out if the SCALAR parameters are provided as a scalar value or a tuple
# and also if there is a wrapping container around. In any case, this should be
# converted to a form [(param1[, param2, ...)]
if isinstance(inp, Iterable) and not isinstance(inp, str):
row1 = next(iter(inp))
if isinstance(row1, Iterable) and not isinstance(row1, str):
self._input = inp
else:
self._input = [inp]
else:
self._input = [(inp,)]
self._input = get_scalar_input(inp)
else:
self._input = inp
self._metadata = metadata
Expand Down Expand Up @@ -176,7 +204,7 @@ def next(self, reset: bool = False):
try:
new_data = next(self._iter)
self._data = new_data
self._validate_tuples(self._data, self._metadata.input_columns)
self.validate_emit(self._data, self._metadata.input_columns)
return True
except StopIteration as e:
self._data = None
Expand All @@ -195,11 +223,11 @@ def emit(self, *args):
else:
tuples = [args]
for row in tuples:
self._validate_tuples(row, self._metadata.output_columns)
self.validate_emit(row, self._metadata.output_columns)
self._output.extend(tuples)

@staticmethod
def _validate_tuples(row: Tuple, columns: List[Column]):
def validate_emit(row: Tuple, columns: List[Column]):
if len(row) != len(columns):
raise Exception(f"row {row} has not the same number of values as columns are defined")
for i, column in enumerate(columns):
Expand Down
4 changes: 2 additions & 2 deletions exasol_udf_mock_python/udf_mock_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def _loop_groups(ctx:MockContext, exa:MockExaEnvironment, runfunc:Callable):
while ctx._next_group():
while ctx.next_group():
_wrapped_run(ctx, exa, runfunc)


Expand Down Expand Up @@ -77,4 +77,4 @@ def run(self,
finally:
if "cleanup" in exec_globals:
self._exec_cleanup(exec_globals)
return ctx._output_groups
return ctx.output_groups
51 changes: 34 additions & 17 deletions tests/test_mock_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,45 +15,62 @@ def context_set_emits(meta_set_emits):


def test_scroll(context_set_emits):
assert context_set_emits._current is None
assert not context_set_emits._output_groups
assert context_set_emits._next_group()
assert not context_set_emits.output_groups
assert context_set_emits.next_group()
assert context_set_emits.t2 == 'cat'
assert context_set_emits.next()
assert context_set_emits.t2 == 'dog'
assert not context_set_emits.next()
assert context_set_emits._next_group()
assert context_set_emits.next_group()
assert context_set_emits.t2 == 'ant'
assert context_set_emits.next()
assert context_set_emits.t2 == 'bee'
assert context_set_emits.next()
assert context_set_emits.t2 == 'beetle'
assert not context_set_emits.next()
assert not context_set_emits._next_group()
assert context_set_emits._current is None
assert not context_set_emits.next_group()


def test_output_groups(context_set_emits):
context_set_emits._next_group()
context_set_emits.next_group()
context_set_emits.emit(1, 'cat')
context_set_emits.emit(2, 'dog')
context_set_emits._next_group()
context_set_emits.next_group()
context_set_emits.emit(3, 'ant')
context_set_emits.emit(4, 'bee')
context_set_emits.emit(5, 'beetle')
context_set_emits._next_group()
assert len(context_set_emits._output_groups) == 2
assert context_set_emits._output_groups[0] == Group([(1, 'cat'), (2, 'dog')])
assert context_set_emits._output_groups[1] == Group([(3, 'ant'), (4, 'bee'), (5, 'beetle')])
context_set_emits.next_group()
assert len(context_set_emits.output_groups) == 2
assert context_set_emits.output_groups[0] == Group([(1, 'cat'), (2, 'dog')])
assert context_set_emits.output_groups[1] == Group([(3, 'ant'), (4, 'bee'), (5, 'beetle')])


def test_output_groups_partial(context_set_emits):
context_set_emits._next_group()
context_set_emits.next_group()
context_set_emits.emit(1, 'cat')
context_set_emits.emit(2, 'dog')
context_set_emits._next_group()
context_set_emits.next_group()
context_set_emits.emit(3, 'ant')
context_set_emits.emit(4, 'bee')
assert len(context_set_emits._output_groups) == 2
assert context_set_emits._output_groups[0] == Group([(1, 'cat'), (2, 'dog')])
assert context_set_emits._output_groups[1] == Group([(3, 'ant'), (4, 'bee')])
assert len(context_set_emits.output_groups) == 2
assert context_set_emits.output_groups[0] == Group([(1, 'cat'), (2, 'dog')])
assert context_set_emits.output_groups[1] == Group([(3, 'ant'), (4, 'bee')])


def test_no_context_exception(context_set_emits):

for _ in range(3):
context_set_emits.next_group()

with pytest.raises(RuntimeError):
_ = context_set_emits.t2
with pytest.raises(RuntimeError):
_ = context_set_emits.get_dataframe()
with pytest.raises(RuntimeError):
context_set_emits.next()
with pytest.raises(RuntimeError):
_ = context_set_emits.size()
with pytest.raises(RuntimeError):
context_set_emits.reset()
with pytest.raises(RuntimeError):
context_set_emits.emit(1, 'cat')
16 changes: 8 additions & 8 deletions tests/test_mock_context_standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,12 @@ def test_next(context_set_emits):


def test_next_end(context_set_emits):
assert context_set_emits.next()
context_set_emits.next()
assert not context_set_emits.next()


def test_reset(context_set_emits):
assert context_set_emits.next()
context_set_emits.next()
context_set_emits.reset()
assert context_set_emits.t1 == 5
assert context_set_emits.t2 == 'abc'
Expand All @@ -85,17 +85,17 @@ def test_size(context_set_emits):
assert context_set_emits.size() == 2


def test_validate_tuples_good(meta_set_emits):
StandaloneMockContext._validate_tuples((10, 'fish'), meta_set_emits.output_columns)
def test_validate_emit_good(meta_set_emits):
StandaloneMockContext.validate_emit((10, 'fish'), meta_set_emits.output_columns)


def test_validate_tuples_bad(meta_set_emits):
def test_validate_emit_bad(meta_set_emits):
with pytest.raises(Exception):
StandaloneMockContext._validate_tuples((10,), meta_set_emits.output_columns)
StandaloneMockContext.validate_emit((10,), meta_set_emits.output_columns)
with pytest.raises(Exception):
StandaloneMockContext._validate_tuples((10, 'fish', 4.5), meta_set_emits.output_columns)
StandaloneMockContext.validate_emit((10, 'fish', 4.5), meta_set_emits.output_columns)
with pytest.raises(Exception):
StandaloneMockContext._validate_tuples((10., 'fish'), meta_set_emits.output_columns)
StandaloneMockContext.validate_emit((10., 'fish'), meta_set_emits.output_columns)


def test_emit_df(context_set_emits):
Expand Down

0 comments on commit aa3f2ed

Please sign in to comment.