Skip to content

Commit

Permalink
Fixed initial review findings
Browse files Browse the repository at this point in the history
  • Loading branch information
ckunki committed Nov 15, 2024
1 parent bb9d341 commit 1e99613
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 87 deletions.
Original file line number Diff line number Diff line change
@@ -1,68 +1,4 @@
from typing import Set, List

from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext

from exasol_machine_learning_library.execution.execution_graph import ExecutionGraph
from exasol_machine_learning_library.execution.sql_stage_graph_execution.sql_stage_input_output import \
SQLStageInputOutput
from exasol_machine_learning_library.execution.stage_graph.stage import UDFStage, SQLStage
from exasol_machine_learning_library.execution.stage_graph.sql_stage_train_query_handler import \
SQLStageTrainQueryHandler
from exasol_machine_learning_library.execution.stage_graph.sql_stage import SQLStage

UDFStageGraph = ExecutionGraph[UDFStage]
SQLStageGraph = ExecutionGraph[SQLStage]


class UDFRunnerPlaceholderSQLStage(SQLStage):
def __init__(self, udf_stages_component: Set[UDFStage]):
self._udf_stage_component = udf_stages_component

@property
def conntected_udf_stages(self) -> Set[UDFStage]:
return set(self._udf_stage_component)

def __repr__(self) -> str:
return f"{self.__class__.__name__}@{id(self)}"

def create_train_query_handler(self, stage_inputs: List[SQLStageInputOutput],
query_handler_context: ScopeQueryHandlerContext) -> SQLStageTrainQueryHandler:
raise NotImplemented("The method create_train_query_handler should never be called on this stage, "
"because it is a placeholder for graph rewriting.")


class UDFRunnerSQLStage(SQLStage):
def __init__(self, udf_stage_graph: UDFStageGraph):
self._udf_stage_graph = udf_stage_graph

@property
def udf_stage_graph(self) -> UDFStageGraph:
return self._udf_stage_graph

def __repr__(self) -> str:
return f"{self.__class__.__name__}@{id(self)}"

def create_train_query_handler(self, stage_inputs: List[SQLStageInputOutput],
query_handler_context: ScopeQueryHandlerContext) -> SQLStageTrainQueryHandler:
raise NotImplemented("The create_train_query_handler method needs a implementation.")


class ColumnSelectorPlaceholderUDFStage(UDFStage):
def __init__(self, input_sql_stage: List[SQLStage]):
self._input_sql_stages = input_sql_stage

@property
def input_sql_stages(self) -> List[SQLStage]:
return self._input_sql_stages

def __repr__(self) -> str:
return f"{self.__class__.__name__}@{id(self)}"


class SourceUDFStage(UDFStage):
def __repr__(self) -> str:
return f"{self.__class__.__name__}@{id(self)}"


class SinkUDFStage(UDFStage):
def __repr__(self) -> str:
return f"{self.__class__.__name__}@{id(self)}"
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
SQLStageInputOutput
from exasol_machine_learning_library.execution.stage_graph.sql_stage_train_query_handler import \
SQLStageTrainQueryHandlerInput
from exasol_machine_learning_library.execution.stage_graph.stage import SQLStage
from exasol_machine_learning_library.execution.stage_graph.sql_stage import SQLStage


class ResultHandlerReturnValue(enum.Enum):
Expand Down
15 changes: 15 additions & 0 deletions exasol_machine_learning_library/execution/stage_graph/sql_stage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import abc

from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext
from exasol_machine_learning_library.execution.stage_graph.stage import Stage
from exasol_machine_learning_library.execution.stage_graph.sql_stage_train_query_handler import \
SQLStageTrainQueryHandler, SQLStageTrainQueryHandlerInput

class SQLStage(Stage):
@abc.abstractmethod
def create_train_query_handler(
self,
stage_input: SQLStageTrainQueryHandlerInput,
query_handler_context: ScopeQueryHandlerContext,
) -> SQLStageTrainQueryHandler:
pass
18 changes: 0 additions & 18 deletions exasol_machine_learning_library/execution/stage_graph/stage.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
import abc
from abc import ABC
from typing import Generic, TypeVar

from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext

from exasol_machine_learning_library.execution.stage_graph.sql_stage_train_query_handler import \
SQLStageTrainQueryHandler, SQLStageTrainQueryHandlerInput
from exasol_machine_learning_library.execution.trainable_estimators import Parameter, Result

ParameterType = TypeVar("ParameterType", bound=Parameter)
Expand All @@ -14,16 +9,3 @@

class Stage(ABC, Generic[ParameterType, ResultType]):
pass


class SQLStage(Stage):
@abc.abstractmethod
def create_train_query_handler(self,
stage_input: SQLStageTrainQueryHandlerInput,
query_handler_context: ScopeQueryHandlerContext) \
-> SQLStageTrainQueryHandler:
pass


class UDFStage(Stage):
pass
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from exasol_machine_learning_library.execution.execution_graph import ExecutionGraph
from exasol_machine_learning_library.execution.stage_graph.stage import SQLStage, UDFStage, Stage
from exasol_machine_learning_library.execution.stage_graph.stage import Stage

StageGraph = ExecutionGraph[Stage]
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
SQLStageGraphExecutionQueryHandlerState
from exasol_machine_learning_library.execution.sql_stage_graph_execution.sql_stage_input_output import \
SQLStageInputOutput
from exasol_machine_learning_library.execution.stage_graph.stage import SQLStage
from exasol_machine_learning_library.execution.stage_graph.sql_stage import SQLStage
from exasol_machine_learning_library.execution.stage_graph.sql_stage_train_query_handler import \
SQLStageTrainQueryHandler
from tests.unit_tests.sql_stage_graph.mock_cast import mock_cast
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
SQLStageInputOutput
from exasol_machine_learning_library.execution.stage_graph.sql_stage_train_query_handler import \
SQLStageTrainQueryHandler, SQLStageTrainQueryHandlerInput
from exasol_machine_learning_library.execution.stage_graph.stage import SQLStage
from exasol_machine_learning_library.execution.stage_graph.sql_stage import SQLStage


class StartOnlyForwardInputTestSQLStageTrainQueryHandler(SQLStageTrainQueryHandler):
Expand Down

0 comments on commit 1e99613

Please sign in to comment.