Skip to content

Commit

Permalink
Merge pull request #610 from zen-xu/detect-pkg
Browse files Browse the repository at this point in the history
refactor: Optimize the detection of whether a package is installed.
  • Loading branch information
wangxiaoying authored Apr 18, 2024
2 parents 8a8a4c4 + 1cae814 commit 9687211
Showing 1 changed file with 16 additions and 28 deletions.
44 changes: 16 additions & 28 deletions connectorx-python/connectorx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

from typing import Any, Literal, TYPE_CHECKING, overload

import importlib
from importlib.metadata import version

from typing import Any, Literal, TYPE_CHECKING, overload

from .connectorx import (
read_sql as _read_sql,
partition_sql as _partition_sql,
Expand Down Expand Up @@ -311,10 +313,7 @@ def read_sql(
if return_type == "pandas":
df = df.to_pandas(date_as_object=False, split_blocks=False)
if return_type == "polars":
try:
import polars as pl
except ModuleNotFoundError:
raise ValueError("You need to install polars first")
pl = try_import_module("polars")

try:
# api change for polars >= 0.8.*
Expand Down Expand Up @@ -350,10 +349,7 @@ def read_sql(
conn, protocol = rewrite_conn(conn, protocol)

if return_type in {"modin", "dask", "pandas"}:
try:
import pandas
except ModuleNotFoundError:
raise ValueError("You need to install pandas first")
try_import_module("pandas")

result = _read_sql(
conn,
Expand All @@ -368,25 +364,14 @@ def read_sql(
df.set_index(index_col, inplace=True)

if return_type == "modin":
try:
import modin.pandas as mpd
except ModuleNotFoundError:
raise ValueError("You need to install modin first")

mpd = try_import_module("modin.pandas")
df = mpd.DataFrame(df)
elif return_type == "dask":
try:
import dask.dataframe as dd
except ModuleNotFoundError:
raise ValueError("You need to install dask first")

dd = try_import_module("dask.dataframe")
df = dd.from_pandas(df, npartitions=1)

elif return_type in {"arrow", "arrow2", "polars", "polars2"}:
try:
import pyarrow
except ModuleNotFoundError:
raise ValueError("You need to install pyarrow first")
try_import_module("pyarrow")

result = _read_sql(
conn,
Expand All @@ -397,11 +382,7 @@ def read_sql(
)
df = reconstruct_arrow(result)
if return_type in {"polars", "polars2"}:
try:
import polars as pl
except ModuleNotFoundError:
raise ValueError("You need to install polars first")

pl = try_import_module("polars")
try:
df = pl.DataFrame.from_arrow(df)
except AttributeError:
Expand Down Expand Up @@ -488,3 +469,10 @@ def remove_ending_semicolon(query: str) -> str:
if query.endswith(";"):
query = query[:-1]
return query


def try_import_module(name: str):
try:
return importlib.import_module(name)
except ModuleNotFoundError:
raise ValueError(f"You need to install {name.split('.')[0]} first")

0 comments on commit 9687211

Please sign in to comment.