Skip to content

Commit

Permalink
fix mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
Paul-B98 committed Oct 22, 2024
1 parent 109b0d8 commit 9fd2b45
Show file tree
Hide file tree
Showing 10 changed files with 73 additions and 53 deletions.
19 changes: 11 additions & 8 deletions icu_pipeline/concept.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import pandas as pd
from .job import Job
from icu_pipeline.source import DataSource, SourceConfig
from icu_pipeline.source import AbstractSourceMapper, getDataSourceMapper
from icu_pipeline.unit import BaseConverter, ConverterConfig
from typing import Any

from pandera.typing import DataFrame

from conceptbase.config import ConceptCoding, ConceptConfig
from icu_pipeline.graph import Node
from conceptbase.config import ConceptConfig, ConceptCoding
from icu_pipeline.source import AbstractSourceMapper, DataSource, SourceConfig, getDataSourceMapper
from icu_pipeline.unit import BaseConverter, ConverterConfig

from .job import Job


class Concept(Node):
Expand Down Expand Up @@ -74,11 +77,11 @@ def __eq__(self, value: object) -> bool:
return value == self._concept_config.name
return super().__eq__(value)

def fetch_sources(self, job: Job, *args, **kwargs):
def fetch_sources(self, job: Job, *args: list[Any], **kwargs: dict[Any, Any]) -> dict[str, DataFrame]:
# Don't do anything
pass

