Skip to content

Commit

Permalink
Merge pull request #281 from abstractqqq/split_dfs
Browse files Browse the repository at this point in the history
added basic splitting
  • Loading branch information
abstractqqq authored Nov 1, 2024
2 parents 30726ae + f637648 commit fb5c136
Show file tree
Hide file tree
Showing 7 changed files with 792 additions and 686 deletions.
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ jobs:
python -c 'from polars_ds import linear_models'
python -c 'from polars_ds.ts_features import *'
python -c 'from polars_ds.spatial import *'
python -c 'from polars_ds.sample_and_split import *'
- name: Upload sdist
uses: actions/upload-artifact@v4
Expand Down
2 changes: 1 addition & 1 deletion docs/sample.md → docs/sample_and_split.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
## Polars Native Machine Learning Pipeline

::: polars_ds.sample
::: polars_ds.sample_and_split
1,211 changes: 605 additions & 606 deletions examples/diagnosis.ipynb

Large diffs are not rendered by default.

186 changes: 112 additions & 74 deletions examples/sample_and_split.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ nav:
- Home: index.md
- Diagnosis: dia.md
- Pipeline: pipeline.md
- Sample: sample.md
- Sample and Split: sample_and_split.md
- Numerical Functions: num.md
- Statistics: stats.md
- String Related: string.md
Expand Down
3 changes: 1 addition & 2 deletions python/polars_ds/diagnosis.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,13 @@
from . import query_cond_entropy, query_principal_components, query_r2
from .type_alias import CorrMethod, PolarsFrame
from .stats import corr
from .sample import sample
from .sample_and_split import sample

alt.data_transformers.enable("vegafusion")


# DIA = Data Inspection Assistant / DIAgonsis
class DIA:

"""
Data Inspection Assistant. Most plots are powered by plotly/great_tables. Plotly may require
additional package downloads.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@ def _sampler_expr(value: float | int, seed: int | None = None) -> pl.Expr:
if isinstance(value, float):
if value >= 1.0 or value <= 0.0:
raise ValueError("Sample rate must be in (0, 1) range.")
return pl.int_range(0, pl.len()).shuffle(seed) < pl.len() * value
return pl.int_range(0, pl.len(), dtype=pl.UInt32).shuffle(seed) < (pl.len() * value).cast(
pl.UInt32
)
elif isinstance(value, int):
if value <= 0:
raise ValueError("Sample count must be > 0.")
return pl.int_range(0, pl.len()).shuffle(seed) < value
return pl.int_range(0, pl.len(), dtype=pl.UInt32).shuffle(seed) < pl.lit(
value, dtype=pl.UInt32
)
elif isinstance(value, pl.Expr):
return NotImplemented
else:
Expand Down Expand Up @@ -183,3 +187,68 @@ def random_cols(
n = random.randrange(0, math.comb(pool_size, k))
rand_cols = next(islice(to_sample, n, None), None)
return out + list(rand_cols)


def split_by_ratio(
df: PolarsFrame, split_ratio: float | List[float], seed: int | None = None
) -> List[pl.DataFrame]:
"""
Split the dataframe into multiple. If split_ratio is a single number, it is treated as the
ratio for the "train" set and the second is always the "test" set. If split_ratio is a list
of floats, then they must sum to 1 and the return will be dataframes split into the corresponding
ratios. Please avoid using floating point values with too many decimal places, which may cause
the splits to be off by a 1 row.
This will collect your LazyFrames.
Parameters
----------
df
Either a lazy or eager Polars dataframe
split_ratio
Either a single float or a list of floats.
seed
The random seed
"""

if isinstance(split_ratio, float):
if split_ratio <= 0.0 or split_ratio >= 1:
raise ValueError("Split ratio must be > 0 and < 1.")

frames = (
df.lazy()
.with_row_index(name="__id")
.collect()
.with_columns(
(pl.col("__id").shuffle(seed=seed) < (pl.len() * split_ratio).cast(pl.Int64)).alias(
"__tt"
)
)
.partition_by("__tt", as_dict=True)
)
train = frames[(True,)].select(pl.col("*").exclude(["__id", "__tt"]))
test = frames[(False,)].select(pl.col("*").exclude(["__id", "__tt"]))
return [train, test]
else:
if sum(split_ratio) != 1:
raise ValueError("Sum of the ratios is not 1.")

df_eager = (
df.with_row_index(name="__id")
.with_columns(pl.col("__id").shuffle(seed=seed).alias("__tt"))
.sort("__tt")
.lazy()
.collect()
)

n = len(df_eager)
start = 0
dfs = []
for v in split_ratio:
length = int(n * v)
dfs.append(
df_eager.slice(start, length=length).select(pl.col("*").exclude(["__id", "__tt"]))
)
start += length

return dfs

0 comments on commit fb5c136

Please sign in to comment.