diff --git a/docs/metrics.md b/docs/metrics.md
new file mode 100644
index 00000000..3ecc2610
--- /dev/null
+++ b/docs/metrics.md
@@ -0,0 +1,3 @@
+## Extension for ML Metrics/Losses
+
+::: polars_ds.complex.metrics
\ No newline at end of file
diff --git a/docs/polars_ds.md b/docs/polars_ds.md
index 43e4e9a4..c40fba1b 100644
--- a/docs/polars_ds.md
+++ b/docs/polars_ds.md
@@ -2,4 +2,4 @@
::: polars_ds
options:
- filters: ["!(NumExt|StatsExt|StrExt|ComplexExt)", "^__init__$"]
\ No newline at end of file
+ filters: ["!(NumExt|StatsExt|StrExt|ComplexExt|MetricExt)", "^__init__$"]
\ No newline at end of file
diff --git a/examples/basics.ipynb b/examples/basics.ipynb
index c602f4d4..e59ffd03 100644
--- a/examples/basics.ipynb
+++ b/examples/basics.ipynb
@@ -7,7 +7,7 @@
"metadata": {},
"outputs": [],
"source": [
- "import polars_ds\n",
+ "import polars_ds as pld\n",
"import polars as pl\n",
"import numpy as np "
]
@@ -36,7 +36,7 @@
" white-space: pre-wrap;\n",
"}\n",
"\n",
- "shape: (5, 10)
f | dummy | a | b | x1 | x2 | y | actual | predicted | dummy_groups |
---|
f64 | str | f64 | f64 | i64 | i64 | i64 | i32 | f64 | str |
0.0 | "a" | 0.457934 | 0.6423 | 0 | 100000 | -100000 | 1 | 0.890699 | "a" |
0.841471 | "a" | 0.459135 | 0.735028 | 1 | 100001 | -99999 | 0 | 0.388504 | "a" |
0.909297 | "a" | 0.307611 | 0.634786 | 2 | 100002 | -99998 | 1 | 0.642528 | "a" |
0.14112 | "a" | 0.95301 | 0.074787 | 3 | 100003 | -99997 | 0 | 0.327906 | "a" |
-0.756802 | "a" | 0.472305 | 0.905882 | 4 | 100004 | -99996 | 1 | 0.227964 | "a" |
"
+ "shape: (5, 10)f | dummy | a | b | x1 | x2 | y | actual | predicted | dummy_groups |
---|
f64 | str | f64 | f64 | i64 | i64 | i64 | i32 | f64 | str |
0.0 | "a" | 0.329054 | 0.431276 | 0 | 100000 | -100000 | 0 | 0.194481 | "a" |
0.841471 | "a" | 0.302734 | 0.186427 | 1 | 100001 | -99999 | 1 | 0.615612 | "a" |
0.909297 | "a" | 0.066187 | 0.955182 | 2 | 100002 | -99998 | 1 | 0.953673 | "a" |
0.14112 | "a" | 0.052694 | 0.399579 | 3 | 100003 | -99997 | 0 | 0.90706 | "a" |
-0.756802 | "a" | 0.618107 | 0.307307 | 4 | 100004 | -99996 | 0 | 0.115548 | "a" |
"
],
"text/plain": [
"shape: (5, 10)\n",
@@ -45,11 +45,11 @@
"│ --- ┆ --- ┆ --- ┆ --- ┆ ┆ --- ┆ --- ┆ --- ┆ --- │\n",
"│ f64 ┆ str ┆ f64 ┆ f64 ┆ ┆ i64 ┆ i32 ┆ f64 ┆ str │\n",
"╞═══════════╪═══════╪══════════╪══════════╪═══╪═════════╪════════╪═══════════╪══════════════╡\n",
- "│ 0.0 ┆ a ┆ 0.457934 ┆ 0.6423 ┆ … ┆ -100000 ┆ 1 ┆ 0.890699 ┆ a │\n",
- "│ 0.841471 ┆ a ┆ 0.459135 ┆ 0.735028 ┆ … ┆ -99999 ┆ 0 ┆ 0.388504 ┆ a │\n",
- "│ 0.909297 ┆ a ┆ 0.307611 ┆ 0.634786 ┆ … ┆ -99998 ┆ 1 ┆ 0.642528 ┆ a │\n",
- "│ 0.14112 ┆ a ┆ 0.95301 ┆ 0.074787 ┆ … ┆ -99997 ┆ 0 ┆ 0.327906 ┆ a │\n",
- "│ -0.756802 ┆ a ┆ 0.472305 ┆ 0.905882 ┆ … ┆ -99996 ┆ 1 ┆ 0.227964 ┆ a │\n",
+ "│ 0.0 ┆ a ┆ 0.329054 ┆ 0.431276 ┆ … ┆ -100000 ┆ 0 ┆ 0.194481 ┆ a │\n",
+ "│ 0.841471 ┆ a ┆ 0.302734 ┆ 0.186427 ┆ … ┆ -99999 ┆ 1 ┆ 0.615612 ┆ a │\n",
+ "│ 0.909297 ┆ a ┆ 0.066187 ┆ 0.955182 ┆ … ┆ -99998 ┆ 1 ┆ 0.953673 ┆ a │\n",
+ "│ 0.14112 ┆ a ┆ 0.052694 ┆ 0.399579 ┆ … ┆ -99997 ┆ 0 ┆ 0.90706 ┆ a │\n",
+ "│ -0.756802 ┆ a ┆ 0.618107 ┆ 0.307307 ┆ … ┆ -99996 ┆ 0 ┆ 0.115548 ┆ a │\n",
"└───────────┴───────┴──────────┴──────────┴───┴─────────┴────────┴───────────┴──────────────┘"
]
},
@@ -301,7 +301,7 @@
" white-space: pre-wrap;\n",
"}\n",
"\n",
- "shape: (2, 2)dummy | list_float |
---|
str | list[f64] |
"a" | [2.0, -1.0] |
"b" | [2.0, -1.0] |
"
+ "shape: (2, 2)dummy | list_float |
---|
str | list[f64] |
"b" | [2.0, -1.0] |
"a" | [2.0, -1.0] |
"
],
"text/plain": [
"shape: (2, 2)\n",
@@ -310,8 +310,8 @@
"│ --- ┆ --- │\n",
"│ str ┆ list[f64] │\n",
"╞═══════╪═════════════╡\n",
- "│ a ┆ [2.0, -1.0] │\n",
"│ b ┆ [2.0, -1.0] │\n",
+ "│ a ┆ [2.0, -1.0] │\n",
"└───────┴─────────────┘"
]
},
@@ -383,7 +383,7 @@
" white-space: pre-wrap;\n",
"}\n",
"\n",
- "shape: (2, 8)dummy_groups | l2 | log loss | precision | recall | f | average_precision | roc_auc |
---|
str | f64 | f64 | f64 | f64 | f64 | f64 | f64 |
"a" | 0.331461 | 0.994676 | 0.504398 | 0.503277 | 0.251918 | 0.506909 | 0.503755 |
"b" | 0.332576 | 0.999094 | 0.500683 | 0.498081 | 0.249689 | 0.500449 | 0.501698 |
"
+ "shape: (2, 8)dummy_groups | l2 | log loss | precision | recall | f | average_precision | roc_auc |
---|
str | f64 | f64 | f64 | f64 | f64 | f64 | f64 |
"b" | 0.335546 | 1.005173 | 0.498574 | 0.496142 | 0.248677 | 0.495425 | 0.495449 |
"a" | 0.334022 | 0.997736 | 0.500377 | 0.502168 | 0.250635 | 0.501932 | 0.498258 |
"
],
"text/plain": [
"shape: (2, 8)\n",
@@ -393,8 +393,8 @@
"│ str ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ --- ┆ f64 │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ f64 ┆ │\n",
"╞══════════════╪══════════╪══════════╪═══════════╪══════════╪══════════╪════════════════╪══════════╡\n",
- "│ a ┆ 0.331461 ┆ 0.994676 ┆ 0.504398 ┆ 0.503277 ┆ 0.251918 ┆ 0.506909 ┆ 0.503755 │\n",
- "│ b ┆ 0.332576 ┆ 0.999094 ┆ 0.500683 ┆ 0.498081 ┆ 0.249689 ┆ 0.500449 ┆ 0.501698 │\n",
+ "│ b ┆ 0.335546 ┆ 1.005173 ┆ 0.498574 ┆ 0.496142 ┆ 0.248677 ┆ 0.495425 ┆ 0.495449 │\n",
+ "│ a ┆ 0.334022 ┆ 0.997736 ┆ 0.500377 ┆ 0.502168 ┆ 0.250635 ┆ 0.501932 ┆ 0.498258 │\n",
"└──────────────┴──────────┴──────────┴───────────┴──────────┴──────────┴────────────────┴──────────┘"
]
},
@@ -405,9 +405,9 @@
],
"source": [
"df.group_by(\"dummy_groups\").agg(\n",
- " pl.col(\"actual\").num.l2_loss(pl.col(\"predicted\")).alias(\"l2\"),\n",
- " pl.col(\"actual\").num.bce(pl.col(\"predicted\")).alias(\"log loss\"),\n",
- " pl.col(\"actual\").num.binary_metrics_combo(pl.col(\"predicted\")).alias(\"combo\")\n",
+ " pl.col(\"actual\").metric.l2_loss(pl.col(\"predicted\")).alias(\"l2\"),\n",
+ " pl.col(\"actual\").metric.bce(pl.col(\"predicted\")).alias(\"log loss\"),\n",
+ " pl.col(\"actual\").metric.binary_metrics_combo(pl.col(\"predicted\")).alias(\"combo\")\n",
").unnest(\"combo\")\n"
]
},
@@ -482,7 +482,7 @@
" white-space: pre-wrap;\n",
"}\n",
"\n",
- "shape: (5, 1)sen |
---|
str |
"hello" |
"church" |
"world" |
"going" |
"to" |
"
+ "shape: (5, 1)sen |
---|
str |
"to" |
"hello" |
"going" |
"world" |
"church" |
"
],
"text/plain": [
"shape: (5, 1)\n",
@@ -491,11 +491,11 @@
"│ --- │\n",
"│ str │\n",
"╞════════╡\n",
+ "│ to │\n",
"│ hello │\n",
- "│ church │\n",
- "│ world │\n",
"│ going │\n",
- "│ to │\n",
+ "│ world │\n",
+ "│ church │\n",
"└────────┘"
]
},
@@ -527,7 +527,7 @@
" white-space: pre-wrap;\n",
"}\n",
"\n",
- "shape: (4, 1)sen |
---|
str |
"hello" |
"world" |
"church" |
"go" |
"
+ "shape: (4, 1)sen |
---|
str |
"go" |
"hello" |
"church" |
"world" |
"
],
"text/plain": [
"shape: (4, 1)\n",
@@ -536,10 +536,10 @@
"│ --- │\n",
"│ str │\n",
"╞════════╡\n",
+ "│ go │\n",
"│ hello │\n",
- "│ world │\n",
"│ church │\n",
- "│ go │\n",
+ "│ world │\n",
"└────────┘"
]
},
@@ -862,7 +862,7 @@
" white-space: pre-wrap;\n",
"}\n",
"\n",
- "shape: (5, 1)a |
---|
f64 |
null |
null |
-1.497465 |
-0.859263 |
-0.596512 |
"
+ "shape: (5, 1)a |
---|
f64 |
null |
null |
0.375437 |
0.9494 |
-0.651141 |
"
],
"text/plain": [
"shape: (5, 1)\n",
@@ -873,9 +873,9 @@
"╞═══════════╡\n",
"│ null │\n",
"│ null │\n",
- "│ -1.497465 │\n",
- "│ -0.859263 │\n",
- "│ -0.596512 │\n",
+ "│ 0.375437 │\n",
+ "│ 0.9494 │\n",
+ "│ -0.651141 │\n",
"└───────────┘"
]
},
@@ -908,7 +908,7 @@
" white-space: pre-wrap;\n",
"}\n",
"\n",
- "shape: (1_000, 2)a | random |
---|
f64 | f64 |
null | null |
null | null |
-1.497465 | 0.263812 |
-0.859263 | -0.600834 |
-0.596512 | 0.085847 |
-0.797408 | 1.611438 |
0.065377 | 0.711696 |
0.486381 | -0.431943 |
0.657458 | 1.799759 |
0.948048 | 0.321395 |
-0.210284 | -0.933072 |
-1.698786 | -0.240009 |
… | … |
-1.144679 | 1.277798 |
0.105859 | 0.085334 |
-0.409642 | 0.500561 |
0.588062 | 1.535893 |
-0.571369 | 1.467995 |
-0.071992 | 0.424841 |
-0.836861 | 0.652322 |
0.31963 | -1.395188 |
-0.911452 | -0.475192 |
0.625393 | 0.053465 |
-0.063977 | 2.109493 |
0.768323 | 1.230715 |
"
+ "shape: (1_000, 2)a | random |
---|
f64 | f64 |
null | null |
null | null |
0.375437 | 0.771642 |
0.9494 | 0.545358 |
-0.651141 | -1.091522 |
-1.834427 | 0.610844 |
-0.620926 | 1.264071 |
1.812079 | -0.381095 |
2.110361 | -0.321377 |
0.777085 | 1.193875 |
-0.876686 | 0.913566 |
-0.523285 | -0.524509 |
… | … |
0.112199 | 1.768952 |
-0.477742 | -0.477829 |
-0.129456 | 1.431202 |
1.146672 | -0.529259 |
-0.277331 | 0.138509 |
-1.206057 | 0.644518 |
0.33921 | 0.530259 |
-0.666568 | -0.536235 |
-0.390784 | -0.150115 |
0.182855 | 0.676322 |
-0.158676 | 0.551754 |
0.014593 | 0.556663 |
"
],
"text/plain": [
"shape: (1_000, 2)\n",
@@ -919,13 +919,13 @@
"╞═══════════╪═══════════╡\n",
"│ null ┆ null │\n",
"│ null ┆ null │\n",
- "│ -1.497465 ┆ 0.263812 │\n",
- "│ -0.859263 ┆ -0.600834 │\n",
+ "│ 0.375437 ┆ 0.771642 │\n",
+ "│ 0.9494 ┆ 0.545358 │\n",
"│ … ┆ … │\n",
- "│ -0.911452 ┆ -0.475192 │\n",
- "│ 0.625393 ┆ 0.053465 │\n",
- "│ -0.063977 ┆ 2.109493 │\n",
- "│ 0.768323 ┆ 1.230715 │\n",
+ "│ -0.390784 ┆ -0.150115 │\n",
+ "│ 0.182855 ┆ 0.676322 │\n",
+ "│ -0.158676 ┆ 0.551754 │\n",
+ "│ 0.014593 ┆ 0.556663 │\n",
"└───────────┴───────────┘"
]
},
@@ -956,7 +956,7 @@
" white-space: pre-wrap;\n",
"}\n",
"\n",
- "shape: (1_000, 2)a | random_str |
---|
f64 | str |
null | null |
null | null |
-1.497465 | "n2" |
-0.859263 | "B" |
-0.596512 | "y1oL" |
-0.797408 | "KK0" |
0.065377 | "gM5" |
0.486381 | "4pBH" |
0.657458 | "r" |
0.948048 | "5" |
-0.210284 | "hiyl" |
-1.698786 | "rp" |
… | … |
-1.144679 | "T" |
0.105859 | "l" |
-0.409642 | "RP" |
0.588062 | "weP" |
-0.571369 | "VV" |
-0.071992 | "T1s7" |
-0.836861 | "1FR" |
0.31963 | "hyG" |
-0.911452 | "V" |
0.625393 | "B" |
-0.063977 | "C" |
0.768323 | "WJO" |
"
+ "shape: (1_000, 2)a | random_str |
---|
f64 | str |
null | null |
null | null |
0.375437 | "QcAN" |
0.9494 | "9" |
-0.651141 | "wA" |
-1.834427 | "81" |
-0.620926 | "rhXk" |
1.812079 | "KfKa" |
2.110361 | "zs" |
0.777085 | "L" |
-0.876686 | "UMF" |
-0.523285 | "YmCw" |
… | … |
0.112199 | "T" |
-0.477742 | "fAKV" |
-0.129456 | "xgUF" |
1.146672 | "9k" |
-0.277331 | "4M" |
-1.206057 | "Z" |
0.33921 | "unhi" |
-0.666568 | "wB" |
-0.390784 | "u" |
0.182855 | "WqkX" |
-0.158676 | "KD" |
0.014593 | "YZn3" |
"
],
"text/plain": [
"shape: (1_000, 2)\n",
@@ -967,13 +967,13 @@
"╞═══════════╪════════════╡\n",
"│ null ┆ null │\n",
"│ null ┆ null │\n",
- "│ -1.497465 ┆ n2 │\n",
- "│ -0.859263 ┆ B │\n",
+ "│ 0.375437 ┆ QcAN │\n",
+ "│ 0.9494 ┆ 9 │\n",
"│ … ┆ … │\n",
- "│ -0.911452 ┆ V │\n",
- "│ 0.625393 ┆ B │\n",
- "│ -0.063977 ┆ C │\n",
- "│ 0.768323 ┆ WJO │\n",
+ "│ -0.390784 ┆ u │\n",
+ "│ 0.182855 ┆ WqkX │\n",
+ "│ -0.158676 ┆ KD │\n",
+ "│ 0.014593 ┆ YZn3 │\n",
"└───────────┴────────────┘"
]
},
@@ -1004,7 +1004,7 @@
" white-space: pre-wrap;\n",
"}\n",
"\n",
- "shape: (1_000, 2)a | random_str |
---|
f64 | str |
null | null |
null | null |
-1.497465 | "TlshF" |
-0.859263 | "vrsiJ" |
-0.596512 | "tXvOF" |
-0.797408 | "XsFtP" |
0.065377 | "lN3oB" |
0.486381 | "yYIhn" |
0.657458 | "pXyXn" |
0.948048 | "z8a8H" |
-0.210284 | "nLyLe" |
-1.698786 | "F1Dqt" |
… | … |
-1.144679 | "fNiwB" |
0.105859 | "kd5bW" |
-0.409642 | "Gkkb3" |
0.588062 | "ZTh77" |
-0.571369 | "ZY2JK" |
-0.071992 | "7ERcF" |
-0.836861 | "8eNdj" |
0.31963 | "jbJhc" |
-0.911452 | "Rb0H9" |
0.625393 | "NtIB4" |
-0.063977 | "3FH6H" |
0.768323 | "7GXoP" |
"
+ "shape: (1_000, 2)a | random_str |
---|
f64 | str |
null | null |
null | null |
0.375437 | "G35lQ" |
0.9494 | "m8OqI" |
-0.651141 | "S7CWK" |
-1.834427 | "IgOkR" |
-0.620926 | "NbTmT" |
1.812079 | "Trx1u" |
2.110361 | "VCvz1" |
0.777085 | "iNPCp" |
-0.876686 | "Wexmv" |
-0.523285 | "J6TII" |
… | … |
0.112199 | "BxFXn" |
-0.477742 | "rLOKm" |
-0.129456 | "yyYQI" |
1.146672 | "TyGA0" |
-0.277331 | "0fCBu" |
-1.206057 | "ajFgx" |
0.33921 | "9x7wb" |
-0.666568 | "GQ9wB" |
-0.390784 | "zX288" |
0.182855 | "QZhnh" |
-0.158676 | "mHfiC" |
0.014593 | "txRJL" |
"
],
"text/plain": [
"shape: (1_000, 2)\n",
@@ -1015,13 +1015,13 @@
"╞═══════════╪════════════╡\n",
"│ null ┆ null │\n",
"│ null ┆ null │\n",
- "│ -1.497465 ┆ TlshF │\n",
- "│ -0.859263 ┆ vrsiJ │\n",
+ "│ 0.375437 ┆ G35lQ │\n",
+ "│ 0.9494 ┆ m8OqI │\n",
"│ … ┆ … │\n",
- "│ -0.911452 ┆ Rb0H9 │\n",
- "│ 0.625393 ┆ NtIB4 │\n",
- "│ -0.063977 ┆ 3FH6H │\n",
- "│ 0.768323 ┆ 7GXoP │\n",
+ "│ -0.390784 ┆ zX288 │\n",
+ "│ 0.182855 ┆ QZhnh │\n",
+ "│ -0.158676 ┆ mHfiC │\n",
+ "│ 0.014593 ┆ txRJL │\n",
"└───────────┴────────────┘"
]
},
@@ -1053,7 +1053,7 @@
" white-space: pre-wrap;\n",
"}\n",
"\n",
- "shape: (5, 3)a | test1 | test2 |
---|
f64 | f64 | f64 |
null | 2.015709 | null |
null | 0.100036 | null |
-1.497465 | 0.754484 | 2.671293 |
-0.859263 | -0.107083 | 0.509611 |
-0.596512 | 1.619745 | 1.599823 |
"
+ "shape: (5, 3)a | test1 | test2 |
---|
f64 | f64 | f64 |
null | -0.316657 | null |
null | 0.857924 | null |
0.375437 | 0.817816 | 0.840231 |
0.9494 | -0.145348 | 1.610263 |
-0.651141 | -1.057825 | 1.090286 |
"
],
"text/plain": [
"shape: (5, 3)\n",
@@ -1062,11 +1062,11 @@
"│ --- ┆ --- ┆ --- │\n",
"│ f64 ┆ f64 ┆ f64 │\n",
"╞═══════════╪═══════════╪══════════╡\n",
- "│ null ┆ 2.015709 ┆ null │\n",
- "│ null ┆ 0.100036 ┆ null │\n",
- "│ -1.497465 ┆ 0.754484 ┆ 2.671293 │\n",
- "│ -0.859263 ┆ -0.107083 ┆ 0.509611 │\n",
- "│ -0.596512 ┆ 1.619745 ┆ 1.599823 │\n",
+ "│ null ┆ -0.316657 ┆ null │\n",
+ "│ null ┆ 0.857924 ┆ null │\n",
+ "│ 0.375437 ┆ 0.817816 ┆ 0.840231 │\n",
+ "│ 0.9494 ┆ -0.145348 ┆ 1.610263 │\n",
+ "│ -0.651141 ┆ -1.057825 ┆ 1.090286 │\n",
"└───────────┴───────────┴──────────┘"
]
},
@@ -1099,7 +1099,7 @@
" white-space: pre-wrap;\n",
"}\n",
"\n",
- "shape: (1, 4)t-tests: statistics | t-tests: pvalue | normality_test: statistics | normality_test: pvalue |
---|
f64 | f64 | f64 | f64 |
-0.241144 | 0.809478 | 0.162122 | 0.922138 |
"
+ "shape: (1, 4)t-tests: statistics | t-tests: pvalue | normality_test: statistics | normality_test: pvalue |
---|
f64 | f64 | f64 | f64 |
-0.340638 | 0.733424 | 0.489244 | 0.783001 |
"
],
"text/plain": [
"shape: (1, 4)\n",
@@ -1108,7 +1108,7 @@
"│ --- ┆ --- ┆ --- ┆ --- │\n",
"│ f64 ┆ f64 ┆ f64 ┆ f64 │\n",
"╞═════════════════════╪═════════════════╪════════════════════════════╪════════════════════════╡\n",
- "│ -0.241144 ┆ 0.809478 ┆ 0.162122 ┆ 0.922138 │\n",
+ "│ -0.340638 ┆ 0.733424 ┆ 0.489244 ┆ 0.783001 │\n",
"└─────────────────────┴─────────────────┴────────────────────────────┴────────────────────────┘"
]
},
@@ -1152,7 +1152,7 @@
" white-space: pre-wrap;\n",
"}\n",
"\n",
- "shape: (5, 5)market_id | group1 | group2 | category_1 | category_2 |
---|
i64 | f64 | f64 | i64 | i64 |
0 | 0.442217 | 0.026176 | 1 | 9 |
1 | 0.818966 | 0.583352 | 3 | 3 |
2 | 0.41901 | 0.616824 | 0 | 3 |
0 | 0.006135 | 0.754259 | 1 | 2 |
1 | 0.781227 | 0.799708 | 2 | 7 |
"
+ "shape: (5, 5)market_id | group1 | group2 | category_1 | category_2 |
---|
i64 | f64 | f64 | i64 | i64 |
0 | 0.279683 | 0.693541 | 3 | 7 |
1 | 0.165041 | 0.408362 | 3 | 7 |
2 | 0.36538 | 0.943691 | 1 | 5 |
0 | 0.182031 | 0.066145 | 1 | 8 |
1 | 0.451541 | 0.953597 | 3 | 1 |
"
],
"text/plain": [
"shape: (5, 5)\n",
@@ -1161,11 +1161,11 @@
"│ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n",
"│ i64 ┆ f64 ┆ f64 ┆ i64 ┆ i64 │\n",
"╞═══════════╪══════════╪══════════╪════════════╪════════════╡\n",
- "│ 0 ┆ 0.442217 ┆ 0.026176 ┆ 1 ┆ 9 │\n",
- "│ 1 ┆ 0.818966 ┆ 0.583352 ┆ 3 ┆ 3 │\n",
- "│ 2 ┆ 0.41901 ┆ 0.616824 ┆ 0 ┆ 3 │\n",
- "│ 0 ┆ 0.006135 ┆ 0.754259 ┆ 1 ┆ 2 │\n",
- "│ 1 ┆ 0.781227 ┆ 0.799708 ┆ 2 ┆ 7 │\n",
+ "│ 0 ┆ 0.279683 ┆ 0.693541 ┆ 3 ┆ 7 │\n",
+ "│ 1 ┆ 0.165041 ┆ 0.408362 ┆ 3 ┆ 7 │\n",
+ "│ 2 ┆ 0.36538 ┆ 0.943691 ┆ 1 ┆ 5 │\n",
+ "│ 0 ┆ 0.182031 ┆ 0.066145 ┆ 1 ┆ 8 │\n",
+ "│ 1 ┆ 0.451541 ┆ 0.953597 ┆ 3 ┆ 1 │\n",
"└───────────┴──────────┴──────────┴────────────┴────────────┘"
]
},
@@ -1199,13 +1199,13 @@
"output_type": "stream",
"text": [
"shape: (1, 3)\n",
- "┌─────────────────────┬────────────────────┬───────────────────┐\n",
- "│ t-test ┆ chi2-test ┆ f-test │\n",
- "│ --- ┆ --- ┆ --- │\n",
- "│ struct[2] ┆ struct[2] ┆ struct[2] │\n",
- "╞═════════════════════╪════════════════════╪═══════════════════╡\n",
- "│ {2.115272,0.034431} ┆ {32.5811,0.631989} ┆ {1.881212,0.1108} │\n",
- "└─────────────────────┴────────────────────┴───────────────────┘\n"
+ "┌──────────────────────┬──────────────────────┬─────────────────────┐\n",
+ "│ t-test ┆ chi2-test ┆ f-test │\n",
+ "│ --- ┆ --- ┆ --- │\n",
+ "│ struct[2] ┆ struct[2] ┆ struct[2] │\n",
+ "╞══════════════════════╪══════════════════════╪═════════════════════╡\n",
+ "│ {-1.650504,0.098871} ┆ {36.895283,0.427337} ┆ {0.309968,0.871476} │\n",
+ "└──────────────────────┴──────────────────────┴─────────────────────┘\n"
]
}
],
@@ -1234,19 +1234,19 @@
" white-space: pre-wrap;\n",
"}\n",
"\n",
- "shape: (3, 4)market_id | t-test | chi2-test | f-test |
---|
i64 | struct[2] | struct[2] | struct[2] |
0 | {1.598889,0.10994} | {36.896559,0.427279} | {-1.900136,-0.1079} |
1 | {1.806081,0.070996} | {44.404524,0.158728} | {-0.346406,-0.846645} |
2 | {0.257484,0.796821} | {26.412285,0.878774} | {-0.350121,-0.844067} |
"
+ "shape: (3, 4)market_id | t-test | chi2-test | f-test |
---|
i64 | struct[2] | struct[2] | struct[2] |
0 | {-1.237199,0.2161} | {32.480736,0.636743} | {-0.373944,-0.82735} |
1 | {0.056369,0.955051} | {42.838918,0.201173} | {-0.29107,-0.883896} |
2 | {-1.668119,0.095386} | {52.244235,0.039157} | {-0.690292,-0.598664} |
"
],
"text/plain": [
"shape: (3, 4)\n",
- "┌───────────┬─────────────────────┬──────────────────────┬───────────────────────┐\n",
- "│ market_id ┆ t-test ┆ chi2-test ┆ f-test │\n",
- "│ --- ┆ --- ┆ --- ┆ --- │\n",
- "│ i64 ┆ struct[2] ┆ struct[2] ┆ struct[2] │\n",
- "╞═══════════╪═════════════════════╪══════════════════════╪═══════════════════════╡\n",
- "│ 0 ┆ {1.598889,0.10994} ┆ {36.896559,0.427279} ┆ {-1.900136,-0.1079} │\n",
- "│ 1 ┆ {1.806081,0.070996} ┆ {44.404524,0.158728} ┆ {-0.346406,-0.846645} │\n",
- "│ 2 ┆ {0.257484,0.796821} ┆ {26.412285,0.878774} ┆ {-0.350121,-0.844067} │\n",
- "└───────────┴─────────────────────┴──────────────────────┴───────────────────────┘"
+ "┌───────────┬──────────────────────┬──────────────────────┬───────────────────────┐\n",
+ "│ market_id ┆ t-test ┆ chi2-test ┆ f-test │\n",
+ "│ --- ┆ --- ┆ --- ┆ --- │\n",
+ "│ i64 ┆ struct[2] ┆ struct[2] ┆ struct[2] │\n",
+ "╞═══════════╪══════════════════════╪══════════════════════╪═══════════════════════╡\n",
+ "│ 0 ┆ {-1.237199,0.2161} ┆ {32.480736,0.636743} ┆ {-0.373944,-0.82735} │\n",
+ "│ 1 ┆ {0.056369,0.955051} ┆ {42.838918,0.201173} ┆ {-0.29107,-0.883896} │\n",
+ "│ 2 ┆ {-1.668119,0.095386} ┆ {52.244235,0.039157} ┆ {-0.690292,-0.598664} │\n",
+ "└───────────┴──────────────────────┴──────────────────────┴───────────────────────┘"
]
},
"execution_count": 29,
@@ -1307,7 +1307,7 @@
" white-space: pre-wrap;\n",
"}\n",
"\n",
- "shape: (5, 1)"
+ "shape: (5, 1)"
],
"text/plain": [
"shape: (5, 1)\n",
@@ -1316,11 +1316,11 @@
"│ --- │\n",
"│ u32 │\n",
"╞═════╡\n",
- "│ 11 │\n",
+ "│ 10 │\n",
+ "│ 5 │\n",
+ "│ 5 │\n",
+ "│ 10 │\n",
"│ 3 │\n",
- "│ 8 │\n",
- "│ 11 │\n",
- "│ 9 │\n",
"└─────┘"
]
},
@@ -1357,7 +1357,7 @@
" white-space: pre-wrap;\n",
"}\n",
"\n",
- "shape: (5, 1)"
+ "shape: (5, 1)"
],
"text/plain": [
"shape: (5, 1)\n",
@@ -1366,11 +1366,11 @@
"│ --- │\n",
"│ u32 │\n",
"╞═════╡\n",
- "│ 2 │\n",
- "│ 65 │\n",
+ "│ 117 │\n",
"│ 18 │\n",
- "│ 105 │\n",
- "│ 1 │\n",
+ "│ 154 │\n",
+ "│ 355 │\n",
+ "│ 184 │\n",
"└─────┘"
]
},
@@ -1406,7 +1406,7 @@
" white-space: pre-wrap;\n",
"}\n",
"\n",
- "shape: (5, 7)id | val1 | val2 | val3 | r | rh | best friends |
---|
i64 | f64 | f64 | f64 | f64 | f64 | list[u64] |
0 | 0.946279 | 0.738426 | 0.362952 | 0.421689 | 1.876017 | [0, 958, … 313] |
1 | 0.741668 | 0.981829 | 0.883407 | 0.533584 | 6.247939 | [1, 460, … 906] |
2 | 0.758454 | 0.726879 | 0.96088 | 0.110808 | 4.053351 | [2, 568, … 834] |
3 | 0.96918 | 0.571118 | 0.91598 | 0.474751 | 4.305541 | [3, 641, … 82] |
4 | 0.432479 | 0.708105 | 0.345469 | 0.349741 | 1.646762 | [4, 379, … 389] |
"
+ "shape: (5, 7)id | val1 | val2 | val3 | r | rh | best friends |
---|
i64 | f64 | f64 | f64 | f64 | f64 | list[u64] |
0 | 0.628277 | 0.24721 | 0.823609 | 0.401848 | 4.325744 | [0, 629, … 857] |
1 | 0.359997 | 0.586155 | 0.179896 | 0.06172 | 4.930872 | [1, 399, … 635] |
2 | 0.974186 | 0.90477 | 0.223832 | 0.119878 | 9.149884 | [2, 902, … 53] |
3 | 0.15599 | 0.46002 | 0.756781 | 0.680912 | 6.204805 | [3, 391, … 898] |
4 | 0.254709 | 0.305821 | 0.035771 | 0.770254 | 7.867602 | [4, 94, … 816] |
"
],
"text/plain": [
"shape: (5, 7)\n",
@@ -1415,11 +1415,11 @@
"│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n",
"│ i64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ list[u64] │\n",
"╞═════╪══════════╪══════════╪══════════╪══════════╪══════════╪═════════════════╡\n",
- "│ 0 ┆ 0.946279 ┆ 0.738426 ┆ 0.362952 ┆ 0.421689 ┆ 1.876017 ┆ [0, 958, … 313] │\n",
- "│ 1 ┆ 0.741668 ┆ 0.981829 ┆ 0.883407 ┆ 0.533584 ┆ 6.247939 ┆ [1, 460, … 906] │\n",
- "│ 2 ┆ 0.758454 ┆ 0.726879 ┆ 0.96088 ┆ 0.110808 ┆ 4.053351 ┆ [2, 568, … 834] │\n",
- "│ 3 ┆ 0.96918 ┆ 0.571118 ┆ 0.91598 ┆ 0.474751 ┆ 4.305541 ┆ [3, 641, … 82] │\n",
- "│ 4 ┆ 0.432479 ┆ 0.708105 ┆ 0.345469 ┆ 0.349741 ┆ 1.646762 ┆ [4, 379, … 389] │\n",
+ "│ 0 ┆ 0.628277 ┆ 0.24721 ┆ 0.823609 ┆ 0.401848 ┆ 4.325744 ┆ [0, 629, … 857] │\n",
+ "│ 1 ┆ 0.359997 ┆ 0.586155 ┆ 0.179896 ┆ 0.06172 ┆ 4.930872 ┆ [1, 399, … 635] │\n",
+ "│ 2 ┆ 0.974186 ┆ 0.90477 ┆ 0.223832 ┆ 0.119878 ┆ 9.149884 ┆ [2, 902, … 53] │\n",
+ "│ 3 ┆ 0.15599 ┆ 0.46002 ┆ 0.756781 ┆ 0.680912 ┆ 6.204805 ┆ [3, 391, … 898] │\n",
+ "│ 4 ┆ 0.254709 ┆ 0.305821 ┆ 0.035771 ┆ 0.770254 ┆ 7.867602 ┆ [4, 94, … 816] │\n",
"└─────┴──────────┴──────────┴──────────┴──────────┴──────────┴─────────────────┘"
]
},
@@ -1457,7 +1457,7 @@
" white-space: pre-wrap;\n",
"}\n",
"\n",
- "shape: (5, 6)id | val1 | val2 | val3 | r | rh |
---|
i64 | f64 | f64 | f64 | f64 | f64 |
4 | 0.432479 | 0.708105 | 0.345469 | 0.349741 | 1.646762 |
6 | 0.617363 | 0.456153 | 0.625416 | 0.603103 | 4.263803 |
7 | 0.329858 | 0.303322 | 0.728363 | 0.853283 | 9.724125 |
15 | 0.607162 | 0.131546 | 0.57634 | 0.986423 | 3.408407 |
16 | 0.216038 | 0.581832 | 0.682663 | 0.169455 | 9.051951 |
"
+ "shape: (5, 6)id | val1 | val2 | val3 | r | rh |
---|
i64 | f64 | f64 | f64 | f64 | f64 |
0 | 0.628277 | 0.24721 | 0.823609 | 0.401848 | 4.325744 |
1 | 0.359997 | 0.586155 | 0.179896 | 0.06172 | 4.930872 |
3 | 0.15599 | 0.46002 | 0.756781 | 0.680912 | 6.204805 |
6 | 0.490409 | 0.513039 | 0.853946 | 0.568392 | 8.889839 |
8 | 0.728796 | 0.803011 | 0.33399 | 0.025543 | 1.634557 |
"
],
"text/plain": [
"shape: (5, 6)\n",
@@ -1466,11 +1466,11 @@
"│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n",
"│ i64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 │\n",
"╞═════╪══════════╪══════════╪══════════╪══════════╪══════════╡\n",
- "│ 4 ┆ 0.432479 ┆ 0.708105 ┆ 0.345469 ┆ 0.349741 ┆ 1.646762 │\n",
- "│ 6 ┆ 0.617363 ┆ 0.456153 ┆ 0.625416 ┆ 0.603103 ┆ 4.263803 │\n",
- "│ 7 ┆ 0.329858 ┆ 0.303322 ┆ 0.728363 ┆ 0.853283 ┆ 9.724125 │\n",
- "│ 15 ┆ 0.607162 ┆ 0.131546 ┆ 0.57634 ┆ 0.986423 ┆ 3.408407 │\n",
- "│ 16 ┆ 0.216038 ┆ 0.581832 ┆ 0.682663 ┆ 0.169455 ┆ 9.051951 │\n",
+ "│ 0 ┆ 0.628277 ┆ 0.24721 ┆ 0.823609 ┆ 0.401848 ┆ 4.325744 │\n",
+ "│ 1 ┆ 0.359997 ┆ 0.586155 ┆ 0.179896 ┆ 0.06172 ┆ 4.930872 │\n",
+ "│ 3 ┆ 0.15599 ┆ 0.46002 ┆ 0.756781 ┆ 0.680912 ┆ 6.204805 │\n",
+ "│ 6 ┆ 0.490409 ┆ 0.513039 ┆ 0.853946 ┆ 0.568392 ┆ 8.889839 │\n",
+ "│ 8 ┆ 0.728796 ┆ 0.803011 ┆ 0.33399 ┆ 0.025543 ┆ 1.634557 │\n",
"└─────┴──────────┴──────────┴──────────┴──────────┴──────────┘"
]
},
@@ -1507,7 +1507,7 @@
" white-space: pre-wrap;\n",
"}\n",
"\n",
- "shape: (5, 6)id | val1 | val2 | val3 | r | rh |
---|
i64 | f64 | f64 | f64 | f64 | f64 |
103 | 0.427545 | 0.537531 | 0.001032 | 0.252425 | 0.520078 |
135 | 0.505411 | 0.475068 | 0.63048 | 0.308946 | 7.205082 |
245 | 0.506767 | 0.557298 | 0.850589 | 0.949393 | 7.133072 |
261 | 0.534525 | 0.475023 | 0.691376 | 0.212086 | 8.101366 |
275 | 0.46973 | 0.526947 | 0.404464 | 0.120119 | 9.040323 |
"
+ "shape: (5, 6)id | val1 | val2 | val3 | r | rh |
---|
i64 | f64 | f64 | f64 | f64 | f64 |
6 | 0.490409 | 0.513039 | 0.853946 | 0.568392 | 8.889839 |
12 | 0.424308 | 0.494102 | 0.532104 | 0.352054 | 4.542275 |
45 | 0.502154 | 0.583782 | 0.14791 | 0.25895 | 7.132816 |
65 | 0.544417 | 0.502539 | 0.229982 | 0.349951 | 2.068083 |
74 | 0.496812 | 0.421964 | 0.761863 | 0.979747 | 1.770039 |
"
],
"text/plain": [
"shape: (5, 6)\n",
@@ -1516,11 +1516,11 @@
"│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n",
"│ i64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 │\n",
"╞═════╪══════════╪══════════╪══════════╪══════════╪══════════╡\n",
- "│ 103 ┆ 0.427545 ┆ 0.537531 ┆ 0.001032 ┆ 0.252425 ┆ 0.520078 │\n",
- "│ 135 ┆ 0.505411 ┆ 0.475068 ┆ 0.63048 ┆ 0.308946 ┆ 7.205082 │\n",
- "│ 245 ┆ 0.506767 ┆ 0.557298 ┆ 0.850589 ┆ 0.949393 ┆ 7.133072 │\n",
- "│ 261 ┆ 0.534525 ┆ 0.475023 ┆ 0.691376 ┆ 0.212086 ┆ 8.101366 │\n",
- "│ 275 ┆ 0.46973 ┆ 0.526947 ┆ 0.404464 ┆ 0.120119 ┆ 9.040323 │\n",
+ "│ 6 ┆ 0.490409 ┆ 0.513039 ┆ 0.853946 ┆ 0.568392 ┆ 8.889839 │\n",
+ "│ 12 ┆ 0.424308 ┆ 0.494102 ┆ 0.532104 ┆ 0.352054 ┆ 4.542275 │\n",
+ "│ 45 ┆ 0.502154 ┆ 0.583782 ┆ 0.14791 ┆ 0.25895 ┆ 7.132816 │\n",
+ "│ 65 ┆ 0.544417 ┆ 0.502539 ┆ 0.229982 ┆ 0.349951 ┆ 2.068083 │\n",
+ "│ 74 ┆ 0.496812 ┆ 0.421964 ┆ 0.761863 ┆ 0.979747 ┆ 1.770039 │\n",
"└─────┴──────────┴──────────┴──────────┴──────────┴──────────┘"
]
},
@@ -1557,7 +1557,7 @@
" white-space: pre-wrap;\n",
"}\n",
"\n",
- "shape: (5, 6)id | val1 | val2 | val3 | r | rh |
---|
i64 | f64 | f64 | f64 | f64 | f64 |
135 | 0.505411 | 0.475068 | 0.63048 | 0.308946 | 7.205082 |
245 | 0.506767 | 0.557298 | 0.850589 | 0.949393 | 7.133072 |
261 | 0.534525 | 0.475023 | 0.691376 | 0.212086 | 8.101366 |
275 | 0.46973 | 0.526947 | 0.404464 | 0.120119 | 9.040323 |
323 | 0.549024 | 0.453889 | 0.808333 | 0.470312 | 7.627118 |
"
+ "shape: (5, 6)id | val1 | val2 | val3 | r | rh |
---|
i64 | f64 | f64 | f64 | f64 | f64 |
6 | 0.490409 | 0.513039 | 0.853946 | 0.568392 | 8.889839 |
133 | 0.498075 | 0.517661 | 0.196447 | 0.756156 | 9.906723 |
164 | 0.496588 | 0.493433 | 0.39586 | 0.202915 | 1.926145 |
215 | 0.490266 | 0.520783 | 0.488036 | 0.521546 | 5.905562 |
248 | 0.469321 | 0.526669 | 0.117629 | 0.563881 | 9.793812 |
"
],
"text/plain": [
"shape: (5, 6)\n",
@@ -1566,11 +1566,11 @@
"│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n",
"│ i64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 │\n",
"╞═════╪══════════╪══════════╪══════════╪══════════╪══════════╡\n",
- "│ 135 ┆ 0.505411 ┆ 0.475068 ┆ 0.63048 ┆ 0.308946 ┆ 7.205082 │\n",
- "│ 245 ┆ 0.506767 ┆ 0.557298 ┆ 0.850589 ┆ 0.949393 ┆ 7.133072 │\n",
- "│ 261 ┆ 0.534525 ┆ 0.475023 ┆ 0.691376 ┆ 0.212086 ┆ 8.101366 │\n",
- "│ 275 ┆ 0.46973 ┆ 0.526947 ┆ 0.404464 ┆ 0.120119 ┆ 9.040323 │\n",
- "│ 323 ┆ 0.549024 ┆ 0.453889 ┆ 0.808333 ┆ 0.470312 ┆ 7.627118 │\n",
+ "│ 6 ┆ 0.490409 ┆ 0.513039 ┆ 0.853946 ┆ 0.568392 ┆ 8.889839 │\n",
+ "│ 133 ┆ 0.498075 ┆ 0.517661 ┆ 0.196447 ┆ 0.756156 ┆ 9.906723 │\n",
+ "│ 164 ┆ 0.496588 ┆ 0.493433 ┆ 0.39586 ┆ 0.202915 ┆ 1.926145 │\n",
+ "│ 215 ┆ 0.490266 ┆ 0.520783 ┆ 0.488036 ┆ 0.521546 ┆ 5.905562 │\n",
+ "│ 248 ┆ 0.469321 ┆ 0.526669 ┆ 0.117629 ┆ 0.563881 ┆ 9.793812 │\n",
"└─────┴──────────┴──────────┴──────────┴──────────┴──────────┘"
]
},
@@ -1583,7 +1583,8 @@
"df.filter(\n",
" pld.query_radius(\n",
" [0.5, 0.5],\n",
- " pl.col(\"val1\"), pl.col(\"val2\"), # Columns used as the coordinates in n-d space\n",
+ " # Columns used as the coordinates in n-d space\n",
+ " pl.col(\"val1\"), pl.col(\"val2\"), \n",
" # radius can also be an existing column in the dataframe.\n",
" radius = pl.col(\"rh\"), \n",
" dist = \"h\" \n",
diff --git a/mkdocs.yml b/mkdocs.yml
index 49d1445c..a6b2a7d0 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -9,6 +9,7 @@ nav:
- Numerical Extension: num.md
- Stats Extension: stats.md
- String Extension: str2.md
+- ML Metrics/Loss Extension: metrics.md
- Additional Expressions: polars_ds.md
theme:
diff --git a/python/polars_ds/__init__.py b/python/polars_ds/__init__.py
index b30b0db6..07715eb6 100644
--- a/python/polars_ds/__init__.py
+++ b/python/polars_ds/__init__.py
@@ -6,9 +6,10 @@
from polars_ds.complex import ComplexExt # noqa: E402
from polars_ds.str2 import StrExt # noqa: E402
from polars_ds.stats import StatsExt # noqa: E402
+from polars_ds.metrics import MetricExt # noqa: E402
version = "0.2.1"
-__all__ = ["NumExt", "StrExt", "StatsExt", "ComplexExt"]
+__all__ = ["NumExt", "StrExt", "StatsExt", "ComplexExt", "MetricExt"]
def query_radius(
diff --git a/python/polars_ds/metrics.py b/python/polars_ds/metrics.py
new file mode 100644
index 00000000..874280cf
--- /dev/null
+++ b/python/polars_ds/metrics.py
@@ -0,0 +1,297 @@
+import polars as pl
+# from typing import Union, Optional
+
+from polars.utils.udfs import _get_shared_lib_location
+
+_lib = _get_shared_lib_location(__file__)
+
+
+@pl.api.register_expr_namespace("metric")
+class MetricExt:
+
+ """
+ All the metrics/losses provided here is meant for model evaluation outside training,
+ e.g. for report generation, model performance monitoring, etc., not for actual use in ML models.
+ All metrics follow the convention by treating self as the actual column, and pred as the column
+ of predictions.
+
+ Polars Namespace: metric
+
+ Example: pl.col("a").metric.hubor_loss(pl.col("pred"), delta = 0.5)
+ """
+
+ def __init__(self, expr: pl.Expr):
+ self._expr: pl.Expr = expr
+
+ def hubor_loss(self, pred: pl.Expr, delta: float) -> pl.Expr:
+ """
+ Computes huber loss between this and the other expression. This assumes
+ this expression is actual, and the input is predicted, although the order
+ does not matter in this case.
+
+ Parameters
+ ----------
+ pred
+ An expression represeting the column with predicted probability.
+ """
+ temp = (self._expr - pred).abs()
+ return (
+ pl.when(temp <= delta).then(0.5 * temp.pow(2)).otherwise(delta * (temp - 0.5 * delta))
+ / self._expr.count()
+ )
+
+ def l1_loss(self, pred: pl.Expr, normalize: bool = True) -> pl.Expr:
+ """
+ Computes L1 loss (absolute difference) between this and the other `pred` expression.
+
+ Parameters
+ ----------
+ pred
+ An expression represeting the column with predicted probability.
+ normalize
+ If true, divide the result by length of the series
+ """
+ temp = (self._expr - pred).abs().sum()
+ if normalize:
+ return temp / self._expr.count()
+ return temp
+
+ def l2_loss(self, pred: pl.Expr, normalize: bool = True) -> pl.Expr:
+ """
+ Computes L2 loss (normalized L2 distance) between this and the other `pred` expression. This
+ is the norm without 1/p power.
+
+ Parameters
+ ----------
+ pred
+ An expression represeting the column with predicted probability.
+ normalize
+ If true, divide the result by length of the series
+ """
+ temp = self._expr - pred
+ temp = temp.dot(temp)
+ if normalize:
+ return temp / self._expr.count()
+ return temp
+
+ def msle(self, pred: pl.Expr, normalize: bool = True) -> pl.Expr:
+ """
+ Computes the mean square log error between this and the other `pred` expression.
+
+ Parameters
+ ----------
+ pred
+ An expression represeting the column with predicted probability.
+ normalize
+ If true, divide the result by length of the series
+ """
+ diff = self._expr.log1p() - pred.log1p()
+ out = diff.dot(diff)
+ if normalize:
+ return out / self._expr.count()
+ return out
+
+ def chebyshev_loss(self, pred: pl.Expr, normalize: bool = True) -> pl.Expr:
+ """
+ Alias for l_inf_loss.
+ """
+ return self.l_inf_dist(pred, normalize)
+
+ def l_inf_loss(self, pred: pl.Expr, normalize: bool = True) -> pl.Expr:
+ """
+ Computes L^infinity loss between this and the other `pred` expression
+
+ Parameters
+ ----------
+ pred
+ An expression represeting the column with predicted probability.
+ normalize
+ If true, divide the result by length of the series
+ """
+ temp = self._expr - pred
+ out = pl.max_horizontal(temp.min().abs(), temp.max().abs())
+ if normalize:
+ return out / self._expr.count()
+ return out
+
+ def mape(self, pred: pl.Expr, weighted: bool = False) -> pl.Expr:
+ """
+ Computes mean absolute percentage error between self and the other `pred` expression.
+ If weighted, it will compute the weighted version as defined here:
+
+ https://en.wikipedia.org/wiki/Mean_absolute_percentage_error
+
+ Parameters
+ ----------
+ pred
+ An expression represeting the column with predicted probability.
+ weighted
+ If true, computes wMAPE in the wikipedia article
+ """
+ if weighted:
+ return (self._expr - pred).abs().sum() / self._expr.abs().sum()
+ else:
+ return (1 - pred / self._expr).abs().mean()
+
+ def smape(self, pred: pl.Expr) -> pl.Expr:
+ """
+ Computes symmetric mean absolute percentage error between self and other `pred` expression.
+ The value is always between 0 and 1. This is the third version in the wikipedia without
+ the 100 factor.
+
+ https://en.wikipedia.org/wiki/Symmetric_mean_absolute_percentage_error
+
+ Parameters
+ ----------
+ pred
+ A Polars expression representing predictions
+ """
+ numerator = (self._expr - pred).abs()
+ denominator = 1.0 / (self._expr.abs() + pred.abs())
+ return (1.0 / self._expr.count()) * numerator.dot(denominator)
+
+ def log_loss(self, pred: pl.Expr, normalize: bool = True) -> pl.Expr:
+ """
+ Computes log loss, aka binary cross entropy loss, between self and other `pred` expression.
+
+ Parameters
+ ----------
+ pred
+ An expression represeting the column with predicted probability.
+ normalize
+ Whether to divide by N.
+ """
+ out = self._expr.dot(pred.log()) + (1 - self._expr).dot((1 - pred).log())
+ if normalize:
+ return -(out / self._expr.count())
+ return -out
+
+ def pinball_loss(self, pred: pl.Expr, tau: float = 0.5) -> pl.Expr:
+ """
+ This loss yields an estimator of the tau conditional quantile in quantile regression models.
+ This will treat self as y_true.
+
+ Parameters
+ ----------
+ pred
+ An expression represeting the column which is the prediction.
+ tau
+ A float in [0,1] represeting the conditional quantile level
+ """
+ return pl.max_horizontal(tau * (self._expr - pred), (tau - 1) * (self._expr - pred))
+
+ def bce(self, pred: pl.Expr, normalize: bool = True) -> pl.Expr:
+ """
+ Binary cross entropy. Alias for log_loss.
+ """
+ return self.log_loss(pred, normalize)
+
+ def categorical_cross_entropy(
+ self, pred: pl.Expr, normalize: bool = True, dense: bool = True
+ ) -> pl.Expr:
+ """
+ Returns the categorical cross entropy. If you want to avoid numerical error due to log, please
+ set pred = pred + epsilon.
+
+ Parameters
+ ----------
+ pred
+ An expression represeting the predicted probabilities for the classes
+ normalize
+ Whether to divide by N.
+ dense
+ If true, self has to be a dense vector (a single number for each row). If false, self has to be
+ a column of lists with only one 1 and 0s otherwise.
+ """
+ if dense:
+ y_prob = pred.list.get(self._expr)
+ else:
+ y_prob = pred.list.get(self._expr.list.arg_max())
+ if normalize:
+ return -y_prob.log().sum() / self._expr.count()
+ return -y_prob.log().sum()
+
+ def kl_divergence(self, pred: pl.Expr) -> pl.Expr:
+ """
+ Computes the discrete KL Divergence.
+
+ Parameters
+ ----------
+ pred
+ An expression represeting the predicted probabilities for the classes
+
+ Reference
+ ---------
+ https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
+ """
+ return self._expr * (self._expr / pred).log()
+
+ def log_cosh(self, pred: pl.Expr) -> pl.Expr:
+ """
+ Computes log cosh of the the prediction error (pred - self (y_true))
+ """
+ return (pred - self._expr).cosh().log()
+
+ def roc_auc(self, pred: pl.Expr) -> pl.Expr:
+ """
+ Computes ROC AUC using self as actual and pred as predictions.
+
+ Self must be binary and castable to type UInt32. If self is not all 0s and 1s or not binary,
+ the result will not make sense, or some error may occur.
+
+ Parameters
+ ----------
+ pred
+ An expression represeting the column with predicted probability.
+ """
+ y = self._expr.cast(pl.UInt32)
+ return y.register_plugin(
+ lib=_lib,
+ symbol="pl_roc_auc",
+ args=[pred],
+ is_elementwise=False,
+ returns_scalar=True,
+ )
+
+ def gini(self, pred: pl.Expr) -> pl.Expr:
+ """
+ Computes the Gini coefficient. This is 2 * AUC - 1.
+
+ Self must be binary and castable to type UInt32. If self is not all 0s and 1s or not binary,
+ the result will not make sense, or some error may occur.
+
+ Parameters
+ ----------
+ pred
+ An expression represeting the column with predicted probability.
+ """
+ return self.roc_auc(pred) * 2 - 1
+
+ def binary_metrics_combo(self, pred: pl.Expr, threshold: float = 0.5) -> pl.Expr:
+ """
+ Computes the following binary classificaition metrics using self as actual and pred as predictions:
+ precision, recall, f, average_precision and roc_auc. The return will be a struct with values
+ having the names as given here.
+
+ Self must be binary and castable to type UInt32. If self is not all 0s and 1s,
+ the result will not make sense, or some error may occur.
+
+ Average precision is computed using Sum (R_n - R_n-1)*P_n-1, which is not the textbook definition,
+ but is consistent with Scikit-learn. For more information, see
+ https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score.html
+
+ Parameters
+ ----------
+ pred
+ An expression represeting the column with predicted probability.
+ threshold
+ The threshold used to compute precision, recall and f (f score).
+ """
+ y = self._expr.cast(pl.UInt32)
+ return y.register_plugin(
+ lib=_lib,
+ symbol="pl_combo_b",
+ args=[pred, pl.lit(threshold, dtype=pl.Float64)],
+ is_elementwise=False,
+ returns_scalar=True,
+ )
diff --git a/python/polars_ds/num.py b/python/polars_ds/num.py
index b4901816..30240627 100644
--- a/python/polars_ds/num.py
+++ b/python/polars_ds/num.py
@@ -12,6 +12,8 @@ class NumExt:
"""
This class contains tools for dealing with well-known numerical operations and other metrics inside Polars DataFrame.
+ All the metrics/losses provided here is meant for use in cases like evaluating models outside training,
+ not for actual use in ML models.
Polars Namespace: num
@@ -183,209 +185,6 @@ def is_equidistant(self, tol: float = 1e-6) -> pl.Expr:
"""
return (self._expr.diff(null_behavior="drop").abs() <= tol).all()
- def hubor_loss(self, pred: pl.Expr, delta: float) -> pl.Expr:
- """
- Computes huber loss between this and the other expression. This assumes
- this expression is actual, and the input is predicted, although the order
- does not matter in this case.
-
- Parameters
- ----------
- pred
- An expression represeting the column with predicted probability.
- """
- temp = (self._expr - pred).abs()
- return (
- pl.when(temp <= delta).then(0.5 * temp.pow(2)).otherwise(delta * (temp - 0.5 * delta))
- / self._expr.count()
- )
-
- def mad(self, pred: pl.Expr) -> pl.Expr:
- """Computes mean absolute deivation between this and the other `pred` expression."""
- return (self._expr - pred).abs().mean()
-
- def l1_loss(self, pred: pl.Expr, normalize: bool = True) -> pl.Expr:
- """
- Computes L1 loss (absolute difference) between this and the other `pred` expression.
-
- Parameters
- ----------
- pred
- An expression represeting the column with predicted probability.
- normalize
- If true, divide the result by length of the series
- """
- temp = (self._expr - pred).abs().sum()
- if normalize:
- return temp / self._expr.count()
- return temp
-
- def l2_loss(self, pred: pl.Expr, normalize: bool = True) -> pl.Expr:
- """
- Computes L2 loss (normalized L2 distance) between this and the other `pred` expression. This
- is the norm without 1/p power.
-
- Parameters
- ----------
- pred
- An expression represeting the column with predicted probability.
- normalize
- If true, divide the result by length of the series
- """
- temp = self._expr - pred
- temp = temp.dot(temp)
- if normalize:
- return temp / self._expr.count()
- return temp
-
- def msle(self, pred: pl.Expr, normalize: bool = True) -> pl.Expr:
- """
- Computes the mean square log error between this and the other `pred` expression.
-
- Parameters
- ----------
- pred
- An expression represeting the column with predicted probability.
- normalize
- If true, divide the result by length of the series
- """
- diff = self._expr.log1p() - pred.log1p()
- out = diff.dot(diff)
- if normalize:
- return out / self._expr.count()
- return out
-
- def chebyshev_loss(self, pred: pl.Expr, normalize: bool = True) -> pl.Expr:
- """
- Alias for l_inf_loss.
- """
- return self.l_inf_dist(pred, normalize)
-
- def l_inf_loss(self, pred: pl.Expr, normalize: bool = True) -> pl.Expr:
- """
- Computes L^infinity loss between this and the other `pred` expression
-
- Parameters
- ----------
- pred
- An expression represeting the column with predicted probability.
- normalize
- If true, divide the result by length of the series
- """
- temp = self._expr - pred
- out = pl.max_horizontal(temp.min().abs(), temp.max().abs())
- if normalize:
- return out / self._expr.count()
- return out
-
- def mape(self, pred: pl.Expr, weighted: bool = False) -> pl.Expr:
- """
- Computes mean absolute percentage error between self and the other `pred` expression.
- If weighted, it will compute the weighted version as defined here:
-
- https://en.wikipedia.org/wiki/Mean_absolute_percentage_error
-
- Parameters
- ----------
- pred
- An expression represeting the column with predicted probability.
- weighted
- If true, computes wMAPE in the wikipedia article
- """
- if weighted:
- return (self._expr - pred).abs().sum() / self._expr.abs().sum()
- else:
- return (1 - pred / self._expr).abs().mean()
-
- def smape(self, pred: pl.Expr) -> pl.Expr:
- """
- Computes symmetric mean absolute percentage error between self and other `pred` expression.
- The value is always between 0 and 1. This is the third version in the wikipedia without
- the 100 factor.
-
- https://en.wikipedia.org/wiki/Symmetric_mean_absolute_percentage_error
-
- Parameters
- ----------
- pred
- A Polars expression representing predictions
- """
- numerator = (self._expr - pred).abs()
- denominator = 1.0 / (self._expr.abs() + pred.abs())
- return (1.0 / self._expr.count()) * numerator.dot(denominator)
-
- def log_loss(self, pred: pl.Expr, normalize: bool = True) -> pl.Expr:
- """
- Computes log loss, aka binary cross entropy loss, between self and other `pred` expression.
-
- Parameters
- ----------
- pred
- An expression represeting the column with predicted probability.
- normalize
- Whether to divide by N.
- """
- out = self._expr.dot(pred.log()) + (1 - self._expr).dot((1 - pred).log())
- if normalize:
- return -(out / self._expr.count())
- return -out
-
- def bce(self, pred: pl.Expr, normalize: bool = True) -> pl.Expr:
- """
- Binary cross entropy. Alias for log_loss.
- """
- return self.log_loss(pred, normalize)
-
- def roc_auc(self, pred: pl.Expr) -> pl.Expr:
- """
- Computes ROC AUC using self as actual and pred as predictions.
-
- Self must be binary and castable to type UInt32. If self is not all 0s and 1s or not binary,
- the result will not make sense, or some error may occur.
-
- Parameters
- ----------
- pred
- An expression represeting the column with predicted probability.
- """
- y = self._expr.cast(pl.UInt32)
- return y.register_plugin(
- lib=_lib,
- symbol="pl_roc_auc",
- args=[pred],
- is_elementwise=False,
- returns_scalar=True,
- )
-
- def binary_metrics_combo(self, pred: pl.Expr, threshold: float = 0.5) -> pl.Expr:
- """
- Computes the following binary classificaition metrics using self as actual and pred as predictions:
- precision, recall, f, average_precision and roc_auc. The return will be a struct with values
- having the names as given here.
-
- Self must be binary and castable to type UInt32. If self is not all 0s and 1s,
- the result will not make sense, or some error may occur.
-
- Average precision is computed using Sum (R_n - R_n-1)*P_n-1, which is not the textbook definition,
- but is consistent with Scikit-learn. For more information, see
- https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score.html
-
- Parameters
- ----------
- pred
- An expression represeting the column with predicted probability.
- threshold
- The threshold used to compute precision, recall and f (f score).
- """
- y = self._expr.cast(pl.UInt32)
- return y.register_plugin(
- lib=_lib,
- symbol="pl_combo_b",
- args=[pred, pl.lit(threshold, dtype=pl.Float64)],
- is_elementwise=False,
- returns_scalar=True,
- )
-
def trapz(self, x: Union[float, pl.Expr]) -> pl.Expr:
"""
Treats self as y axis, integrates along x using the trapezoidal rule. If x is not a single
diff --git a/tests/test.ipynb b/tests/test.ipynb
index 60b7b417..aa0b28b7 100644
--- a/tests/test.ipynb
+++ b/tests/test.ipynb
@@ -2,27 +2,60 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": null,
"id": "529f4422-5c3a-4bd6-abe0-a15edfc62abb",
"metadata": {},
"outputs": [],
"source": [
"import polars as pl\n",
"import numpy as np\n",
- "# import polars_ds as pld"
+ "import polars_ds as pld"
]
},
{
"cell_type": "code",
"execution_count": null,
- "id": "f0aef69b",
+ "id": "8d9720ab",
"metadata": {},
"outputs": [],
"source": [
"df = pl.DataFrame({\n",
- " \n",
+ " \"y\": [[1,0,0],[0,1,0],[0,0,1],[1,0,0],[0,1,0]],\n",
+ " \"pred\":[[0.1, 0.5, 0.4], [0.2, 0.6, 0.2], [0.4, 0.1, 0.5], [0.9, 0.05, 0.05], [0.2, 0.5, 0.3]]\n",
"})"
]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "92b14f40",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "df.select(\n",
+ " pl.col(\"pred\").list.get(pl.col(\"y\"))\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "f0aef69b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "df.select(\n",
+ " pl.col(\"y\").metric.categorical_cross_entropy(pl.col(\"pred\"), normalize=True, dense = False)\n",
+ ").item(0,0)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "209c3e1c",
+ "metadata": {},
+ "outputs": [],
+ "source": []
}
],
"metadata": {
diff --git a/tests/test_ext.py b/tests/test_ext.py
index 604a6108..82b3b038 100644
--- a/tests/test_ext.py
+++ b/tests/test_ext.py
@@ -111,6 +111,39 @@ def test_cond_entropy(df, res):
assert_frame_equal(df.lazy().select(pl.col("y").num.cond_entropy(pl.col("a"))).collect(), res)
+@pytest.mark.parametrize(
+ "df, res",
+ [
+ (
+ pl.DataFrame(
+ {
+ "y": [0, 1, 2, 0, 1],
+ "pred": [
+ [0.1, 0.5, 0.4],
+ [0.2, 0.6, 0.2],
+ [0.4, 0.1, 0.5],
+ [0.9, 0.05, 0.05],
+ [0.2, 0.5, 0.3],
+ ],
+ }
+ ),
+ pl.DataFrame({"a": [0.8610131187075506]}),
+ ),
+ ],
+)
+def test_cross_entropy(df, res):
+ assert_frame_equal(
+ df.select(pl.col("y").metric.categorical_cross_entropy(pl.col("pred")).alias("a")), res
+ )
+
+ assert_frame_equal(
+ df.lazy()
+ .select(pl.col("y").metric.categorical_cross_entropy(pl.col("pred")).alias("a"))
+ .collect(),
+ res,
+ )
+
+
@pytest.mark.parametrize(
"df, res",
[
@@ -784,7 +817,9 @@ def test_precision_recall_roc_auc():
)
for threshold in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
res = df.select(
- pl.col("y").num.binary_metrics_combo(pl.col("a"), threshold=threshold).alias("metrics")
+ pl.col("y")
+ .metric.binary_metrics_combo(pl.col("a"), threshold=threshold)
+ .alias("metrics")
).unnest("metrics")
precision_res = res.get_column("precision")[0]
recall_res = res.get_column("recall")[0]