Skip to content

Commit

Permalink
Merge pull request #185 from abstractqqq/better_numeric_profile
Browse files Browse the repository at this point in the history
  • Loading branch information
abstractqqq authored Jun 18, 2024
2 parents 4e60feb + eb5ed75 commit cccd3cb
Show file tree
Hide file tree
Showing 7 changed files with 1,809 additions and 1,641 deletions.
3 changes: 3 additions & 0 deletions docs/dia.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
## Data Inspection Assistant and Diagnosis

::: polars_ds.diagnosis
3,267 changes: 1,675 additions & 1,592 deletions examples/diagnosis.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ use_directory_urls: false

nav:
- Home: index.md
- Complex Extension: complex.md
- Diagnosis: dia.md
- Numerical Extension: num.md
- Stats Extension: stats.md
- String Extension: str2.md
- ML Metrics/Loss Extension: metrics.md
- Graph Extension: graph.md
- Complex Extension: complex.md
- Additional Expressions: polars_ds.md

theme:
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ authors = [
{name = "Tianren Qin", email = "[email protected]"},
]
dependencies = [
"polars >= 0.20.6, !=0.20.12, <1.0",
"polars >= 0.20.6, !=0.20.12",
'typing-extensions; python_version <= "3.11"',
]

keywords = ["polars-extension", "scientific-computing", "data-science"]

[project.optional-dependencies]
plot = [
"great-tables>=0.5",
"great-tables>=0.9",
"graphviz>=0.20",
"plotly>=5.0,<6"
]
Expand Down
157 changes: 112 additions & 45 deletions python/polars_ds/diagnosis.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,29 @@ def __init__(self, df: Union[pl.DataFrame, pl.LazyFrame]):
)
self.other_types: List[str] = [c for c in self._frame.columns if c not in self.simple_types]

def numeric_profile(self, n_bins: int = 20, iqr_multiplier: float = 1.5) -> GT:
def special_values_report(self) -> pl.DataFrame:
"""
Checks null, NaN, and non-finite values for float columns. Note that for integers, only null_count
can possibly be non-zero.
"""
to_check = self.numerics
frames = [
self._frame.select(
pl.lit(c, dtype=pl.String).alias("column"),
pl.col(c).null_count().alias("null_count"),
(pl.col(c).null_count() / pl.len()).alias("null%"),
pl.col(c).is_nan().sum().alias("NaN_count"),
(pl.col(c).is_nan().sum() / pl.len()).alias("NaN%"),
pl.col(c).is_infinite().sum().alias("inf_count"),
(pl.col(c).is_infinite().sum() / pl.len()).alias("Inf%"),
)
for c in to_check
]
return pl.concat(pl.collect_all(frames))

def numeric_profile(
self, n_bins: int = 20, iqr_multiplier: float = 1.5, histogram: bool = True, gt: bool = True
) -> GT:
"""
Creates a numerical profile with a histogram plot. Notice that the histograms may have
completely different scales on the x-axis.
Expand All @@ -80,59 +102,104 @@ def numeric_profile(self, n_bins: int = 20, iqr_multiplier: float = 1.5) -> GT:
Inter Quartile Ranger multiplier. Inter quantile range is the range between
Q1 and Q3, and this multiplier will enlarge the range by a certain amount and
use this to count outliers.
histogram
Whether to show a histogram or not
gt
Whether to show the table as a formatted Great Table or not
"""
to_check = self.numerics

cuts = [i / n_bins for i in range(n_bins)]
cuts[0] -= 1e-5
cuts[-1] += 1e-5
frames = []
for c in to_check:
temp = self._frame.select(
pl.lit(c).alias("column"),
pl.col(c).count().alias("non_null_cnt"),
(pl.col(c).null_count() / pl.len()).alias("null%"),
pl.col(c).mean().alias("mean"),
pl.col(c).std().alias("std"),
pl.col(c).min().cast(pl.Float64).alias("min"),
pl.col(c).quantile(0.25).cast(pl.Float64).alias("q1"),
pl.col(c).median().cast(pl.Float64).round(2).alias("median"),
pl.col(c).quantile(0.75).cast(pl.Float64).alias("q3"),
pl.col(c).max().cast(pl.Float64).alias("max"),
(pl.col(c).quantile(0.75) - pl.col(c).quantile(0.25)).cast(pl.Float64).alias("IQR"),
pl.any_horizontal(
pl.col(c)
< pl.col(c).quantile(0.25)
- iqr_multiplier * (pl.col(c).quantile(0.75) - pl.col(c).quantile(0.25)),
pl.col(c)
> pl.col(c).quantile(0.75)
+ iqr_multiplier * (pl.col(c).quantile(0.75) - pl.col(c).quantile(0.25)),
)
.sum()
.alias("outlier_cnt"),
pl.struct(
((pl.col(c) - pl.col(c).min()) / (pl.col(c).max() - pl.col(c).min()))
.filter(pl.col(c).is_finite())
.cut(breaks=cuts, left_closed=True, include_breaks=True)
.struct.field("brk")
.value_counts()
.sort()
.struct.field("count")
.implode()
).alias("histogram"),
)
frames.append(temp)