def get_data(self, job) -> dict[str, pd.DataFrame]:
def get_data(self, job: Job, *args: list[Any], **kwargs: dict[Any, Any]) -> DataFrame:
"""Map the concept to data from the sources."""
assert (
job.database in self._data_sources
Expand Down
4 changes: 2 additions & 2 deletions icu_pipeline/graph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ class GraphType(StrEnum):
from icu_pipeline.graph.in_memory import InMemoryNode as Node
from icu_pipeline.graph.in_memory import InMemoryPipe as Pipe
case GraphType.Multiprocessing:
from icu_pipeline.graph.parallel import MultiprocessingNode as Node
from icu_pipeline.graph.parallel import MultiprocessingPipe as Pipe
from icu_pipeline.graph.parallel import MultiprocessingNode as Node # type: ignore[assignment]
from icu_pipeline.graph.parallel import MultiprocessingPipe as Pipe # type: ignore[assignment]
case _:
raise EnvironmentError(f"Unknown GraphType '{t}'. Available Modules: {[tt.value for tt in GraphType]}")

Expand Down
35 changes: 21 additions & 14 deletions icu_pipeline/graph/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import pandas as pd
from typing import TYPE_CHECKING, Any

from pandera.typing import DataFrame

from icu_pipeline.job import Job

if TYPE_CHECKING:
from icu_pipeline.concept import Concept


class BaseNode:
ID = 0
REQUIRED_CONCEPTS = []
REQUIRED_CONCEPTS: list["Concept"] = []

def __init__(self, concept_id: str) -> None:
self._node_id = BaseNode.ID
Expand All @@ -24,10 +30,10 @@ def __eq__(self, value: object) -> bool:
def __str__(self) -> str:
return f"{type(self).__name__}({self._node_id})"

def fetch_sources(self, job: Job, *args, **kwargs) -> dict[str, pd.DataFrame]:
def fetch_sources(self, job: Job, *args: list[Any], **kwargs: dict[Any, Any]) -> dict[str, DataFrame]:
raise NotImplementedError

def get_data(self, job: Job, *args, **kwargs) -> pd.DataFrame:
def get_data(self, job: Job, *args: list[Any], **kwargs: dict[Any, Any]) -> DataFrame:
raise NotImplementedError


Expand All @@ -36,10 +42,10 @@ def __init__(self, source: BaseNode, sink: BaseNode) -> None:
self._source = source
self._sink = sink

def write(self, job: Job, data: pd.DataFrame, *args, **kwargs) -> None:
def write(self, job: Job, data: DataFrame, *args: list[Any], **kwargs: dict[Any, Any]) -> None:
raise NotImplementedError

def read(self, job: Job, *args, **kwargs) -> pd.DataFrame:
def read(self, job: Job, *args: list[Any], **kwargs: dict[Any, Any]) -> DataFrame:
raise NotImplementedError


Expand All @@ -49,15 +55,15 @@ def __init__(self) -> None:
self._edges: list[BasePipe] = []

@property
def sources(self):
def sources(self) -> list[BaseNode]:
out = []
for n in self._nodes:
if len(n._sources) == 0 and n not in out:
out.append(n)
return out

@property
def sinks(self):
def sinks(self) -> list[BaseNode]:
out = []
for n in self._nodes:
if len(n._sinks) == 0 and n not in out:
Expand All @@ -70,7 +76,7 @@ def getNode(self, node: BaseNode | str) -> BaseNode | None:
return n
return None

def addPipe(self, source: BaseNode, sink: BaseNode):
def addPipe(self, source: BaseNode, sink: BaseNode) -> None:
from icu_pipeline.graph import Pipe

# Check if the Nodes are already part of the Graph
Expand All @@ -79,7 +85,7 @@ def addPipe(self, source: BaseNode, sink: BaseNode):
if sink not in self._nodes:
self._nodes.append(sink)

new_pipe = Pipe(source, sink)
new_pipe = Pipe(source, sink) # type: ignore[arg-type]
# Append the Pipe to the Graph
self._edges.append(new_pipe)
# Append the Pipe to the Nodes
Expand All @@ -88,7 +94,7 @@ def addPipe(self, source: BaseNode, sink: BaseNode):

self.check_circularity()

def check_circularity(self):
def check_circularity(self) -> bool:
"""Only works for directed acyclic graphs (DAGs).
Return 'True' if the check is passed"""
# Assume everythings fine
Expand All @@ -99,7 +105,7 @@ def check_circularity(self):
if len(self.sources) == 0 and len(self._nodes) > 0:
result = False

def _is_circular(node: BaseNode, nodes: list[BaseNode]):
def _is_circular(node: BaseNode, nodes: list[BaseNode]) -> bool:
# Check if I'm part of the visited Nodes
if node in nodes:
return True
Expand All @@ -117,12 +123,13 @@ def _is_circular(node: BaseNode, nodes: list[BaseNode]):
result = False

assert result, "Graph has Circular dependencies!"
return result

def check_hanging_leaves(self):
def check_hanging_leaves(self) -> bool:
"""Are there any Nodes that are not connected to the sink?"""
raise NotImplementedError

def check_missing_connections(self):
def check_missing_connections(self) -> bool:
"""Are there any dependencies not met?"""
raise NotImplementedError

Expand Down
13 changes: 8 additions & 5 deletions icu_pipeline/graph/in_memory.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from pandas.core.api import DataFrame as DataFrame
from icu_pipeline.graph.base import BasePipe, BaseNode
from typing import Any

from pandera.typing import DataFrame

from icu_pipeline.graph.base import BaseNode, BasePipe
from icu_pipeline.job import Job


class InMemoryNode(BaseNode):
def fetch_sources(self, job: Job, *args, **kwargs) -> dict[str, DataFrame]:
def fetch_sources(self, job: Job, *args: list[Any], **kwargs: dict[Any, Any]) -> dict[str, DataFrame]:
out: dict[str, DataFrame] = {}
for c, s in self._sources.items():
out[c] = s.read(job, *args, **kwargs)
Expand All @@ -15,10 +18,10 @@ class InMemoryPipe(BasePipe):
def __init__(self, source: InMemoryNode, sink: InMemoryNode) -> None:
super().__init__(source, sink)

def read(self, job: Job, *args, **kwargs):
def read(self, job: Job, *args: list[Any], **kwargs: dict[Any, Any]) -> DataFrame:
# Forward the get_data method
return self._source.get_data(job)

def write(self, job: Job, data, *args, **kwargs):
def write(self, job: Job, data: DataFrame, *args: list[Any], **kwargs: dict[Any, Any]) -> None:
# Nothing special
return data
9 changes: 5 additions & 4 deletions icu_pipeline/graph/parallel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import multiprocessing
from typing import Any

from pandera.typing import DataFrame

Expand All @@ -10,7 +11,7 @@


class MultiprocessingNode(BaseNode):
def fetch_sources(self, job: Job, *args, **kwargs) -> dict[str, DataFrame]:
def fetch_sources(self, job: Job, *args: list[Any], **kwargs: dict[Any, Any]) -> dict[str, DataFrame]:
manager = multiprocessing.Manager()
out = manager.dict()

Expand All @@ -20,7 +21,7 @@ def fetch_sources(self, job: Job, *args, **kwargs) -> dict[str, DataFrame]:
p.join()
return out

def get_data(self, job: Job) -> DataFrame:
def get_data(self, job: Job, *args: list[Any], **kwargs: dict[Any, Any]) -> DataFrame:
data = self.fetch_sources(job)
if self._concept_id is not None:
data = data[self._concept_id]
Expand All @@ -31,7 +32,7 @@ class MultiprocessingPipe(BasePipe):
def __init__(self, source: MultiprocessingNode, sink: MultiprocessingNode) -> None:
super().__init__(source, sink)

def read(self, job: Job, managed_dict, *args, **kwargs):
def read(self, job: Job, managed_dict, *args: list[Any], **kwargs: dict[Any, Any]) -> DataFrame:
def _read(result: dict):
df = self._source.get_data(job)
result[self._source._concept_id] = df
Expand All @@ -40,6 +41,6 @@ def _read(result: dict):
p.start()
return p

def write(self, job: Job, data, *args, **kwargs):
def write(self, job: Job, data: DataFrame, *args: list[Any], **kwargs: dict[Any, Any]) -> None:
# Nothing special
return data
10 changes: 7 additions & 3 deletions icu_pipeline/sink/pandas.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from icu_pipeline.job import Job
from typing import Any

from pandera.typing import DataFrame

from icu_pipeline.graph import Node
from icu_pipeline.job import Job


class PandasSink(Node):
def __init__(self) -> None:
super().__init__(None)
super().__init__("concept_id")

def get_data(self, job: Job, *args, **kwargs):
def get_data(self, job: Job, *args: list[Any], **kwargs: dict[Any, Any]) -> DataFrame:
data = self.fetch_sources(job, *args, **kwargs)
# TODO - Merge the Data?
return data
3 changes: 2 additions & 1 deletion icu_pipeline/source/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pandera.typing import DataFrame, Series

from conceptbase.config import MapperConfig
from icu_pipeline.job import Job
from icu_pipeline.logger import ICULogger
from icu_pipeline.schema.fhir import AbstractFHIRSinkSchema

Expand Down Expand Up @@ -87,7 +88,7 @@ def __init__(
self._source_config = source_config

@abstractmethod
def get_data(self, job) -> DataFrame:
def get_data(self, job: Job) -> DataFrame:
"""
Retrieves the data to be mapped.
Expand Down
13 changes: 6 additions & 7 deletions icu_pipeline/source/database/mapper.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from typing import Any, TypeVar, Generic
from typing import Any, Generic, TypeVar

import pandas as pd
from sqlalchemy import Engine, create_engine
from pandas import DataFrame
from psycopg import sql
from psycopg.sql import Composable
from sqlalchemy import Engine, create_engine

from icu_pipeline.source import DataSource, SourceConfig
from icu_pipeline.schema.fhir import AbstractFHIRSinkSchema
from icu_pipeline.source import AbstractSourceMapper
from icu_pipeline.job import Job

from icu_pipeline.logger import ICULogger
from icu_pipeline.schema.fhir import AbstractFHIRSinkSchema
from icu_pipeline.source import AbstractSourceMapper, DataSource, SourceConfig

logger = ICULogger.get_logger()

Expand Down Expand Up @@ -171,7 +170,7 @@ def _build_join_identifier(identifier: str) -> Composable:

return query

def get_data(self, job: Job) -> pd.DataFrame:
def get_data(self, job: Job) -> DataFrame:
"""
Retrieves data from the database.
Expand Down
3 changes: 2 additions & 1 deletion icu_pipeline/source/database/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class AbstractDatabaseSourceSampler(AbstractSourceSampler):
Retrieves data from the database. This method should be implemented by subclasses.
"""

IDENTIFIER: list[str] # the identifier columns for the table
SQL_QUERY: str | Composable # the SQL query to be executed

def __init__(self, source_config: SourceConfig) -> None:
Expand Down Expand Up @@ -66,7 +67,7 @@ def build_query(
FROM {schema}.{table}
LIMIT {limit}
"""

query = sql.SQL(raw_query).format(
fields=sql.SQL(", ").join([sql.SQL(i) for i in self.IDENTIFIER]),
schema=sql.Identifier(schema),
Expand Down
17 changes: 9 additions & 8 deletions tools/pipeline.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[32mINFO:root:New Logger Created\u001b[0m\n"
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
Expand Down Expand Up @@ -42,7 +43,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -75,7 +76,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -94,7 +95,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 13,
"metadata": {},
"outputs": [
{
Expand All @@ -117,7 +118,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 14,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -301,7 +302,7 @@
"[6819 rows x 4 columns]"
]
},
"execution_count": 5,
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
Expand Down

0 comments on commit 9fd2b45

Please sign in to comment.