Skip to content

Commit

Permalink
Merge pull request #609 from zen-xu/return-annotation
Browse files Browse the repository at this point in the history
typing: optimize type annotations
  • Loading branch information
wangxiaoying authored Apr 18, 2024
2 parents 59bb016 + 677faf1 commit 8dc238d
Showing 1 changed file with 110 additions and 13 deletions.
123 changes: 110 additions & 13 deletions connectorx-python/connectorx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any
from typing import Any, Literal, TYPE_CHECKING, overload

from importlib.metadata import version

Expand All @@ -11,6 +11,13 @@
get_meta as _get_meta,
)

if TYPE_CHECKING:
import pandas as pd
import polars as pl
import modin.pandas as mpd
import dask.dataframe as dd
import pyarrow as pa

__version__ = version(__name__)

import os
Expand All @@ -27,8 +34,10 @@
"CX_REWRITER_PATH", os.path.join(dir_path, "dependencies/federated-rewriter.jar")
)

Protocol = Literal["csv", "binary", "cursor", "simple", "text"]

def rewrite_conn(conn: str, protocol: str | None = None):

def rewrite_conn(conn: str, protocol: Protocol | None = None) -> tuple[str, Protocol]:
if not protocol:
# note: redshift/clickhouse are not compatible with the 'binary' protocol, and use other database
# drivers to connect. set a compatible protocol and masquerade as the appropriate backend.
Expand All @@ -47,8 +56,8 @@ def rewrite_conn(conn: str, protocol: str | None = None):
def get_meta(
conn: str,
query: str,
protocol: str | None = None,
):
protocol: Protocol | None = None,
) -> pd.DataFrame:
"""
Get metadata (header) of the given query (only for pandas)
Expand All @@ -75,7 +84,7 @@ def partition_sql(
partition_on: str,
partition_num: int,
partition_range: tuple[int, int] | None = None,
):
) -> list[str]:
"""
Partition the sql query
Expand Down Expand Up @@ -106,11 +115,11 @@ def read_sql_pandas(
sql: list[str] | str,
con: str | dict[str, str],
index_col: str | None = None,
protocol: str | None = None,
protocol: Protocol | None = None,
partition_on: str | None = None,
partition_range: tuple[int, int] | None = None,
partition_num: int | None = None,
):
) -> pd.DataFrame:
"""
Run the SQL query, download the data from database into a dataframe.
First several parameters are in the same name and order with `pandas.read_sql`.
Expand Down Expand Up @@ -142,17 +151,103 @@ def read_sql_pandas(
)


# default return pd.DataFrame
@overload
def read_sql(
conn: str | dict[str, str],
query: list[str] | str,
*,
return_type: str = "pandas",
protocol: str | None = None,
protocol: Protocol | None = None,
partition_on: str | None = None,
partition_range: tuple[int, int] | None = None,
partition_num: int | None = None,
index_col: str | None = None,
):
) -> pd.DataFrame: ...


@overload
def read_sql(
conn: str | dict[str, str],
query: list[str] | str,
*,
return_type: Literal["pandas"],
protocol: Protocol | None = None,
partition_on: str | None = None,
partition_range: tuple[int, int] | None = None,
partition_num: int | None = None,
index_col: str | None = None,
) -> pd.DataFrame: ...


@overload
def read_sql(
conn: str | dict[str, str],
query: list[str] | str,
*,
return_type: Literal["arrow", "arrow2"],
protocol: Protocol | None = None,
partition_on: str | None = None,
partition_range: tuple[int, int] | None = None,
partition_num: int | None = None,
index_col: str | None = None,
) -> pa.Table: ...


@overload
def read_sql(
conn: str | dict[str, str],
query: list[str] | str,
*,
return_type: Literal["modin"],
protocol: Protocol | None = None,
partition_on: str | None = None,
partition_range: tuple[int, int] | None = None,
partition_num: int | None = None,
index_col: str | None = None,
) -> mpd.DataFrame: ...


@overload
def read_sql(
conn: str | dict[str, str],
query: list[str] | str,
*,
return_type: Literal["dask"],
protocol: Protocol | None = None,
partition_on: str | None = None,
partition_range: tuple[int, int] | None = None,
partition_num: int | None = None,
index_col: str | None = None,
) -> dd.DataFrame: ...


@overload
def read_sql(
conn: str | dict[str, str],
query: list[str] | str,
*,
return_type: Literal["polars", "polars2"],
protocol: Protocol | None = None,
partition_on: str | None = None,
partition_range: tuple[int, int] | None = None,
partition_num: int | None = None,
index_col: str | None = None,
) -> pl.DataFrame: ...


def read_sql(
conn: str | dict[str, str],
query: list[str] | str,
*,
return_type: Literal[
"pandas", "polars", "polars2", "arrow", "arrow2", "modin", "dask"
] = "pandas",
protocol: Protocol | None = None,
partition_on: str | None = None,
partition_range: tuple[int, int] | None = None,
partition_num: int | None = None,
index_col: str | None = None,
) -> pd.DataFrame | mpd.DataFrame | dd.DataFrame | pl.DataFrame | pa.Table:
"""
Run the SQL query, download the data from database into a dataframe.
Expand Down Expand Up @@ -318,7 +413,9 @@ def read_sql(
return df


def reconstruct_arrow(result: tuple[list[str], list[list[tuple[int, int]]]]):
def reconstruct_arrow(
result: tuple[list[str], list[list[tuple[int, int]]]],
) -> pa.Table:
import pyarrow as pa

names, ptrs = result
Expand All @@ -334,7 +431,7 @@ def reconstruct_arrow(result: tuple[list[str], list[list[tuple[int, int]]]]):
return pa.Table.from_batches(rbs)


def reconstruct_pandas(df_infos: dict[str, Any]):
def reconstruct_pandas(df_infos: dict[str, Any]) -> pd.DataFrame:
import pandas as pd

data = df_infos["data"]
Expand Down Expand Up @@ -388,6 +485,6 @@ def remove_ending_semicolon(query: str) -> str:
SQL query
"""
if query.endswith(';'):
if query.endswith(";"):
query = query[:-1]
return query

0 comments on commit 8dc238d

Please sign in to comment.