if histogram:
columns_needed = [
[
pl.lit(c, dtype=pl.String).alias("column"),
pl.col(c).count().alias("non_null_cnt"),
(pl.col(c).null_count() / pl.len()).alias("null%"),
pl.col(c).mean().cast(pl.Float64).alias("mean"),
pl.col(c).std().cast(pl.Float64).alias("std"),
pl.col(c).min().cast(pl.Float64).cast(pl.Float64).alias("min"),
pl.col(c).quantile(0.25).cast(pl.Float64).alias("q1"),
pl.col(c).median().cast(pl.Float64).round(2).alias("median"),
pl.col(c).quantile(0.75).cast(pl.Float64).alias("q3"),
pl.col(c).max().cast(pl.Float64).alias("max"),
(pl.col(c).quantile(0.75) - pl.col(c).quantile(0.25))
.cast(pl.Float64)
.alias("IQR"),
pl.any_horizontal(
pl.col(c)
< pl.col(c).quantile(0.25)
- iqr_multiplier * (pl.col(c).quantile(0.75) - pl.col(c).quantile(0.25)),
pl.col(c)
> pl.col(c).quantile(0.75)
+ iqr_multiplier * (pl.col(c).quantile(0.75) - pl.col(c).quantile(0.25)),
)
.sum()
.alias("outlier_cnt"),
pl.struct(
((pl.col(c) - pl.col(c).min()) / (pl.col(c).max() - pl.col(c).min()))
.filter(pl.col(c).is_finite())
.cut(breaks=cuts, left_closed=True, include_breaks=True)
.struct.rename_fields(["brk", "category"])
.struct.field("brk")
.value_counts()
.sort()
.struct.field("count")
.implode()
).alias("histogram"),
]
for c in to_check
]
else:
columns_needed = [
[
pl.lit(c, dtype=pl.String).alias("column"),
pl.col(c).count().alias("non_null_cnt"),
(pl.col(c).null_count() / pl.len()).alias("null%"),
pl.col(c).mean().cast(pl.Float64).alias("mean"),
pl.col(c).std().cast(pl.Float64).alias("std"),
pl.col(c).min().cast(pl.Float64).cast(pl.Float64).alias("min"),
pl.col(c).quantile(0.25).cast(pl.Float64).alias("q1"),
pl.col(c).median().cast(pl.Float64).round(2).alias("median"),
pl.col(c).quantile(0.75).cast(pl.Float64).alias("q3"),
pl.col(c).max().cast(pl.Float64).alias("max"),
(pl.col(c).quantile(0.75) - pl.col(c).quantile(0.25))
.cast(pl.Float64)
.alias("IQR"),
pl.any_horizontal(
pl.col(c)
< pl.col(c).quantile(0.25)
- iqr_multiplier * (pl.col(c).quantile(0.75) - pl.col(c).quantile(0.25)),
pl.col(c)
> pl.col(c).quantile(0.75)
+ iqr_multiplier * (pl.col(c).quantile(0.75) - pl.col(c).quantile(0.25)),
)
.sum()
.alias("outlier_cnt"),
]
for c in to_check
]

frames = [self._frame.select(*cols) for cols in columns_needed]
df_final = pl.concat(pl.collect_all(frames))
return (
GT(df_final, rowname_col="column")
.tab_stubhead("column")
.fmt_percent(columns="null%")
.fmt_number(
columns=["mean", "std", "min", "q1", "median", "q3", "max", "IQR"], decimals=3

if gt:
gt_out = (
GT(df_final, rowname_col="column")
.tab_stubhead("column")
.fmt_percent(columns="null%")
.fmt_number(
columns=["mean", "std", "min", "q1", "median", "q3", "max", "IQR"], decimals=3
)
)
.fmt_nanoplot(columns="histogram", plot_type="bar")
)
if histogram:
return gt_out.fmt_nanoplot(columns="histogram", plot_type="bar")
return gt_out
else:
return df_final

def plot_null_distribution(
self,
Expand Down
5 changes: 4 additions & 1 deletion python/polars_ds/num.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,7 @@ def query_permute_entropy(
.head(t.len() - n_dims + 1)
.list.eval(pl.element().arg_sort())
.value_counts() # groupby and count, but returns a struct
.struct.field("count") # extract the field named "counts"
.struct.field("count") # extract the field named "count"
.entropy(base=base, normalize=True)
)
else:
Expand Down Expand Up @@ -1059,6 +1059,9 @@ def query_psi(

vc = (
valid_ref.qcut(n_bins, left_closed=False, allow_duplicates=True, include_breaks=True)
.struct.rename_fields(
["brk", "category"]
) # Use "breakpoints" in the future. Skip this rename. After polars v1
.struct.field("brk")
.value_counts()
.sort()
Expand Down
11 changes: 11 additions & 0 deletions tests/test.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@
"df"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df.select(\n",
" pl.col(\"x1\").cut(breaks=[0.2, 0.4], left_closed=True, include_breaks=True)\n",
").unnest(\"x1\")"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down

0 comments on commit cccd3cb

Please sign in to comment.